🔖

AdamWにおける改善点をきちんと理解する

2023/09/20に公開

概要

本稿ではAdamWの論文[1]の内容を元に、AdamとAdamWの違いを掘り下げて説明する。

本稿で扱う内容

  • AdamとAdamWの違いについて
  • Adamのweight decayにはどのような問題があるか?また、AdamWではこの問題をどのようにして解決しているか?
  • 論文の実装と主要なフレームワークの実装の違いについて

Adam と AdamWの違いについて

AdamWはAdamのweight decayの実装を改良することを目的としたもので、以下のような違いがある:

  • Adamにおけるweight decayは実際にはL2正則化として実装されたのに対し、AdamWはweight decayを本来の形式(重み減衰)で実装している

なお、weight decayの実装以外については両者は同じものである。

L2正則化とweight decay

L2正則化とweight decayは以下のような形式で定義される。\thetaはモデルのパラメータ、fは目的関数とする。

L2正則化

f_t^\mathit{reg}(\theta) = f_t(\theta) + \frac{\lambda^\prime}{2} \| \theta \|^2_2 \tag{1}

weight decay

\theta_{t+1} = (1-\lambda)\theta_t \tag{2}

なお、L2正則化は実際は正則化項の微分を勾配に加算する実装になっている。

g_{t} = \nabla f_t (\theta_t) + \lambda^{\prime} \theta_t

L2正則化 \neq weight decay

Adamにおける"weight decay"は実際にはL2正則化によって実装されている。

これは論文[1]でも指摘されている通り、weight decayはL2正則化と等価であるという誤解によるものだと思うが、これはSGD(Stochastic Gradient Decent)の場合にのみ成り立つ事実であり、一般的には成立しない。

なお、論文[1]にはSGDで両者が等価になること、およびAdamでは等価とならないことの証明が掲載されているが、ここでは省略する(前者の証明は簡単なので自分でやってみることを勧める)。

Adamにおけるweight decayの実装の問題点

Algorithm1にpytorchにおけるAdamの実装[2]を転載する。簡単のため、いくつかのパラメータを省略している。

Adam はSGDと異なり、勾配を正規化した上でパラメータ更新する。
簡単のため移動平均の処理を無視して考えると、勾配をその絶対値で正規化する処理と等価である(スケール処理)。この処理をL2正則化とともに用いる場合は以下の問題が生じる。

  • L2正則化を用いた場合、勾配に正則化項\lambda\theta_t が加算されるが、正則化項の値が勾配よりも大きい場合、スケール処理におけるスケールの推定誤差が大きくなる。
  • スケール処理により、正則化項もスケールされる。このことは、勾配の大きさが大きなパラメータほど減衰率が小さくなるという不都合な状況が生じる。

これはいわば、勾配のスケーリング重みの正則化という2つの異なる目的のために導入した処理が干渉しあい(coupled)、互いに悪影響を及ぼしている状況と解釈できる。

Algorithm 1: PyTorchにおけるAdamの実装

\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ &\hspace{13mm} \lambda \text{ (weight decay)} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned}

Adam + 重み減衰(AdamW)

そこで、AdamWではAlgorithm 2([3]より一部改変して転載)のように、勾配のスケーリング処理重み減衰という2つの処理を分離した形式(decoupled)で実装することを提案する。両者の計算過程は独立しているので、互いに干渉することがない。

Algorithm 2: PyTorchにおけるAdamWの実装

\begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \: \epsilon \text{ (epsilon)} \\ &\hspace{13mm} \lambda \text{(weight decay)} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned}

AdamとAdamWにおけるweight decayの実装の計算グラフ

以上を踏まえて、AdamとAdamWにおけるweight decayの実装を計算グラフで表すとそれぞれ以下の図a, bのようになる。このグラフでは簡単のため、m_t, v_t の計算におけるEMA(Exponential Moving Average)の計算を省略し、「勾配の絶対値による正規化」で代替している。

図1. AdamとAdamWのweight decayの実装の計算グラフ

論文の実装とPyTorchなどの主要なフレームワークの実装との違い

重み減衰の本来の意味からすると、重みの減衰率の初期値は学習率とは独立して決定すべきだが、pytorchのAdamWの実装では重みの減衰率が1-\gamma\lambda であり、学習率にも依存する形式になっている。おそらくこれは既存のAdamの実装との互換性を維持する目的だと思うが、1-\lambdaとする方が個人的にパラメータの意味が直感的に理解しやすいと感じる(論文[1]ではむしろこちらの形式で提案している)。

これが気持ち悪い場合はAdamWのweight_decayを設定する際に、学習率の初期値で割ったweight_decay / lrを渡せば良い。

本稿のまとめ

  • AdamWはAdamのweight decayの実装の問題を解消したものである。
  • 「weight decayはL2正則化と等価」という認識はSGDにのみ成立する内容で、一般的には正しくない。
  • Adamにおけるweight decayは実際にはL2正則化であり、この実装では勾配のスケーリング重みの正則化の処理が互いに干渉しあい、本来の機能を互いに阻害する。
  • AdamWでは勾配のスケーリング重みの正則化の処理を独立して計算することで、Adamにおけるweight decayの実装の問題点を解消した。
  • PyTorchのAdamWの実装では論文と異なり、weight_decayが学習率に連動する形式になっている。

参考文献

GitHubで編集を提案

Discussion