勾配ベースの最適化手法について直近に読んだ論文のまとめ
概要
勾配ベースの最適化手法について直近数ヶ月で読んだ論文を元にまとめる。網羅的ではないことに注意。
構成
以下のような構成でまとめる。
- Adamの派生手法
- 補助的に用いられる最適化手法
- リソース消費を最適化する手法
1. Adamの派生手法
Decoupled weight decay (SGDW, AdamW)
[5]ではAdamのweight decayの実装の問題点を理論的に解明し、これに対処する代替アルゴリズムを提案した。提案手法はAdamの汎化性能を改善し、CIFAR-10, ImageNet32x32による評価でSGD with momentum(提案手法以前ではAdamを上回ることが多かった)と同等の性能となることを示した。
Rectified Adam (RAdam)
RMSpropやAdamなどのadaptiveな最適化手法において、warmupが学習の安定性や汎化性能の向上に寄与することが知られている。[6]ではこれらのadaptiveな最適化手法で学習の初期段階において分散が大きくなりすぎることを解明し、学習の初期段階において分散が小さくなるように保つ矯正項(rectification term)を導入するとを提案する。提案手法は機械翻訳(IWSLT)と画像分類(CIFAR-10)のベンチマークにおいて、広いパラメータ範囲でwarmupと同等の性能となることを経験的に示した。
2.補助的に用いられる最適化手法
以下で挙げるのは基本的なoptimizer(SGD, Adam)などと組み合わせて用いられる補助的な最適化手法である。
Adversarial Weight Perturbation (AWP)
学習サンプルに敵対的摂動を加えることでモデルの予測値がノイズに対して堅牢になり、汎化性能が向上することが知られている(Adversarial Training; AT)が、AWP(Adversarial Weight Perturbation)[2]はサンプルに対してだけでなく、モデルの重みに対しても敵対的摂動を加える手法である。
ベースとなるアイデアとしては、loss landscape(入力の変動に対するlossの変動をグラフにプロットしたもの)のが平らになることを目指したもの。提案手法は、CIFAR-10に対する評価でtestデータに対する汎化性能がAT単体よりも高くなることを示した。
なお、この手法はforward/backward stepが1つのバッチに対して2回実行されるため、学習時間が約2倍となる欠点がある。このため、実際に使用される場面では学習の終盤のみ適用することがある[3]。
Lookahead
[4]は探索範囲をkステップ先読み(lookahead)する手法で、既存の最適化手法(SGD, Adamなど)と組み合わせて利用することが可能である。提案手法はImageNet, CIFAR-10/100, 機械翻訳などのベンチマークでSGDやAdam単体と比較して性能が向上することを経験的に示した。
本手法はgradient accumulationやSWA(Stochastic Weight Average)とも関連した手法で、数理的には重みのEMA(Exponential moving average)を計算することに相当する。
3. リソース消費を最適化する手法
近年の大規模言語モデルに代表されるモデルのパラメータ数の巨大化に伴い、訓練時のインフラの要件が厳しくなってきたことを背景にして、訓練時のリソース消費を最適化する手法について挙げる。
AdaFactor
RMSProp, Adam, Adadeltaなどのadaptiveな最適化手法では、勾配のモーメントおよび2次のモーメントをメモリに保持している。これらはそれぞれモデルのパラメータを保持するのと同じだけのメモリ容量が必要である。[9]では2次のモーメントについて、行、列のそれぞれの和の移動平均のみ保持し、これらの移動平均を用いて各パラメータごとの2次のモーメントを推定する手法を提案する。この手法により、2次モーメントのメモリの消費量は
Zero Redundancy Optimizer (ZeRO)
[11]は分散訓練環境におけるメモリの重複保持をゼロにすることを目的とした手法である。
近年の大規模言語モデルに見られるようなモデルパラメータ数のモデルのファインチューニングを、VRAMの容量が小さな環境でも訓練できるように開発された。提案手法によるメモリ配置の最適化は、効率的に訓練できるモデルのパラメータ数を増やすとともに、訓練時間を大幅に短縮することに成功したと報告している。
なお、本論文では既存提案された分散手法(data parallelism, model parallelism, pipeline parallelism)のメリットとデメリットについて整理していることも特徴である。本論文は「何でメモリを消費しているか?」「何が重複して保持されるか?」を整理する資料としても参考になる。
参考資料
- [1] Stochastic Weight Averaging in PyTorch
- [2] Adversarial Weight Perturbation Helps Robust Generalization
- [3] Masaki AOTA, Kaggleで使用される敵対学習方法AWPの論文解説と実装解説 ~Adversarial Weight Perturbation Helps Robust Generalization~, 2022
- [4] Lookahead Optimizer: k steps forward, 1 step back
- [5] Decoupled Weight Decay Regularization, ICLR 2019
- [6] On the Variance of the Adaptive Learning Rate and Beyond
- [9] Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
- [10] T5 - Hugging Face
- [11] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
Discussion