📺

Transformerよりもシンプル?「MLP-Mixer」爆誕(3日目) ~Experiments 1編~

2021/05/23に公開

ニツオです。TwitterでAIやMLについて関連する話題を紹介してます。海外の研究者をフォローしていますので、情報源を増やしたい方はお気軽にフォローください。

さて、2021年5月にMLP-Mixerというモデルが爆誕しました。本日はその解説シリーズ3日目です。

  • 1日目: Abstract / Introduction
  • 2日目: Mixer Architecture**
  • 3日目: Experiments 1
  • 4日目: Experiments 2**
  • 5日目: Related Work
  • 6日目: Conclusion
  • 7日目: Appendix
  • 8日目: Source Code

「MLP-Mixer: An all-MLP Architecture for Vision」の原文はこちらです。2021年5月4日にGoogle ResearchとGoogle Brainの混合チームから発表され、関係者のTwitterでもかなり話題になっています。

シリーズ関連記事は一番下にリンク貼ってます。
早速みていきましょう。

3 Experiments

We evaluate the performance of MLP-Mixer models, pre-trained with medium- to large-scale datasets, on a range of small and mid-sized downstream classification tasks.

中規模から大規模なデータセットで事前に学習させたMLP-Mixerモデルを、さらに中小規模の別の分類タスクで学習させ、それを評価します。

downstreamはStackoverflowでも質問されていて、時々使われるわりにはネイティブにとっても意味がわかりにくいようですが、学習を前半・後半にわけて、前半の学習をUpstream、後半の学習(主に微調整が目的)をDownstreamとしてるようです。

そもそも事前学習で大量データである程度が学習しておき、特定のタスクに向けて、転移学習として、大量ではないデータで微調整的に再度学習するのがここ最近の主流のようです。

We are interested in three primary quantities:
(1) Accuracy on the downstream task.
(2) Total computational cost of pre-training, which is important when training the model from scratch on the upstream dataset.
(3) Throughput at inference time, which is important to the practitioner.

次の3つの主要な数値をみていく。

  1. Downstreamタスクに対する正確性・精度
  2. 学習のトータル計算時間。これはスクラッチからモデルを学習させる使い方をする場合に重要(Upstream+Downstreamと思われます)
  3. 推論時の処理能力。実務者や利用者にとって重要

Our goal is not to demonstrate state-of-the-art results, but to show that, remarkably, a simple MLP-based model is competitive with today’s best convolutional and attention-based models.

私たちの目標は、最先端の結果(State of the Art=SOTA)を達成することではなく、ただのシンプルなMLPベースのモデルが、今日(こんにち)の最高モデルとされる畳み込みモデルやAttentionベースのモデルと比べて遜色ないことを示すことです。これが示せれば驚くべきことである(そして当然示されることになる)

このCompetitiveが重要ですね。別に性能で勝利する必要はなくて、同じくらいの性能を示せれば、MLPモデルがシンプルであるがゆえに、十分すごいことになるからです。

Downstream tasks

We use multiple popular downstream tasks such as ILSVRC2012 “ImageNet” (1.3M training examples, 1k classes) with the original validation labels [13] and cleaned-up ReaL labels [5], CIFAR-10/100 (50k examples, 10/100 classes) [23], Oxford-IIIT Pets (3.7k examples, 36 classes) [31], and Oxford Flowers-102 (2k examples, 102 classes) [30].

性能調整には、一般的なデータを複数使用しています。例えば、

  • ILSVRC2012 "ImageNet" (130万個の学習用データ、1,000クラス・種類)+オリジナルの検証ラベルとクリーンアップされたReaLラベル
  • CIFAR-10/100 (5万個の学習データ、10クラスと100クラス)
  • Oxford-IIIT Pets (3700個の学習データ、36クラス)
  • Oxford Flowers-102 (2,000個の学習データ、102クラス)

など。
いくつか主要なデータセットが出てきましたので、2つほど紹介しておきます。

ILSVRC2012 "ImageNet":

ImageNet Large Scale Visual Recognition Challenge 2012 の頭文字をとったデータセット名。スタンフォード大学が運営。

この2012のVersionだとクラス=000~999までの1,000種類の画像がある(もっと最新の大量のデータセットももちろんある)。動物とか建物とか小物とかが対象。

公式だと写真を気軽に見れないので、こちらの記事が参考になりました。意外ときれいな写真ではなくて、人が写り込んでたり、光が強すぎたり、クジラとかは海にほとんど隠れちゃってました。学習データってこういうのでもいいんですね。サイズもばらばらです。

サンプル例

https://image-net.org/challenges/LSVRC/2012/
http://starpentagon.net/analytics/ilsvrc2012_class_image/
https://starpentagon.net/analytics/ilsvrc2012_class_image_01/

CIFAR-10/100:

Canadian Institute For Advanced Researchの頭文字をとったデータセット名。CIFAR-10とCIFAR-100がある。

CIFAR-10であれば、ラベル0~9の10種類の画像セットがあり、飛行機とか、猫とか、犬とか、自動車など。作ったのはウクライナ人のAlex Krizhevsky。2012年のImageNetの画像認識コンペで優勝して、自分の会社をGoogleに売却した人です。公式サイトと解説記事を載せておきます。

サンプル例。そして人間的な感覚で言うと、画像がけっこう荒い

https://www.cs.toronto.edu/~kriz/cifar.html
https://www.atmarkit.co.jp/ait/articles/2006/10/news021.html

We also evaluate on the Visual Task Adaptation Benchmark (VTAB-1k), which consists of 19 diverse datasets, each with 1k training examples [55].

また、19種類の多様なデータセットからなるVTAB-1k(Visual Task Adaptation Benchmarkの略)でもDownstream評価を行いました。このVTAB-1kは、それぞれ1,000個の学習データ例で構成されています。

Google AI Blogからとってきたサンプル例。おなじみの猫や飛行機の画像もあるが、3行目の画像などは人間が作った塀と木??と思えるくらい謎の画像もある。

https://ai.googleblog.com/2019/11/the-visual-task-adaptation-benchmark.html

Pre-training data

We follow the standard transfer learning setup: pre-training followed by finetuning on the downstream tasks.

事前学習について。

標準的・一般的な伝達学習の設定に従っています。つまり、事前にトレーニングを行い、その後、後続のDownstreamタスクで微調整を行います。ここは前述した通り。

We pre-train all models on two public datasets: ILSVRC2021 ImageNet, and ImageNet-21k, a superset of ILSVRC2012 that contains 21k classes and 14M images [13].

2つの公開データセットを用いて,すべてのモデルの事前学習を行いました。ILSVRC2021 ImageNetと、ILSVRC2012の上位セットであるImageNet-21k(21,000個のクラスと1400万個の画像を含む)です。

To assess performance at even larger scale, we also train on JFT-300M, a proprietary dataset with 300M examples and 18k classes [43].

さらに大規模なデータセットでもパフォーマンスを評価するために、3億個の学習データと1.8万個のクラスを持つ独自のデータセットであるJFT-300Mでもトレーニングを行いました。

We de-duplicate all pre-training datasets with respect to the test sets of the downstream tasks as done in Dosovitskiy et al. [14], Kolesnikov et al. [22].

参考文献にある、Dosovitskiy(本論文の著者Lucasも共著だが)の論文 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale や Kolesnikov ら(これも本論文の著者Lucasも共著だが)の論文 Big transfer (BiT): General visual representation learning が行っているように、下流のタスクのテストセットに関して、すべての事前学習データセットの重複を除外します。

Pre-training details

We pre-train all models using Adam with β_1 = 0.9, β_2 = 0.999, and batch size 4,096, using weight decay, and gradient clipping at global norm 1.

事前学習の詳細。

最適化アルゴリズムにはAdamを用いる。そのパラメータは、β_1 = 0.9, β_2 = 0.999、バッチサイズ=4,096 で、重み W の減衰Weight Decayと、Gradient Clipping(勾配を閾値で打ち切る手法の名前)をそのパラメータのglobal norm =1 で用いて、すべてのモデルを事前に学習します。

Adamはよく使われる最適化アルゴリズムです。これはこちらの記事が図や絵がよくまとまっていて、わかりやすかったので紹介します。
https://qiita.com/omiita/items/1735c1d048fe5f611f80

We use a linear learning rate warmup of 10k steps and linear decay.

1万(計算)ステップまでは線形的な学習率でのウォームアップをして、そのあとは線形的な学習率の減衰をするようにしています。

We pre-train all models at resolution 224.

すべてのモデルを解像度 =224 dpi(dot per inch)で事前トレーニングします。dpi は、1インチ当たりのドット数(画素数)です。

For JFT-300M, we pre-process images by applying the cropping technique from Szegedy et al. [44] in addition to random horizontal flipping.

JFT-300Mでは、参考文献にあげた Szegedy らのクロッピング技術(切り取り)に加え、ランダムな水平反転を行うことで、画像の前処理を行っています。

For ImageNet and ImageNet-21k, we employ additional data augmentation and regularization techniques.
In particular, we use RandAugment [12], mixup [56], dropout [42], and stochastic depth [19].

ImageNetとImageNet-21kでは、追加的に、データの補強と正則化の技術を採用しています。
具体的には、RandAugment [12]、mixup [56]、dropout [42]、stochastic depth [19]を使用しています。

言葉の意味ですが、

  • RandAugment: データ量を増やす時の手法で、Googleが2019年に公開した。回転させたり、変形させたりすることでデータ量を増やす。同じくGoogleが公開したAutoAugmentより早く、現時点で最も優れた手法
  • mixup: データ量を増やすときの手法で、MITとFacebook AI Researchが2018年に公開した。2つのデータとそのラベルのセットをそれぞれに係数をかけてSumすることで新たな混合データ(mixup)を作る
  • dropout: 過学習を防ぐために、意図的に学習途中のデータをDropして後続に渡さない手法。これもImageNetを作ったAlex達トロント大学のメンバーが2014年にJMLRに公開した
  • stochastic depth: Dropoutはノードに対するOFFでしたが、これをレイヤー単位に拡張した手法で、Cornell Universityなどが2016年に公開した。Dropoutと同じように確率的に指定したレイヤーを落とすことで、レイヤーが深くなった際の勾配消失を防ぐ効果がある

いくつか論文の解説記事も参考になりました。
https://qiita.com/kitfactory/items/d89457eeab5c185880be
https://qiita.com/yu4u/items/70aa007346ec73b7ff05
https://qiita.com/supersaiakujin/items/eb0553a1ef1d46bd03fa

This set of techniques was inspired by the timm library [52] and Touvron et al. [46].

この一連の技術は、timm library [52]やTouvron et al [46]に触発されたものです。

More details on these hyperparameters are provided in Supplementary B.

これらのハイパーパラメータの詳細は、補足Bに記載されています。

こちらのことですね。Mixerの後に続く文字は、BがBase、LがLarge、SがSmall、HがHugeで一番重厚なモデルです。説明は割愛しますが、RandAugはMagnitudeの m、MixupはMixing Strengthの p、DropoutはDropping Rateの d、Stochastic Depthは変数名は与えられてないですが、そのドロップの確率が、それぞれ記載されてます。

Fine-tuning details

We fine-tune using SGD with momentum, batch size 512, gradient clipping at global norm 1, and a cosine learning rate schedule with a linear warmup.

Fine-Tuning(微調整)の詳細

momentum項ありのSGDを使って微調整します。バッチサイズ=512、グローバルノルム=1での勾配クリッピング、線形ウォームアップを伴うコサイン学習率を用います。

  • SGD: Stochastic Gradient Descentの略で、日本語だと確率的勾配降下法といい、最適化アルゴリズムの1つ
  • momentum: SGDが最適化される過程、つまり損失関数を最小化させるパラメータを見つける過程で、その損失自体が振動してしまう現象を抑制してくれる働きをもつ数式的な項。意味合い的には、前の状態を覚えていて、移動平均的な効果を与えてくれることによって、振動をならしてくれる

こちらの最適化アルゴリズムの説明はこちらの記事が参考になります。
https://qiita.com/omiita/items/1735c1d048fe5f611f80

We do not use weight decay when fine-tuning.
Following common practice [22, 46], we also apply fine-tune at higher resolutions with respect to those used during pre-training.

モデルのパラメータを微調整する際、Weight Decayは使用しません。また、一般的な慣習[22, 46]に従い、事前学習の際に使用した解像度よりも高い解像度のデータで微調整を行います。

仕上げなのでより解像度の高いリッチなデータを用いるイメージでしょうか。

また、Weight Decayとは重みの減衰なのですが、つまり、レイヤーが深くなるにつれてモデルとしては表現力が高くなるのは一般的にわかってることなのですが、それとともに過学習のリスクも高まります。そのバランスを保つ手法として、重み W の各要素の大きさを小さくしてしまうことによって、バイアスの方が相対的に大きくなり、深いレイヤーの効果を薄める手法がWeight Decayのイメージです。下記の記事も参考にしています。

https://qiita.com/supersaiakujin/items/97f4c0017ef76e547976

Since we keep the patch resolution fixed, this increases the number of input patches (say from S to S' ) and thus requires modifying the shape of Mixer’s token-mixing MLP blocks.

(明確に書いてないが)Fine-Tuningの段階では、事前学習の時より解像度の高い画像を使っているため、その解像度 H \times W は大きくなる。

一方で、パッチの解像度 P \times P32 \times 3216 \times 16 などに固定している。

なので、入力パッチの数 S は、S = HW/P^2 で定義されているので、入力パッチ数が増える(例えばSからS'に)

したがって、S が大きくなってしまうと、Mixerの入口にあるトークン混合MLPブロックの形状(行列の縦×横の意味)を変更する必要があります。

Formally, the input in Eq.(1) is left-multiplied by a weight matrix \mathrm{W_1} ∈ \mathbb{R}^{D_S×S} and this operation has to be adjusted when changing the input dimension S.

数式的には、式(1)の入力で次元(つまり縦の要素数が)が元々 S であった列ベクトル \mathrm{X} に対して、重み行列 \mathrm{W_1} ∈ \mathbb{R}^{D_S×S} が左乗されることであり、入力次元 S を変える際にはこの操作を調整する必要がある。

数式(1)はこれでしたね。

\begin{aligned} \mathrm{U}_{∗,i} &= \mathrm{X}_{∗,i} + \mathrm{W}_2 \ σ(\mathrm{W}_1 \ \mathrm{LayerNorm}(\mathrm{X})_{∗,i}), \ \mathrm{for} \ i = 1,..., C \end{aligned}

具体的には元論文の補足Cにありますが、この数式の重み \mathrm{W_1} ∈ \mathbb{R}^{D_S×S}\mathrm{W^{'}_1} ∈ \mathbb{R}^{(K^2D_S)×(K^2S)} に置き換えられます。ここで K は整数の変数で入力画像の画素数の元となるピクセルが1辺あたり K 倍になるとした時に解像度が K^2 倍になることを意味します。

For this, we increase the hidden layer width from D_S to D^{'}_S in proportion to the number of patches and initialize the (now larger) weight matrix W^{'}_2 ∈ R^{D^{'}_S × S^{'}} with a block-diagonal matrix containing copies of W_2 on its diagonal.

そのために、隠れ層の幅(次元)をパッチの数 SK^2 倍比例して D_S から D^{'}_S に増やし、(今までより大きい)重み行列 W^{'}_2 ∈ R^{D^{'}_S×S^{'}} を、W_2 と同じ値を行列の対角線上に含むブロック対角線行列で初期化します。

See Supplementary C for more details.

詳しくは元論文の補足Cをご覧ください。

On the VTAB-1k benchmark we follow the BiT-HyperRule [22] and fine-tune Mixer models at resolution 224 and 448 on the datasets with small and large input images respectively.

VTAB-1kベンチマークでは,BiT-HyperRule[22]に従い,入力画像の数が小さいデータセットでは解像度224、大きいデータセットでは解像度448でMixerモデルを微調整しました。

Metrics

We evaluate the trade-off between the model’s computational cost and quality.

評価指標について。
モデルの計算コストと性能のトレードオフを評価しています。

For the former we compute two metrics:

前者、つまりモデルの計算コストについては、2つのメトリクス(指標)をはかります。

(1) Total pre-training time on TPU-v3 accelerators, which combines three relevant factors:

(1) 1つ目は、TPU-v3アクセラレータでの事前トレーニングのトータル時間。これは以下の関連する3つの要素を組み合わせたものである。

the theoretical FLOPs for each training setup, the computational efficiency on the relevant training hardware, and the data efficiency.

(1-1) 各学習設定における理論上のFLOP数
(1-2) 関連する学習ハードウェア上の計算効率
(1-3) およびそのデータ効率

(2) Throughput in images/sec/core on TPU-v3.

(2) 2つ目は、TPU-v3でのスループット=性能だ。1秒あたり、1コアあたり、どれだけの画像を処理できるか、が単位。

Since models of different sizes may benefit from different batch sizes, we sweep the batch sizes in {32, 64, . . . , 8192} and report the highest throughput for each model.

異なるサイズのモデルでは、異なるバッチサイズが有効である可能性があるため、バッチサイズを {32, 64, ... ... , 8192} で色々調べ、各モデルの最高スコアを記録した。

For model quality, we focus on top-1 downstream accuracy after fine-tuning.

モデルの性能については、モデルパラメータ微調整後のDownstreamタスクにおける精度のTop1(つまり一番精度のよかった数字)に注目した。

On one occasion (Figure 3, right), where fine-tuning all of the models would have been too costly, we report the few-shot accuracies obtained by solving the l_2-regularized linear regression problem between the frozen learned representations of images and the labels.

また、すべてのモデルを微調整するとコストがかかりすぎます。
なので、各モデル、数ショットの精度を報告します(図3右)。
それらは、学習した画像表現とラベルの間で l_2 の正則化・線形回帰問題を解くことで得られた精度値である。

これを見ると、縦軸が正確性の精度で、横軸は左右それぞれ異なり、右側の図であれば、処理能力になる。ここではグラフの意味は説明してないが、確かに点の数はまばらだが、それは時間がかかりすぎるので仕方ないということを言っている。

Models

We compare various configurations of Mixer, summarized in Table 1, to the most recent, state-of-the-art, CNNs and attention-based models.

表1にまとめたMixerの様々な構成を、最新の最先端のCNNやアテンションベースのモデルと比較してみました。

表1はこちら。

表2はこちら。

ぜんぶは比較されてませんが、Mixer-L/16とH/14 が比較されてます。LはLarge、HはHugeの意味で、数字は各パッチの1辺の解像度です。

表の左端がモデル名、2~4列目の数字は精度を表しています。これをみると、Mixer-L/16は微妙にVitなどに負けてるものの、VTABでは買っており、まさにCompetitiveといった結果ですね。

Mixer-H/14では4つのデータセットすべてで微妙に負けてますが、それでもCompetitiveとはいえそうです。

繰り返しますが、Competitiveであるだけで、十分よい結果です。

In all the figures and tables, the MLP-based Mixer models are marked with pink ( ・ ), convolution-based models with yellow ( ・ ), and attention-based models with blue ( ・ ).

表では、MLPベースのMixerモデルをピンク(・)、コンボリューションベースのモデルを黄色(・)、アテンションベースのモデルを青(・)で表示しています。

The Vision Transformers (ViTs) have model scales and patch resolutions similar to Mixer, including ViT-L/16 and ViT-H/14.

Vision Transformer(ViT)は、ViT-L/16やViT-H/14など、モデルのスケールやパッチの解像度がMixerに近いものがあります。

HaloNets are attention-based models that use a ResNet-like structure with local self-attention layers instead of 3×3 convolutions [49].

HaloNetsは、3×3の畳み込みフィルターの代わりに、局所的なAttentionレイヤーを持つResNetのような構造を用いた、Attentionベースのモデルである。

We focus on the particularly efficient “HaloNet-H4 (base 128, Conv-12)” model, which is a hybrid variant of the wider HaloNet-H4 architecture with some of the self-attention layers replaced by convolutions.

ここでは、特に効率の高い「HaloNet-H4 (base 128, Conv-12)」モデルと比較した。
これは、HaloNet-H4アーキテクチャのサイズを大きくしたハイブリッド型のモデルで、Attentionレイヤーの一部を畳み込みフィルターに置き換えたものです。

Note, we mark HaloNets with both attention and convolutions with blue ( ・ ).

なお、アテンションとコンボリューションの両方を持つHaloNetsを青(・)でマークしています。

Big Transfer (BiT) [22] models are ResNets optimized for transfer learning, pre-trained on ImageNet-21k or JFT-300M.

Big Transfer (BiT) [22] モデルは、ImageNet-21kまたはJFT-300Mで事前に学習された、転移学習に最適化されたResNetsです。

NFNets [7] are normalizer-free ResNets with several optimizations for ImageNet classification.
We consider the NFNet-F4+ model variant.

NFNets [7]は、ImageNetの分類のためにいくつかの最適化を行ったノーマライザ(正規化レイヤー)のないResNetsです。
ここでは、NFNet-F4+モデルの型と比較します。

Finally, we consider MPL [33] and ALIGN [21] for EfficientNet architectures.

最後に、EfficientNetアーキテクチャのためのMPL [33]とALIGN [21]を比較します。

MPL is pre-trained at very large-scale on JFT-300M images, using meta-pseudo labelling from ImageNet instead of the original labels.
We compare to the EfficientNet-B6-Wide model variant.

MPLは、JFT-300M画像を用いて大規模な事前学習を行っており、オリジナルのラベルの代わりにImageNetのメタ疑似ラベルを使用しています。
具体的には、MPLの中のEfficientNet-B6-Wideモデルと比較しています。

ALIGN pre-train image encoder and language encoder on noisy web image text pairs in a contrastive way.
We compare to their best EfficientNet-L2 image encoder.

ALIGNは、ノイズの多いウェブ画像とテキストのペアに対して、画像エンコーダーと言語エンコーダーを対照的に事前学習するモデルです。
Alignの最高傑作であるEfficientNet-L2画像エンコーダと比較します。

おわり

「MLP-Mixer」を解説するシリーズ3日目は以上です。次回はExperiments内のMain Resultsからです。

感想や要望・指摘等は、本記事へのコメントか、TwitterのリプライやDMでもお待ちしております!
また、結構な時間を費やして書いていますので、投げ銭・サポートの程、よろしくお願いいたします!

シリーズ関連記事はこちら
https://zenn.dev/attentionplease/articles/532a3de6308f57
https://zenn.dev/attentionplease/articles/7a11a56d767280
https://zenn.dev/attentionplease/articles/df6170f8581b71
https://zenn.dev/attentionplease/articles/7a3e74ad1bc9bf
https://zenn.dev/attentionplease/articles/a0d88939f9ceed
https://zenn.dev/attentionplease/articles/719580daf5a2d1

【2023年5月追記】
また、Slack版ChatGPT「Q」というサービスを開発・運営しています。
こちらもぜひお試しください。
https://q-bot.suchica.com/

Discussion