📉

Optimizerの変遷 ~最急降下法からAdamWまで~

2025/01/20に公開

はじめに

機械学習のコードを書いているエンジニアの方はoptimizerに何を選べばいいか迷ったことがあると思います。
とりあえず有名なモデルでよく使われているAdamやAdamWをよく考えずに選んでいませんか。
本記事ではそれぞれのoptimizerのパラメータ更新式の定性的な意味について、最急降下法からAdamWまでの変遷とともに説明したいと思います。

optimizerの発展の流れ

最急降下法

まずは、最も基本的なパラメータ更新手法である、最急降下法について軽くおさらいする。

パラメータ更新式

\bm{G}_t = \nabla f(\bm{\theta}_{t-1})
\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma \bm{G}_t

ここで\gammaは学習率。

トレーニングデータに対する損失f(\bm{\theta})を最小化するパラメータ\bm{\theta}を求める方法を考える。
ステップtごとに、パラメータ\bm{\theta}_{t-1}地点での勾配\bm{G}_tを求め勾配を下る方向にパラメータ\bm{\theta}を更新すれば損失f(\bm{\theta})は最小値に近づくはずである。
上記の仮定に基づいてパラメータを更新していく方法が最急降下法。

SGD

SGD(Stochastic Gradient Descent)は確率論的な最急降下法という意味。

オンライン学習・ミニバッチ学習

機械学習が発展するにつれて学習データセットの量が非常に大きくなった。そうすると、最急降下法では、一回のステップを計算するのに長い時間がかかるようになる。それを解消するために毎回ランダムに1つのデータを選び、それに対する損失を利用してパラメータを更新するのがオンライン学習。
しかし、オンライン学習では各データ毎に損失の違いが大きいことで学習が安定しないという課題があった、そこでオンライン学習と従来の方法(バッチ学習)のいいとこ取りをしたミニバッチ学習が生まれた。
ミニバッチ学習では、1つのデータを選ぶのではなく事前に決めたバッチサイズの数のデータをランダムサンプリングして、それに対する損失を利用してパラメータを更新するようにした。

パラメータ更新式

\bm{g}_t = \nabla f_t(\bm{\theta}_{t-1})
\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma \bm{g}_t

式を見てみると最急降下法と内容はほとんど同じである。
唯一の違いは、損失を計算する方法。最急降下法では全トレーニングデータに対する損失f(\bm{\theta})を利用していたが、SGDではステップtごとに異なるトレーニングデータによる損失f_t(\bm{\theta})を利用するようにしたところ。

PyTorchで利用する場合

PyTorchで利用する場合は以下のコード。

optimizer = optim.SGD(params=theta, lr=gamma)

問題点

SGDには最小値にたどり着くまでに近い点の間で振動してしまって効率よく最小値を探せない問題がある。

モメンタム法

モメンタムは勢いという意味。

パラメータ更新式

\bm{m}_0 = \bm{0}
\bm{m}_t = \mu \bm{m}_{t-1} + (1-\tau) \bm{g}_t
\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma \bm{m}_t

勾配\bm{g}_tの代わりに勾配の移動平均\bm{m}_tを利用する。
これは、定性的に過去のデータの慣性を利用することで振動を打ち消していると理解できる。

PyTorchで利用する場合

PyTorchで利用する場合は、以下のコードのようにSGDにmomentumdampeningを設定する。

optimizer = optim.SGD(params=theta, lr=gamma, momentum=mu, dampening=tau)

RMSprop

RMSprop(Root Mean Square PROPagation)は二乗平均平方根の伝播という意味。

パラメータ更新式

g_t^2 = \bm{g}_t \cdot \bm{g}_t
v_0 = 0
v_t = \alpha v_{t-1} + (1-\alpha)g_t^2
\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma \frac{\bm{g}_t}{\sqrt{v_t}+\epsilon}

\epsilonはゼロ除算を防ぐための微小な値。

振動していないならば徐々に勾配の大きさは小さくなっていくはずであると仮定する。
したがって、長い間、勾配が急であるならば振動している度合いが大きいといえる。
振動しているときにはパラメータを小さく更新したいので、勾配が長い間急である度合いでスケーリングする。
勾配が長い間急である度合いは勾配の二乗平均v_tを用いれば定量化できる。

RMSpropでは、パラメータの更新を勾配の二乗平均平方根\sqrt{v_t}でスケーリングすることで振動を抑えている。

PyTorchで利用する場合

PyTorchで利用する場合は以下のコード。

optimizer = optim.RMSprop(params=theta, lr=gamma, alpha=alpha, eps=epsilon)

Adam

Adam(ADAptive Moment estimation)は適応的モーメント推定という意味。

Adamはモメンタム法+RMSpropにバイアス補正をかけたもの。

パラメータ更新式

\bm{m}_t = \beta_1 \bm{m}_{t-1} + (1-\beta_1) \bm{g}_t
v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2
\hat{\bm{m}_t} = \frac{\bm{m}_t}{1-\beta_1^t}
\hat{v_t} = \frac{v_t}{1-\beta_2^t}
\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma\frac{\hat{\bm{m}_t}}{\sqrt{\hat{v_t}}+\epsilon}

バイアス補正以外は\mu=\tau=\beta_1としたモメンタム法と\alpha=\beta_2としたRMSpropの組み合わせにほかならない。

バイアス補正

v_tのバイアス補正の係数1/(1-\beta_2^t)の意味について説明する。v_tの期待値について考える。

v_tの一般項は以下の式で求められる

v_t = (1-\beta _2) \sum _{i=1} ^t {\beta _2^{t-i} g_i^2}

v_tの期待値E[v_t]

\begin{aligned} E[v_t] &= E\left[(1-\beta_2) \sum_{i=1} ^t {\beta_2^{t-i} g_i^2}\right]\\ &=(1-\beta_2) \sum_{i=1} ^t {\beta_2^{t-i} E[g_i^2]}\\ &\approx(1-\beta_2) \left(\sum_{i=1} ^t {\beta_2^{t-i}}\right)E[g_t^2]\\ &\approx(1-\beta_2^t)E[g_t^2] \end{aligned}

途中の変形は期待値の線形性、g^2の定常性E[g_i^2]\approx E[g_t^2]、等比数列の和の公式を利用した。

tが小さいときのv_tは非常に小さくなるため、初期のパラメータが大きく変化してしまい学習が安定しない。
そこでE[\hat{v_t}]\approx E[g_t^2]となるように係数で補正する。
\hat{\bm{m}_t}についても同様のバイアス補正をかけている。

PyTorchで利用する場合

PyTorchで利用する場合は以下のコード。Adamにweight_decayを設定すると後述のL2正則化が適応されてしまうので、weight_decayを利用したい場合はAdamWを使用することが推奨される。

optimizer = optim.Adam(
    params=theta,
    lr=gamma,
    betas=(beta_1, beta_2),
    eps=epsilon,
)

AdamW

AdamW(Adam + Weight decay)は重み減衰付きAdamの意味。

AdamWが導入された背景としてL2正則化の問題点がある。

L2正則化

過学習は各トレーニングデータに対する過剰なフィッティングである。

オーバーフィッティングの例
オーバーフィッティングの例 [1]

上の画像で青い関数(過剰なフィッティングが起きている)と緑の関数を比べてみると、以下が言えそう。
過剰なフィッティングが起きている。\iffパラメータの大きさが非常に大きい。

パラメータの大きさを評価する項を損失関数に組み込むことで過学習を抑制する方法が正則化。

L2正則化では損失関数にパラメータの二乗和|\bm{\theta}|^2を追加することでパラメータが大きくなることに対するペナルティを与えられる。

f'(\bm{\theta})=f(\bm{\theta})+\frac{\lambda}{2}|\bm{\theta}|^2
\begin{aligned} \bm{g}'_t &= \nabla f'_t(\bm{\theta}_{t-1}) \\ &= \nabla f_t(\bm{\theta}_{t-1}) + \lambda\bm{\theta}_{t-1} \end{aligned}

ここで\lambdaは正則化項の影響の大きさを決めるパラメータ。

問題点

L2正則化は以下の2つの点でAdamのバイアス補正との相性が悪い。

  • 期待値E[v_t]の推定の誤差が大きくなる
  • 正則化の項に対して間違ったバイアス補正がかかってしまう

Weight Decay

L2正則化の問題を解消するため別の方式で過剰なフィッティングを防ぐことを考える。
SGDでのL2正則化の振る舞いを見てみる。

\begin{aligned} \bm{\theta}'_t &= \bm{\theta}_{t-1} - \gamma \bm{g}'_t\\ &= \bm{\theta}_{t-1} -\gamma\lambda\bm{\theta}_{t-1} - \gamma \bm{g}_t \end{aligned}

SGDにおいてはL2正則化はパラメータ更新時に\gamma\lambda\bm{\theta}_{t-1}を引くことと同値。
この方式をWeight Decayと呼ぶ。

パラメータ更新式

\bm{\theta}_t = \bm{\theta}_{t-1} - \gamma\lambda\bm{\theta}_{t-1} - \gamma\frac{\hat{\bm{m}_t}}{\sqrt{\hat{v_t}}+\epsilon}

上の式のようにAdamに対してL2正則化の代わりにWeight Decayを行うことで\bm{m}_tv_tに影響を与えることなく過学習を防いだ手法がAdamW。

PyTorchで利用する場合

PyTorchで利用する場合は以下のコード。

optimizer = optim.AdamW(
    params=theta,
    lr=gamma,
    betas=(beta_1, beta_2),
    eps=epsilon,
    weight_decay=lambda_,
)

最後に

機械学習の理解の助けやoptimizer選びの参考になれば幸いです。

参考

  • モメンタム法・RMSpropまでの説明がわかりやすい。Adamのバイアス補正についての説明が省略されているのでそこだけ注意
脚注
  1. By Nicoguaro - Own work, CC BY 4.0 ↩︎

mutex Official Tech Blog

Discussion