Transformerよりもシンプル?「MLP-Mixer」爆誕(2日目) ~Mixer Architecture編~
ニツオです。TwitterでAIやMLについて関連する話題を紹介してます。海外の研究者をフォローしていますので、情報源を増やしたい方はお気軽にフォローください。
さて、2021年5月にMLP-Mixerというモデルが爆誕しました。本日はその解説シリーズ2日目です。
- 1日目: Abstract / Introduction
- 2日目: Mixer Architecture
- 3日目: Experiments
- 4日目: Related Work
- 5日目: Conclusion
- 6日目: Appendix
- 7日目: Source Code
「MLP-Mixer: An all-MLP Architecture for Vision」の原文はこちらです。2021年5月4日にGoogle ResearchとGoogle Brainの混合チームが発表され、関係者のTwitterでもかなり話題になっています。
- 論文の要約はこちら
- 論文のPDFはこちら
- 論文のコードはこちら ※但し、2021年5月16日時点ではいわゆるMasterブランチにはまだ反映されていませんので、LinenブランチのURLを貼ってます
シリーズ関連記事は一番下にリンク貼ってます。
早速みていきましょう。
2 Mixer Architecture
Modern deep vision architectures consist of layers that mix features (i) at a given spatial location, (ii) between different spatial locations, or both at once.
最近のディープビジョンのアーキテクチャは、(i)特定の空間位置での特徴、(ii)異なる空間位置間での特徴、またはその両方を一度に混合する層で構成されている。
In CNNs, (ii) is implemented with N × N convolutions (for N > 1) and pooling.
CNNでは、(ii) は形状が
プーリングとはDeepLでみると共同計算とあるのですが、機械学習においては、畳み込みの出力結果を平均したり、最大値をとったりして、データ的に圧縮してくれるレイヤーのことを言います。
Neurons in deeper layers have a larger receptive field [1, 27].
深い層のニューロンは受容野が大きい。
At the same time, 1×1 convolutions also perform (i), and larger kernels perform both (i) and (ii).
Introductionでも触れたが、
In Vision Transformers and other attention-based architectures, self-attention layers allow both (i) and (ii) and the MLP-blocks perform (i).
Vision Transformers(ViTと呼ばれる)をはじめとするAttentionベースのアーキテクチャでは、Self-Attentionレイヤーが(i)と(ii)の両方を可能にし、MLPブロックが(i)を実行する。
The idea behind the Mixer architecture is to clearly separate the per-location (channel-mixing) operations (i) and cross-location (token-mixing) operations (ii).
Mixerのアーキテクチャの考え方は、位置ごと(チャネル混合)の操作(i)と位置横断(トークン混合)の操作(ii)を明確に分けること。
Both operations are implemented with MLPs.
どちらの操作もMLPで行う。
Figure 1 summarizes the architecture.
改めて図1をのせる。
Mixer takes as input a sequence of
non-overlapping image patches, each one projected to a desired hidden dimension C. S
Mixerは、
This results in a two-dimensional real-valued input table,
. \mathrm{X} \in \mathbb{R}^{S \times C}
この結果、2階テンソルの(つまり行列形式の)実数値の入力テーブル(行列)である
基本的に人間が現実世界においてイメージできる次元は1~3次元までで、時間を入れても4次元といったところですが、数学的に、線形代数的には何次元でもとることができます。あるベクトル空間(といってもイメージは出来ない)を表現するために必要な独立したベクトルの数、つまり一時独立なベクトルの数が次元数だ、という定義だからです。
つまり、
If the original input image has resolution
, and each patch has resolution (H, W) , then the number of patches is (P, P) . S = HW/P^2
元の入力画像の解像度が縦×横
割り切れる数にしておかないと、余りが出ちゃいますね。後で出てきますが、Sはもちろん整数で49、196、256と色々変えながら実験されます。HはHeight、WはWidth、PはPatchの頭文字です。Pも後で出てきますが、だいたい32か16です。
All patches are linearly projected with the same projection matrix.
すべてのパッチは、同じ射影行列で線形的に射影(線形代数用語)されます。ある特徴を表すベクトルを別のベクトルで表現しなおすイメージです。
Mixer consists of multiple layers of identical size, and each layer consists of two MLP blocks.
Mixerは同じ大きさの複数のレイヤー層で構成されており、各層は2つのMLPブロックで構成されています。下記の図の通りに左から右に進んでいくにあたって、データの大きさは変わらず、MLPブロックが2回出てくる構造だ、ということです。
The first one is the token-mixing MLP block: it acts on columns of
(i.e. it is applied to a transposed input table \mathrm{X} ), maps \mathrm{X^T} , and is shared across all columns. \mathbb{R}^S → \mathbb{R}^S
1つ目、トークン混合MLPブロックです。これは、入力テーブル(=入力行列)
1つのパッチは元画像を小さく切り分けた縦と横のサイズが同じな画像を、行ベクトルに変換され、それを集合したものを1つの入力データ
The second one is the channel-mixing MLP block: it acts on rows of
, maps \mathrm{X} , and is shared across all rows. \mathbb{R}^C → \mathbb{R}^C
2つ目は、チャネル混合MLPブロックです。これは入力行列
Each MLP block contains two fully-connected layers and a non-linearity applied independently to each row of its input data tensor.
それぞれのMLPブロックは、2つの全結合層(下図のFully-connected)と、その入力データテンソルの各行に独立して適用される非線形変換レイヤー(下図のGELU)で出てきています。
行列とテンソルは異なるものですが、一旦区別しなくてもよいでしょう。詳しく違いを知りたい方はこちらの記事が参考になりました。
Mixer layers can be written as follows (omitting layer indices):
Mixerのレイヤーは、以下のように書くことができます(レイヤーのインデックスは数式上省略しています)。
インデックスとは、このレイヤーは下記図のように
Here
is an element-wise nonlinearity (GELU [16]). σ
ここで、
GELUの詳細は割愛しますが、RELUと比べて、ちょっと滑らかな曲線カーブを描きます。こちらの記事を参照しました。
で、まず数式の1行目の
上記の図のように、左から言葉でとらえていくと、
- 入力データは、行列
であり、その縦×横のサイズは\mathrm{X} である。あとで出てきますが、S \times C 、S \in 49, 196, 256 で実験されますC \in 512, 768, 1024, 1280 - 1つ目の数式
は、トークン混合MLPがある前半ブロックの出力列ベクトルである\mathrm{U} -
の添え字\mathrm{U}_{∗,i} は、* 行目* 列目の要素を指す。i はワイルドカードで、「すべて」を表すので、つまり* 番目の列ベクトルを指します。i です。i = 1,..., C
これで数式の左辺がはっきりしました。
次に右辺の説明ですが、
- 入力行列の各列ベクトル
がSkip Connectionで最後で同じく\mathrm{X}_{*, i} 番目の列に加算されますi - それとは別に入力行列の各列
がLayer Norm層で正規化されます。この時の出力は\mathrm{X}_{*, i} 。行列の正規化とは各要素の平均を0、標準偏差が1になるように変換すること。中の計算は意識しなくてもいいので、関数の形で書かれてます。細かい話ですが、\mathrm{LayerNorm}(\mathrm{X})_{*, i} が関数の外にあるのは、正規化自体は行列全体に対して行われるので、正規化されたものに対して、改めて{*, i} 番目の列ベクトルをとるためです。i - その出力が、MLP1に入力するために、転置されて、行ベクトル
となり、MLP1に入力される\mathrm{LayerNorm}(\mathrm{X})_{i, *} - MLP1の中で、最初にFully-connectedレイヤーで重み行列
が右からかけられて、\mathrm{W}^{'}_1 となる。この際、バイアスは定義が書かれていないので正確には不明ですが、ここでは論点ではないので無視します\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1 - MLP1の中で、GELUで非線形変換される。GELU関数を
という関数だと定義してるので、引数に上の値が入って、\sigma() となります\sigma(\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1) - MLP1の中で、再度Fully-connectedレイヤーで重み
が右からかけられて、\mathrm{W}^{'}_2 となる\sigma(\mathrm{LayerNorm}(\mathrm{X})_{i, *} \mathrm{W}^{'}_1) \mathrm{W}^{'}_2 - MLP2に入力するために、また転置される。
が成り立つので、(ABC)^T=C^T B^T A^T と置きなおすと、(\mathrm{W}^{'}_1)^T=\mathrm{W}_1, \ (\mathrm{W}^{'}_2)^T=\mathrm{W}_2 に関しても添え字を入れ替えて、\mathrm{X} \mathrm{W}_2 \ \sigma(\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{i, *})
です。これで下記数式の右辺も説明できました。
こんな風に複雑に考えなくても、1つ目のトークン混合MLPは列ベクトルを処理できるMLPだと考えれば、転置の部分が数式上表現しなくてもよくなるので、一発で導出できるようにも思います。その際は重み
正直、この部分でかなり時間を使ってしまったので、困り果てて、著者のLucas Beyerさん(Google Brainスイス)に直接メールして教えてもらったのですが、重み
行列の内積における交換の法則が成立しないですし、かける方向によっては
and D_S are tunable hidden widths in the token-mixing and channel-mixing MLPs, respectively. D_C
Note that
is selected independently of the number of input patches. D_S
ちなみに、
Therefore, the computational complexity of the network is linear in the number of input patches, unlike ViT whose complexity is quadratic.
ここまでの話より、このネットワークの計算量は、入力パッチ数に対して線形に比例する形となり、それが二次関数的な計算量になるViTとは異なる。
Since
is independent of the patch size, the overall complexity is linear in the number of pixels in the image, as for a typical CNN. D_C
As mentioned above, the same channel-mixing MLP (token-mixing MLP) is applied to every row (column) of X.
前述のとおり、入力
Tying the parameters of the channel-mixing MLPs (within each layer) is a natural choice—it provides positional invariance, a prominent feature of convolutions.
チャンネル混合MLPの重みなどのパラメータを(各レイヤー内で)結びつける、つまり同じものを使うことは自然な選択であり、畳み込みの顕著な特徴である位置不変性をもたらします。
However, tying parameters across channels is much less common.
しかし、チャネル間でパラメータを結びつけることはあまり一般的ではありません。
For example, separable convolutions [9, 39], used in some CNNs, apply convolutions to each channel independently of the other channels.
例えば、CNNの一部で使われている、分離可能な畳み込み処理(別論文を参照)は、各チャネルに他のチャネルとは別の畳み込み処理を施します。
However, in separable convolutions, a different convolutional kernel is applied to each channel unlike the token-mixing MLPs in Mixer that share the same kernel (of full receptive field) for all of the channels.
ただ、この分離可能な畳み込み処理では、トークン混合MLPは各チャネルに同じカーネルを共有するのと違って、各チャネルに別のカーネルを使用することになってしまう。
The parameter tying prevents the architecture from growing too fast when increasing the hidden dimension
or the sequence length C and leads to significant memory savings. S
隠れ次元
Surprisingly, this choice does not affect the empirical performance, see Supplementary A.1.
驚くべきことにこの選択は、実験上、パフォーマンスに影響を与えません。補足A.1を参照してください。これは別の回で見てみましょう。
Each layer in Mixer (except for the initial patch projection layer) takes an input of the same size.
Mixerの各レイヤー(最初のパッチ射影レイヤーを除く)は、同じサイズの入力を受け取ります。
This “isotropic” design is most similar to Transformers, or deep RNNs in other domains, that also use a fixed width.
この「等方的な」デザインは、Transformerや他の領域のディープRNNに最も似ていますが、これらも固定の幅(次元やサイズ)を使用しています。
This is unlike most CNNs, which have a pyramidal structure: deeper layers have a lower resolution input, but more channels.
これは、多くのCNNがピラミッド型の構造(深い層は、入力の解像度は低いですが、チャネル数は多い)をしているのとは異なります。
Note that while these are the typical designs, other combinations exist, such as isotropic ResNets [37] and pyramidal ViTs [50].
なお、これらは典型的なデザインと比較した話ですが、等方性のResNetsやピラミッド型のViTなど、他の組み合わせも存在します。
Aside from the MLP layers, Mixer uses other standard architectural components: skip-connections [15] and Layer Normalization [2].
ちなみに、MLPレイヤー以外にも、Mixerは機構を色々持っていて、スキップコネクションや正規化レイヤーといった標準的なアーキテクチャコンポーネントを使用しています。
Furthermore, unlike ViTs, Mixer does not use position embeddings because the token-mixing MLPs are sensitive to the order of the input tokens, and therefore may learn to represent location.
さらに、多くのViTとは異なり、Mixerは位置エンコーディングを使用していません。これは、トークン混合MLPが入力トークンの順序に敏感であるため、位置の表現を学習する可能性があるためです。
Finally, Mixer uses a standard classification head with the global average pooling layer followed by a linear classifier.
最後に、Mixerでは、グローバルアベレージプーリング層の後に線形分類器を使用した標準的な分類ヘッドを使用しています。
この図における、上の部分ですね。
Overall, the architecture can be written
compactly in JAX/Flax, the code is given in Supplementary.
全体として、このアーキテクチャはJAX/Flaxでコンパクトに書くことができます。JAX/Flaxでコンパクトに書くことができます。
JAX/Flaxというのは、こう公式に説明があります。
Flax: A neural network library and ecosystem for JAX designed for flexibility.
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.
つまり、Google Brain(本論文の著者もGoogle Brain)のエンジニアや研究者たちが開発したもので、ニューラルネットワーク専用のライブラリだそうです。今はオープンソースとして、扱われているようです。
おわり
「MLP-Mixer」を解説するシリーズ2日目は以上です。次回はExperimentsです。
感想や要望・指摘等は、本記事へのコメントか、TwitterのリプライやDMでもお待ちしております!
また、結構な時間を費やして書いていますので、投げ銭・サポートの程、よろしくお願いいたします!
シリーズ関連記事はこちら
【2023年5月追記】
また、Slack版ChatGPT「Q」というサービスを開発・運営しています。
こちらもぜひお試しください。
Discussion