Transformerよりもシンプル?「MLP-Mixer」爆誕(2日目) ~Mixer Architecture編~

15 min read読了の目安(約13700字

ニツオです。TwitterでAIやMLについて関連する話題を紹介してます。海外の研究者をフォローしていますので、情報源を増やしたい方はお気軽にフォローください。

さて、2021年5月にMLP-Mixerというモデルが爆誕しました。本日はその解説シリーズ2日目です。

  • 1日目: Abstract / Introduction
  • 2日目: Mixer Architecture
  • 3日目: Experiments
  • 4日目: Related Work
  • 5日目: Conclusion
  • 6日目: Appendix
  • 7日目: Source Code

「MLP-Mixer: An all-MLP Architecture for Vision」の原文はこちらです。2021年5月4日にGoogle ResearchとGoogle Brainの混合チームが発表され、関係者のTwitterでもかなり話題になっています。

シリーズ関連記事は一番下にリンク貼ってます。
早速みていきましょう。

2 Mixer Architecture

Modern deep vision architectures consist of layers that mix features (i) at a given spatial location, (ii) between different spatial locations, or both at once.

最近のディープビジョンのアーキテクチャは、(i)特定の空間位置での特徴、(ii)異なる空間位置間での特徴、またはその両方を一度に混合する層で構成されている。

In CNNs, (ii) is implemented with N × N convolutions (for N > 1) and pooling.

CNNでは、(ii) は形状が N \times N のフィルターでの畳み込み( N>1 )とプーリングで実装される。

プーリングとはDeepLでみると共同計算とあるのですが、機械学習においては、畳み込みの出力結果を平均したり、最大値をとったりして、データ的に圧縮してくれるレイヤーのことを言います。

https://knowledge.sakura.ad.jp/13726/

Neurons in deeper layers have a larger receptive field [1, 27].

深い層のニューロンは受容野が大きい。

At the same time, 1×1 convolutions also perform (i), and larger kernels perform both (i) and (ii).

Introductionでも触れたが、1 \times 1 のサイズのフィルターで畳み込みをした場合が (i) だ。それより大きいカーネル(=フィルターのこと)であれば、(i) (ii) のどちらにもなる。

In Vision Transformers and other attention-based architectures, self-attention layers allow both (i) and (ii) and the MLP-blocks perform (i).

Vision Transformers(ViTと呼ばれる)をはじめとするAttentionベースのアーキテクチャでは、Self-Attentionレイヤーが(i)と(ii)の両方を可能にし、MLPブロックが(i)を実行する。

The idea behind the Mixer architecture is to clearly separate the per-location (channel-mixing) operations (i) and cross-location (token-mixing) operations (ii).

Mixerのアーキテクチャの考え方は、位置ごと(チャネル混合)の操作(i)と位置横断(トークン混合)の操作(ii)を明確に分けること。

Both operations are implemented with MLPs.

どちらの操作もMLPで行う。

Figure 1 summarizes the architecture.

改めて図1をのせる。

Mixer takes as input a sequence of S non-overlapping image patches, each one projected to a desired hidden dimension C.

Mixerは、S 個(図で例えると9個)の非重複画像パッチのシーケンス(配列データ)を入力とする。各パッチは隠れ層の次元 C に射影され、次元が変わる(計算しやすい都合のよい次元数に変えたいから)。

This results in a two-dimensional real-valued input table, \mathrm{X} \in \mathbb{R}^{S \times C} .

この結果、2階テンソルの(つまり行列形式の)実数値の入力テーブル(行列)である \mathrm{X} \in \mathbb{R}^{S \times C} が出来上がる。\mathbb{R} は実数の集合のことで、その次元が S \times C です。

基本的に人間が現実世界においてイメージできる次元は1~3次元までで、時間を入れても4次元といったところですが、数学的に、線形代数的には何次元でもとることができます。あるベクトル空間(といってもイメージは出来ない)を表現するために必要な独立したベクトルの数、つまり一時独立なベクトルの数が次元数だ、という定義だからです。

つまり、\mathrm{X} は、2階テンソルの形、つまり行列であり、縦×横=S \times C の形状をしており、要素数は S \times C 個あり、その要素の数だけ実数空間が広がってる、ということです。

If the original input image has resolution (H, W), and each patch has resolution (P, P), then the number of patches is S = HW/P^2.

元の入力画像の解像度が縦×横 (H, W) のサイズで、小さく切り分けた各パッチの解像度が縦×横 (P, P) とすると、パッチの数は非重複と定義したので S = HW/P^2 となります。

割り切れる数にしておかないと、余りが出ちゃいますね。後で出てきますが、Sはもちろん整数で49、196、256と色々変えながら実験されます。HはHeight、WはWidth、PはPatchの頭文字です。Pも後で出てきますが、だいたい32か16です。

All patches are linearly projected with the same projection matrix.

すべてのパッチは、同じ射影行列で線形的に射影(線形代数用語)されます。ある特徴を表すベクトルを別のベクトルで表現しなおすイメージです。

Mixer consists of multiple layers of identical size, and each layer consists of two MLP blocks.

Mixerは同じ大きさの複数のレイヤー層で構成されており、各層は2つのMLPブロックで構成されています。下記の図の通りに左から右に進んでいくにあたって、データの大きさは変わらず、MLPブロックが2回出てくる構造だ、ということです。

The first one is the token-mixing MLP block: it acts on columns of \mathrm{X} (i.e. it is applied to a transposed input table \mathrm{X^T}), maps \mathbb{R}^S → \mathbb{R}^S, and is shared across all columns.

1つ目、トークン混合MLPブロックです。これは、入力テーブル(=入力行列)\mathrm{X} の要素数が S 個あるそれぞれの列部分に作用し(すなわち,転置された入力テーブル \mathrm{X^T} のそれぞれの行部分に適用される)、\mathbb{R}^S → \mathbb{R}^S の次元に射影するもの(つまりMLPを通してもデータの形状が変わらない)です。また、入力 \mathrm{X} のすべての列で同じMLPが使われます。

1つのパッチは元画像を小さく切り分けた縦と横のサイズが同じな画像を、行ベクトルに変換され、それを集合したものを1つの入力データ\mathrm{X} としてますが、その1列1列が1チャネルで C 列あります。その1列のチャネルの中の各行の値がトークンです。つまり、1チャネルの中のトークンを複数入力とするMLPなので、トークン混合MLPと定義されてます。

The second one is the channel-mixing MLP block: it acts on rows of \mathrm{X}, maps \mathbb{R}^C → \mathbb{R}^C , and is shared across all rows.

2つ目は、チャネル混合MLPブロックです。これは入力行列 \mathrm{X} の行に作用するもので、1行あたりC 個の要素があるので、\mathbb{R}^C → \mathbb{R}^C の次元に射影するもの(つまりさきほどと同じく、データの形状はMLPを通しても変わらない)で、すべての行で同じMLPが共有されます。

Each MLP block contains two fully-connected layers and a non-linearity applied independently to each row of its input data tensor.

それぞれのMLPブロックは、2つの全結合層(下図のFully-connected)と、その入力データテンソルの各行に独立して適用される非線形変換レイヤー(下図のGELU)で出てきています。

行列とテンソルは異なるものですが、一旦区別しなくてもよいでしょう。詳しく違いを知りたい方はこちらの記事が参考になりました。

https://mathwords.net/tensor
https://www.mynote-jp.com/entry/TensorAndMatrixRelation

Mixer layers can be written as follows (omitting layer indices):

Mixerのレイヤーは、以下のように書くことができます(レイヤーのインデックスは数式上省略しています)。

\begin{aligned} \mathrm{U}_{∗,i} &= \mathrm{X}_{∗,i} + \mathrm{W}_2 \ σ \ (\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{∗,i}), \ \mathrm{for} \ i = 1,..., C \\ \mathrm{Y}_{j,∗} &= \mathrm{U}_{j,∗} + \mathrm{W}_4 \ σ \ (\mathrm{W}_3 \ \mathrm{LayerNorm}(\mathrm{U})_{j,∗}), \ \mathrm{for} \ j = 1,..., S \end{aligned}

インデックスとは、このレイヤーは下記図のように N_X 回の層になってるのですが、その N_X のことです。行列のインデックスの場合はその行番号と列番号をインデックスといいますが、ここではレイヤーインデックスなので、どのMixer Layerかを指してると思います。

Here σ is an element-wise nonlinearity (GELU [16]).

ここで、σ(シグマと読む。ギリシャ文字の1つ)は、下記図にでてくるMLP内部のGELU部分で入力の要素ごとに計算される非線形の変換関数を表したものです。GELUはGaussian Error Linear Unitsの略で、RELUやシグモイド関数などと同じ活性化関数の1種です。

GELUの詳細は割愛しますが、RELUと比べて、ちょっと滑らかな曲線カーブを描きます。こちらの記事を参照しました。

https://data-analytics.fun/2020/09/04/understanding-gelu/

で、まず数式の1行目の \mathrm{U} の導出ですが、下記の図のように左前半の出力結果を \mathrm{U} としています。\mathrm{U} はパーセプトロンの出力結果としてよく使われる変数です。

上記の図のように、左から言葉でとらえていくと、

  1. 入力データは、行列 \mathrm{X} であり、その縦×横のサイズは S \times C である。あとで出てきますが、S \in 49, 196, 256C \in 512, 768, 1024, 1280で実験されます
  2. 1つ目の数式 \mathrm{U} は、トークン混合MLPがある前半ブロックの出力列ベクトルである
  3. \mathrm{U}_{∗,i} の添え字 * は、* 行目 i 列目の要素を指す。* はワイルドカードで、「すべて」を表すので、つまり i 番目の列ベクトルを指します。i = 1,..., C です。

これで数式の左辺がはっきりしました。

\begin{aligned} \mathrm{U}_{∗,i} &= \mathrm{X}_{∗,i} + \mathrm{W}_2 \ σ(\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{∗,i}), \ \mathrm{for} \ i = 1,..., C \end{aligned}

次に右辺の説明ですが、

  1. 入力行列の各列ベクトル \mathrm{X}_{*, i} がSkip Connectionで最後で同じく i 番目の列に加算されます
  2. それとは別に入力行列の各列 \mathrm{X}_{*, i} がLayer Norm層で正規化されます。この時の出力は \mathrm{LayerNorm}(\mathrm{X})_{*, i}。行列の正規化とは各要素の平均を0、標準偏差が1になるように変換すること。中の計算は意識しなくてもいいので、関数の形で書かれてます。細かい話ですが、{*, i} が関数の外にあるのは、正規化自体は行列全体に対して行われるので、正規化されたものに対して、改めて i 番目の列ベクトルをとるためです。
  3. その出力が、MLP1に入力するために、転置されて、行ベクトル\mathrm{LayerNorm}(\mathrm{X})_{i, *} となり、MLP1に入力される
  4. MLP1の中で、最初にFully-connectedレイヤーで重み行列 \mathrm{W}^{'}_1 が右からかけられて、\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1 となる。この際、バイアスは定義が書かれていないので正確には不明ですが、ここでは論点ではないので無視します
  5. MLP1の中で、GELUで非線形変換される。GELU関数を \sigma() という関数だと定義してるので、引数に上の値が入って、\sigma(\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1) となります
  6. MLP1の中で、再度Fully-connectedレイヤーで重み \mathrm{W}^{'}_2 が右からかけられて、\sigma(\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1) \mathrm{W}^{'}_2 となる
  7. MLP2に入力するために、また転置される。(ABC)^T=C^T B^T A^T が成り立つので、(\mathrm{W}^{'}_1)^T=\mathrm{W}_1, \ (\mathrm{W}^{'}_2)^T=\mathrm{W}_2 と置きなおすと、\mathrm{X} に関しても添え字を入れ替えて、\mathrm{W}_2 \ \sigma(\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{i, *})

です。これで下記数式の右辺も説明できました。

\begin{aligned} \mathrm{U}_{∗,i} &= \mathrm{X}_{∗,i} + \mathrm{W}_2 \ σ(\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{∗,i}), \ \mathrm{for} \ i = 1,..., C \end{aligned}

こんな風に複雑に考えなくても、1つ目のトークン混合MLPは列ベクトルを処理できるMLPだと考えれば、転置の部分が数式上表現しなくてもよくなるので、一発で導出できるようにも思います。その際は重み W をかけるのが右側からだったのが、左側からかけることになります。

正直、この部分でかなり時間を使ってしまったので、困り果てて、著者のLucas Beyerさん(Google Brainスイス)に直接メールして教えてもらったのですが、重み W を左と右のどちらからかけるかは左からかけて表現することが多いし、今回もそうしているが、慣習的な問題で、あまり意味はない、ような趣旨の返事をもらいました。

行列の内積における交換の法則が成立しないですし、かける方向によっては W の形状も変わるので、重要なのでは・・・と思っていたのですが、実際のコーディング上ももしかしたらあまり意識しないのかもしれません。実装上も、MLPブロックは、Dense関数を使っており、そのDense関数のソースをたどると、dot_general関数が出てきて、それをさらにたどっていかないと中身はわかりません。

D_S and D_C are tunable hidden widths in the token-mixing and channel-mixing MLPs, respectively.

D_SD_C は調整可能な隠れ層の幅(次元数)で、それぞれトークン混合MLPとチャネル混合MLPで使われます。実際、後続の実験でこれらの値はいくつかのパターンで試されます。

Note that D_S is selected independently of the number of input patches.

ちなみに、D_S は入力パッチの数 S とは独立して決めることができます。

Therefore, the computational complexity of the network is linear in the number of input patches, unlike ViT whose complexity is quadratic.

ここまでの話より、このネットワークの計算量は、入力パッチ数に対して線形に比例する形となり、それが二次関数的な計算量になるViTとは異なる。

Since D_C is independent of the patch size, the overall complexity is linear in the number of pixels in the image, as for a typical CNN.

D_C はパッチのサイズとは独立しているので、全体の計算量は、一般的なCNNと同様に、画像のピクセル数に依存する。

As mentioned above, the same channel-mixing MLP (token-mixing MLP) is applied to every row (column) of X.

前述のとおり、入力 X の各行に対して、同じチャネル混合MLPが使われる。同様に、入力 X の各列に対して、同じトークン混合MLPが使われる。

Tying the parameters of the channel-mixing MLPs (within each layer) is a natural choice—it provides positional invariance, a prominent feature of convolutions.

チャンネル混合MLPの重みなどのパラメータを(各レイヤー内で)結びつける、つまり同じものを使うことは自然な選択であり、畳み込みの顕著な特徴である位置不変性をもたらします。

However, tying parameters across channels is much less common.

しかし、チャネル間でパラメータを結びつけることはあまり一般的ではありません。

For example, separable convolutions [9, 39], used in some CNNs, apply convolutions to each channel independently of the other channels.

例えば、CNNの一部で使われている、分離可能な畳み込み処理(別論文を参照)は、各チャネルに他のチャネルとは別の畳み込み処理を施します。

However, in separable convolutions, a different convolutional kernel is applied to each channel unlike the token-mixing MLPs in Mixer that share the same kernel (of full receptive field) for all of the channels.

ただ、この分離可能な畳み込み処理では、トークン混合MLPは各チャネルに同じカーネルを共有するのと違って、各チャネルに別のカーネルを使用することになってしまう。

The parameter tying prevents the architecture from growing too fast when increasing the hidden dimension C or the sequence length S and leads to significant memory savings.

隠れ次元 C や、配列の長さ S を増加させたときに、パラメータの共有により、アーキテクチャが急激に成長することを防ぎ、メモリを大幅に節約することができます。

Surprisingly, this choice does not affect the empirical performance, see Supplementary A.1.

驚くべきことにこの選択は、実験上、パフォーマンスに影響を与えません。補足A.1を参照してください。これは別の回で見てみましょう。

Each layer in Mixer (except for the initial patch projection layer) takes an input of the same size.

Mixerの各レイヤー(最初のパッチ射影レイヤーを除く)は、同じサイズの入力を受け取ります。

This “isotropic” design is most similar to Transformers, or deep RNNs in other domains, that also use a fixed width.

この「等方的な」デザインは、Transformerや他の領域のディープRNNに最も似ていますが、これらも固定の幅(次元やサイズ)を使用しています。

This is unlike most CNNs, which have a pyramidal structure: deeper layers have a lower resolution input, but more channels.

これは、多くのCNNがピラミッド型の構造(深い層は、入力の解像度は低いですが、チャネル数は多い)をしているのとは異なります。

Note that while these are the typical designs, other combinations exist, such as isotropic ResNets [37] and pyramidal ViTs [50].

なお、これらは典型的なデザインと比較した話ですが、等方性のResNetsやピラミッド型のViTなど、他の組み合わせも存在します。

Aside from the MLP layers, Mixer uses other standard architectural components: skip-connections [15] and Layer Normalization [2].

ちなみに、MLPレイヤー以外にも、Mixerは機構を色々持っていて、スキップコネクションや正規化レイヤーといった標準的なアーキテクチャコンポーネントを使用しています。

Furthermore, unlike ViTs, Mixer does not use position embeddings because the token-mixing MLPs are sensitive to the order of the input tokens, and therefore may learn to represent location.

さらに、多くのViTとは異なり、Mixerは位置エンコーディングを使用していません。これは、トークン混合MLPが入力トークンの順序に敏感であるため、位置の表現を学習する可能性があるためです。

Finally, Mixer uses a standard classification head with the global average pooling layer followed by a linear classifier.

最後に、Mixerでは、グローバルアベレージプーリング層の後に線形分類器を使用した標準的な分類ヘッドを使用しています。

この図における、上の部分ですね。

Overall, the architecture can be written
compactly in JAX/Flax, the code is given in Supplementary.

全体として、このアーキテクチャはJAX/Flaxでコンパクトに書くことができます。JAX/Flaxでコンパクトに書くことができます。

JAX/Flaxというのは、こう公式に説明があります。

Flax: A neural network library and ecosystem for JAX designed for flexibility.
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

つまり、Google Brain(本論文の著者もGoogle Brain)のエンジニアや研究者たちが開発したもので、ニューラルネットワーク専用のライブラリだそうです。今はオープンソースとして、扱われているようです。

おわり

「MLP-Mixer」を解説するシリーズ2日目は以上です。次回はExperimentsです。

感想や要望・指摘等は、本記事へのコメントか、TwitterのリプライやDMでもお待ちしております!

https://twitter.com/hnishio0105/status/1395703350759333895?s=20

また、結構な時間を費やして書いていますので、投げ銭・サポートの程、よろしくお願いいたします!

シリーズ関連記事はこちら

https://zenn.dev/attentionplease/articles/532a3de6308f57
https://zenn.dev/attentionplease/articles/7a11a56d767280
https://zenn.dev/attentionplease/articles/df6170f8581b71
https://zenn.dev/attentionplease/articles/7a3e74ad1bc9bf
https://zenn.dev/attentionplease/articles/a0d88939f9ceed
https://zenn.dev/attentionplease/articles/719580daf5a2d1