KagglerによるAWPの公開実装について
概要
AWP(Adversarial Weight Perturbation)[1]は学習サンプルとモデルのパラメータに同時に敵対的摂動を加えながら学習する最適化手法である。noisy な条件や学習サンプルが少ない場合にロバストな解へ収束することが著者らにより報告されている。また、AWP は現実の問題設定に近いタスク( Kaggle コンペ)でも学習の安定性や汎化性能の向上に寄与したと報告されている[3,4]。
本稿では Kaggle ユーザによる AWP の公開実装について、オリジナルの実装との違いを確認する。
AWPのオリジナルの実装
まず、オリジナルの実装の計算過程について確認する(Fig.1)。
- 敵対的サンプルを生成する(1st backprop)
- 敵対的サンプルを入力して敵対的摂動を計算する(2nd backprop)
- 敵対的摂動を各ブロックのパラメータのスケールに応じて正規化する
- 摂動が加えられたパラメータ点において各種勾配法(SGD, Adam etc.)によりパラメータを動かす(3rd backprop)
- 摂動分を相殺する
最初の敵対的サンプルの生成部分は AWP とは独立しているが、全部含めると 3 回 back prop を適用する事になり、学習時間は通常の 3 倍かかる。
Fig.1
@wht1996 による実装
次に、@wht1996 による実装を見ていく。計算過程(Fig.2)は以下のようになる。
- 最初の時点のパラメータを保存する
- 現在のパラメータを元に勾配を計算する(1st backprop)
- 勾配を元に正規化された敵対的摂動を計算する
- パラメータに摂動を加え、その点での勾配を計算する(2nd backprop)
- 最初の時点のパラメータを復元する
- 4 で計算した勾配に基づいて
optimizer.step()
を適用する
大きな違いとして、敵対的サンプルの生成を省略することで backprop の計算を 2 回で済ませている。また、パラメータと勾配のアップデートの計算の工夫として、途中の敵対的摂動と摂動点での勾配を計算する処理では optimizer のパラメータを更新せず、最後に 1 度だけ行っている。このようにすることで optimizer が保持する momentum などの状態に影響しないような配慮をしている。オリジナルの実装では元のモデルと proxy モデルの 2 つ分を用意しているが、この実装では 1 つのモデルのみで完結している。
また、細かな違いとして、Fig.3 に示すように摂動の範囲が一定の範囲内に収まるように clipping を適用している。これは、おそらく学習の安定化のために導入したのだろう。他には、AWP を最初から適用するのでなく、テストのスコアがある閾値を超えてから初めて適用するといったように計算時間を短縮する工夫がされている。
Fig.2
Fig.3
@junkodaによる実装(FastAWP)
最後に、@junkoda による FastAWP という実装[5]について確認する(Fig.4)。このアイデアは、摂動点を計算するために追加の backprop を実行するのでなく、 optimizer の保持している勾配のモーメントで代替するというもの。backprop の計算が 1 回ですむため、追加のオーバヘッドが発生しない。ただし、この方法で推定した摂動は真の敵対的摂動とは異なるため性能が劣化している可能性がある。効果については実際のタスクに当てはめて評価する必要がある。
Fig.4
Reference
- Adversarial Weight Perturbation Helps Robust Generalization
- https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
- https://www.kaggle.com/competitions/feedback-prize-2021/discussion/313177
- https://www.kaggle.com/competitions/nbme-score-clinical-patient-notes/discussion/315707
- https://www.kaggle.com/code/junkoda/fast-awp
Discussion