Optimizerの変遷 ~最急降下法からAdamWまで~
はじめに
機械学習のコードを書いているエンジニアの方はoptimizerに何を選べばいいか迷ったことがあると思います。
とりあえず有名なモデルでよく使われているAdamやAdamWをよく考えずに選んでいませんか。
本記事ではそれぞれのoptimizerのパラメータ更新式の定性的な意味について、最急降下法からAdamWまでの変遷とともに説明したいと思います。
optimizerの発展の流れ
最急降下法
まずは、最も基本的なパラメータ更新手法である、最急降下法について軽くおさらいする。
パラメータ更新式
ここで
トレーニングデータに対する損失
ステップ
上記の仮定に基づいてパラメータを更新していく方法が最急降下法。
SGD
SGD(Stochastic Gradient Descent)は確率論的な最急降下法という意味。
オンライン学習・ミニバッチ学習
機械学習が発展するにつれて学習データセットの量が非常に大きくなった。そうすると、最急降下法では、一回のステップを計算するのに長い時間がかかるようになる。それを解消するために毎回ランダムに1つのデータを選び、それに対する損失を利用してパラメータを更新するのがオンライン学習。
しかし、オンライン学習では各データ毎に損失の違いが大きいことで学習が安定しないという課題があった、そこでオンライン学習と従来の方法(バッチ学習)のいいとこ取りをしたミニバッチ学習が生まれた。
ミニバッチ学習では、1つのデータを選ぶのではなく事前に決めたバッチサイズの数のデータをランダムサンプリングして、それに対する損失を利用してパラメータを更新するようにした。
パラメータ更新式
式を見てみると最急降下法と内容はほとんど同じである。
唯一の違いは、損失を計算する方法。最急降下法では全トレーニングデータに対する損失
PyTorchで利用する場合
PyTorchで利用する場合は以下のコード。
optimizer = optim.SGD(params=theta, lr=gamma)
問題点
SGDには最小値にたどり着くまでに近い点の間で振動してしまって効率よく最小値を探せない問題がある。
モメンタム法
モメンタムは勢いという意味。
パラメータ更新式
勾配
これは、定性的に過去のデータの慣性を利用することで振動を打ち消していると理解できる。
PyTorchで利用する場合
PyTorchで利用する場合は、以下のコードのようにSGDにmomentum
とdampening
を設定する。
optimizer = optim.SGD(params=theta, lr=gamma, momentum=mu, dampening=tau)
RMSprop
RMSprop(Root Mean Square PROPagation)は二乗平均平方根の伝播という意味。
パラメータ更新式
振動していないならば徐々に勾配の大きさは小さくなっていくはずであると仮定する。
したがって、長い間、勾配が急であるならば振動している度合いが大きいといえる。
振動しているときにはパラメータを小さく更新したいので、勾配が長い間急である度合いでスケーリングする。
勾配が長い間急である度合いは勾配の二乗平均
RMSpropでは、パラメータの更新を勾配の二乗平均平方根
PyTorchで利用する場合
PyTorchで利用する場合は以下のコード。
optimizer = optim.RMSprop(params=theta, lr=gamma, alpha=alpha, eps=epsilon)
Adam
Adam(ADAptive Moment estimation)は適応的モーメント推定という意味。
Adamはモメンタム法+RMSpropにバイアス補正をかけたもの。
パラメータ更新式
バイアス補正以外は
バイアス補正
途中の変形は期待値の線形性、
そこで
PyTorchで利用する場合
PyTorchで利用する場合は以下のコード。Adamにweight_decay
を設定すると後述のL2正則化が適応されてしまうので、weight_decay
を利用したい場合はAdamWを使用することが推奨される。
optimizer = optim.Adam(
params=theta,
lr=gamma,
betas=(beta_1, beta_2),
eps=epsilon,
)
AdamW
AdamW(Adam + Weight decay)は重み減衰付きAdamの意味。
AdamWが導入された背景としてL2正則化の問題点がある。
L2正則化
過学習は各トレーニングデータに対する過剰なフィッティングである。
オーバーフィッティングの例 [1]
上の画像で青い関数(過剰なフィッティングが起きている)と緑の関数を比べてみると、以下が言えそう。
過剰なフィッティングが起きている。
パラメータの大きさを評価する項を損失関数に組み込むことで過学習を抑制する方法が正則化。
L2正則化では損失関数にパラメータの二乗和
ここで
問題点
L2正則化は以下の2つの点でAdamのバイアス補正との相性が悪い。
- 期待値
の推定の誤差が大きくなるE[v_t] - 正則化の項に対して間違ったバイアス補正がかかってしまう
Weight Decay
L2正則化の問題を解消するため別の方式で過剰なフィッティングを防ぐことを考える。
SGDでのL2正則化の振る舞いを見てみる。
SGDにおいてはL2正則化はパラメータ更新時に
この方式をWeight Decayと呼ぶ。
パラメータ更新式
上の式のようにAdamに対してL2正則化の代わりにWeight Decayを行うことで
PyTorchで利用する場合
PyTorchで利用する場合は以下のコード。
optimizer = optim.AdamW(
params=theta,
lr=gamma,
betas=(beta_1, beta_2),
eps=epsilon,
weight_decay=lambda_,
)
最後に
機械学習の理解の助けやoptimizer選びの参考になれば幸いです。
参考
- モメンタム法・RMSpropまでの説明がわかりやすい。Adamのバイアス補正についての説明が省略されているのでそこだけ注意
Discussion