【論文読解めも】Training data-efficient image transformers & distillation through attention
Training data-efficient image transformers & distillation through attention
Transformerベースの、効率的な画像分類モデルDeiTを提案。ViTでは訓練に時間がかかったのを効率化。ImageNetのみを使った訓練でTop1 accuracy83.1%を達成している。さらに、token-based distillationという専用の訓練の枠組みを導入することで、85.2%を達成している。
Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention." arXiv preprint arXiv:2012.12877 (2020).
重要な先行研究:Vision Transformer
TransformerをVisionタスクに適用したVision Transformer(ViT)。入力画像をパッチに分割し、それをシーケンスとして入力する。各パッチはLinear層によって投影される。positional embeddingと足し合わされ、Transformer Encoderに入力される。Transformer Encoderの内部では、Multi-head Self Attention layerが使われており、そののち、2層のLinear層をGeLUで繋いだFeed Forward Networkが接続される。
0番目の位置に、訓練対象である[CLS]に対応するembeddingが入力される。その位置の出力をMLPへと入力し、それが分類器として訓練される。MLPは2層のLinear層からなり、GeLUが活性化関数に使われる。
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
重要な先行研究:Fixing the train-test resolution discrepancy
小さな画像サイズで訓練し、大きな画像サイズでファインチューニングすることで、高速かつ高精度な訓練が可能になるという枠組み。
通常、画像分類モデルの訓練時はデータ拡張としてRandomCrop、評価時にはCenterCropが用いられ、訓練時も評価時も同じ画像サイズを用いるが、これには2つの悪影響がある。
- 物体の見かけ上の大きさが、訓練時と評価時で変わってしまう
- Global Pooling層によって得られる特徴量のstatisticsが、訓練時と評価時で変わってしまう
これを回避するために、訓練時と評価時でCropする画像サイズを変えることで回避できることを示している。その結果、評価時の画像サイズよりも訓練時の画像サイズは小さいほうが良いということが判明したため、従来よりも高速に訓練が可能になることを示している。
💡かなり重要な手法なので、あとでちゃんと精読する
Touvron, Hugo, et al. "Fixing the train-test resolution discrepancy." Advances in neural information processing systems. 2019.
ViTとDeiTの違い
構造はViTと全く同じ。224x224のサイズで訓練し、384x384のサイズでファインチューニングするという流れも一緒。しかし、データ拡張やLabel smoothingの適用などのハイパーパラメータの違いによって訓練を高速化している。以下の設定は、事前学習の時のみ使用される。
Distillation through attention
蒸留に使用する教師モデルは、CNNでも、複数の分類器の組み合わせでもいい。本研究では、生徒モデルのターゲットとしてハード/ソフトな蒸留のどちらが良いかを確認し、それに加えて新たなターゲットになる蒸留トークンを提案している。
ハードな蒸留とソフトな蒸留
蒸留におけるハード/ソフトは、生徒モデルに与えられるターゲットが、教師モデルの予測したラベル
DeiTでは、比較の結果、ハードなターゲットを与えたほうが良い精度が得られるとしている。そのため、損失関数は、通常の分類損失であるCrossEntropy損失に、教師モデルから与えられる
蒸留トークン
通常の蒸留の考え方では、分類トークンの位置から得られるlogitに対して通常の分類損失と蒸留損失が計算される。DeiTでは分類トークンとは別の蒸留トークンを設けて、蒸留に関する損失はそこで計算するという別の方法を提案している。下図のように、分類トークンと同様に訓練対象となる蒸留トークンを設け、先ほどのハードな蒸留に関する損失
以上に関する実験は、以下の表のようにまとめられる。最初の3行は、蒸留なし、ソフトな蒸留、ハードな蒸留の比較になっており、ソフトよりもハードのほうが良い結果となっている。最後の3行は、蒸留トークンを使っている設定で、分類トークンのlogitをもとに予測したとき、蒸留トークンのlogitをもとに予測したとき、両トークンのlogitを足し合わせて予測したときの3パターンを表している。
細かいところの確認
事前学習で使われているテクニックで知らないものを確認した。
- stochastic depth
Residualな構造のレイヤーをランダムにDropすることで学習の高速化と正則化の効果が得られる。
repeated augmentation
各バッチに対して、複数のデータ拡張を適用するBatch Augmentationという手法。学習の高速化と正則化が得られることが示されている。
だいたい理解したので、おしまい。