全ての学習率スケジューリングを過去にするOptimizer

2024/12/10に公開


  • RAdamScheduleFree という新しいoptimizerを作りました。
  • warmup、learning rate scheduler、訓練終端時刻の指定、全て不要です。
  • 安定かつ高速に収束します。多くの場合でAdamやAdamWより強いです。


https://x.com/hamanasu_nagisa/status/1861974296257864154
https://x.com/aaron_defazio/status/1863722349138219459
https://x.com/hamanasu_nagisa/status/1863739202665644212


https://github.com/facebookresearch/schedule_free
この Version 1.4 から RAdamScheduleFree が搭載されました。

みんな使ってね。


はじめに

お久しぶりです。はまなすです。

某記事を執筆してからそろそろ2年が経ちそうなことに戦慄を覚えつつ、いまでも新たに反応をいただくこともあり、じんわり嬉しさを覚える年の瀬です。


さて、実務や趣味で機械学習開発に携わっている方々は、皆等しくAdam(またはAdamW)とお友達だと思います。Adam台頭以降、RAdamAdaBoundAdaBeliefなど、直接的な後継といって差し支えないものだけでも、新たなoptimizerは数多く登場してきました。Lionなどの別系統も[1]。しかし、デファクト・スタンダードと呼べるoptimizerはいまだにAdamです。なぜでしょうか。

ひとつには、Adamにしておけばひとまずは問題ないという安心感がありそうです。もはやある種の枯れた技術となったAdamは、どこの馬の骨ともわからないoptimizerを試すより遥かに信頼できる実績が積み重なっており、とりあえずそれを選んでおくことで「もしかしてこの謎挙動、Adamを使わなかったせいかな…?」みたいな無益な精神の摩耗から我々を保護してくれます。

もうひとつには、Cosine annealing scheduler などの強力なschedulerが同様にデファクト・スタンダードに君臨していることが挙げられるでしょう。Adamと Cosine annealing scheduler の組み合わせは暴力的に汎用性が高く、そもそもoptimizerを変えようという発想を(optimizerジャンキーを除く一般的な)開発者から奪い去っています。それは同時に、本来はschedulerの導入にあたって発生する新たな自由度を飼い慣らす必要性さえをも我々の意識から遠ざけ、「とりあえずこれでいいや」という盲目的な気持ちの増長に一役買っています。本当は、適切な調整によりもっといいモデルが訓練できる余地があるかもしれないのに、です。

でも、タスク設計やデータ整備、モデルアーキテクチャなどのより本質的な部分に思考リソースを割きたい開発者にとって、optimizerやschedulerの調整にまで気を配るのは正直面倒なはずです。

では、そもそもそんな学習率スケジューリングなんて全く要らないoptimizerがあるとしたら?


ScheduleFree という新パラダイム

2024年春、ScheduleFree という新たなoptimizerシリーズがMetaより突如提案されました。

https://arxiv.org/abs/2405.15682

詳細は理論背景の章に譲るとして、これこそがまさに「学習率スケジューリングなんて全く要らないoptimizer」の正体そのものです。初期リリースではSGDとAdamWの ScheduleFree version が実装されており、論文内の実験でもこれら2種類における性能の高さが実証されています。

簡潔にまとめると、ScheduleFree、特にその主提案である AdamWScheduleFree によって、次のようなことが実現されました。

  • [手軽] いかなるlearning rate schedulerも不要。訓練終端時刻の指定も不要。
  • [性能] scheduler込みで調整したAdamWより安定で高速。収束品質も同等かそれ以上。



The Road Less Scheduled: Figure 5|既存のoptimizerに learning rate schedule を組み合わせたものと ScheduleFree の比較。上4つがSGD、下4つがAdamWによる実験結果。画像分類や画像の継続事前学習、言語翻訳や言語モデルなど、種々のタスクやモデルアーキテクチャにて強力な優位性と安定性を実証。



The Road Less Scheduled: Figure 7|AlgoPerf challengeに基づく、音声を含む様々なモダリティやタスクにおける検証結果。ベースラインであるNesterov AdamWと比較して、同様に優位性と安定性を実証。


このように旧来のベストプラクティスを次々と駆逐した AdamWScheduleFree でしたが、その影で「warmup stepsを指定する必要性」だけが若干の手間として残されました。

学習率のwarmupはいわずとしれた重要な調整項目で、とりわけAdamWのように勾配の二次指数移動平均を用いるoptimizerは、訓練初期の安定性に無視できない影響を受けます。これも「とりあえず設定しておけばいいハイパラ」の筆頭で、実際のところ、デフォルト値以外をきちんと探索する気概や余裕のある開発者は限られるのではないかという実感があります。

そこで登場するのが古株のRAdamです。RAdamは、AdamWでは職人芸的に調整するしかない最適なwarmup stepsと同等性能を、ハイパラ設定なしに実現してくれます。素晴らしいoptimizerです。あとはお分かりですね。

RAdamの ScheduleFree version があれば、我々はすべての苦行から解放されます。

https://github.com/facebookresearch/schedule_free

なので作りました。


これ使っときゃOK、を次の水準へ

なにはともあれインストール。とっても簡単ですね。

pip install schedulefree

なんらかのパッケージマネージャをお使いの場合でも、普段通りで大丈夫です。

rye add schedulefree
rye sync


それでは ScheduleFree の使い方を見ていきましょう。といっても、基本的には従来のoptimizerとなんら変わりません。差分は以下の2点です。

  • scheduler が不要になること
  • optimizer.eval()optimizer.train()を適切なタイミングで呼ぶこと

例えば、AdamWとなんらかのschedulerを使う場合のよくあるモックを考え、そこからの差分として骨子を表現してみると、以下のような感じになるかと思います[2]

 import argparse
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from pathlib import Path
+from schedulefree import RAdamScheduleFree
 from torch import Tensor
-from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
 from torch.utils.data import DataLoader


 class MyModel(nn.Module):
     r"""
     Awesome neural network.
     """
     def __init__(self, *args, **kwargs):
         ...

     def forward(self, x: Tensor) -> Tensor:
         ...
         return y


 class MyDataset(nn.Module):
     r"""
     Awesome dataset.
     """
     def __init__(self, *args, **kwargs):
         ...

     def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
         ...
         return data, target


 def train(
     model: nn.Module,
     trainloader: DataLoader,
     optimizer: torch.optim.Optimizer,
-    scheduler: LRScheduler,
     device: torch.device,
 ):
     r"""
     Train the model for one epoch.
     """
     model.train()
+    optimizer.train()

     for batch in trainloader:
         data, target = map(lambda x: x.to(device), batch)
         optimizer.zero_grad()

         pred = model(data)
         loss = F.mse_loss(pred, target)
         loss.backward()

         optimizer.step()
-        scheduler.step()


 def validation(
     model: nn.Module,
     validloader: DataLoader,
+    optimizer: torch.optim.Optimizer,
     device: torch.device,
 ):
     r"""
     Evaluate model performance on the validation set.
     """
     model.eval()
+    optimizer.eval()
    
     with torch.inference_mode():
         for batch in validloader:
             data, target = map(lambda x: x.to(device), batch)
             pred = model(data)
            
             # update metrics
             ...


 def save(
     save_dir: Path,
     model: nn.Module,
     optimizer: torch.optim.Optimizer,
-    scheduler: LRScheduler,
 ):
     r"""
     Save model and training states to a checkpoint file.
     """
     torch.save(
         {
             "model": model.state_dict(),
             "optimizer": optimizer.state_dict(),
-            "scheduler": scheduler.state_dict(),
         },
         save_dir / "checkpoint.pt"
     )


 def load(
     load_dir: Path,
     device: torch.device,
     model: nn.Module,
     optimizer: torch.optim.Optimizer,
-    scheduler: LRScheduler,
 ):
     r"""
     Load model and training states from a checkpoint file.
     """
     pkg = torch.load(load_dir / "checkpoint.pt", map_location=device)
     model.load_state_dict(pkg["model"])
     optimizer.load_state_dict(pkg["optimizer"])
-    scheduler.load_state_dict(pkg["scheduler"])


 def main():
     # =========================== #
     #      Training settings      #
     # =========================== #
     parser = argparse.ArgumentParser()
     parser.add_argument('--epochs', type=int, default=1)
     parser.add_argument('--save', type=str, default=None)
     parser.add_argument('--load', type=str, default=None)
     ...
     args = parser.parse_args()
     device = ...

     # ====================== #
     #      Preparations      #
     # ====================== #     
     model = MyModel(...).to(device)

     params = model.parameters()
-    optimizer = torch.optim.AdamW(params, lr=1e-4, betas=(0.9, 0.999))
+    optimizer = RAdamScheduleFree(params, lr=1e-4, betas=(0.9, 0.999))
-    scheduler = CosineAnnealingLR(optimizer, ...)

     if args.load is not None:
         load(
             Path(args.load),
             device,
             model,
             optimizer,
-            scheduler,
         )

     trainloader = DataLoader(MyDataset(...), ...)
     validloader = DataLoader(MyDataset(...), ...)

     # ================== #
     #      Training      #
     # ================== #
     for epoch in range(args.epochs):
         train(
             model,
             trainloader,
             optimizer,
-            scheduler,
             device
         )

         validation(
             model,
             validloader,
+            optimizer,
             device
         )

         if args.save is not None:
             save(
                 Path(args.save),
                 model,
                 optimizer,
-                scheduler,
             )


 if __name__ == '__main__':
     main()

はい、schedulerがことごとく抹消されていることと[3]、optimizerに対してtrainevalというメソッドが呼ばれていることがわかりますね[4]


このoptimizerの状態切り替えが何をやっているかは、「ScheduleFree が何をしているか」と「実際の実装でそれがどう実現されているか」の両方を理解しないと正確には理解できないので、一旦は『modelとセットでoptimizerの状態も切り替える必要がある』とだけ覚えておいてください。

また、上例ではvalidationのすぐ後にsaveがあるので気にしなくてよかったのですが、『modelやoptimizerの保存・読み込み時は、optimizerが eval mode になっている必要がある』ことも頭の片隅に置いておいてください[5]。例えば、checkpointの保存が訓練ループの中にあるような場合には、torch.saveの前にoptimizer.eval()を呼んでくださいね[6]


さて、RAdamScheduleFree を実践で使用するための話は、ここまでで全て終わりました。他の多くのoptimizerと同様、中でなにをしているか知らなくても使えます。重要なのは、今回の提案により『warmup、learning rate scheduler、訓練終端時刻の指定、全てが不要になる』ということです[7]。これにより、我々はもっと本質的な開発や調整に時間を割くことができるようになります。

ところで、詰まるところ今回の取り組みスコープは既存の強いoptimizerの組み合わせでしかなく[8]、特にこれで論文を書いたりするわけでもないので、新規に網羅的な性能実験などはおこなっていません。つまり、皆さんにとっては依然として「どこの馬の骨ともわからないoptimizer」の類ではあるわけですが、それをあなたにとっての新しい「これ使っときゃOK」にするかどうかは、あなたの好奇心次第です。この記事の目的は、冒頭に述べたもうひとつのほうの懸念を潰すことで、その背中を押すことにもありました[9]

少なくとも、私の使用感では RAdamScheduleFree はかなり強いです。なにより超楽。

皆さんもぜひ使ってみてくださいね。こんな問題でもうまくいったよ、こんな問題だと期待薄だったよ、みたいな感想も歓迎しています。気が向いたらお気軽にお寄せください。
























──さて、沼にはまる覚悟はありますか?

理論背景

ここからは関連する理論背景について話していきます。怪しげな ScheduleFree の概要を理解したい方には多分面白い内容になると思いますが、もともと本記事の想定読者に据えている「optimizerの選定や調整なんて面倒すぎるよ〜」という方々にとっては特に有益な情報ではないと思うので、そのあたりの温度感をご理解いただいた上で読み進めていただければ幸いです。

以降、かなり玄人向けです。主に次の三本立てでお送りします。

  • ScheduleFree のお気持ちを理解する
  • RAdamを思い出す
  • そして RAdamScheduleFree へ


ScheduleFree のお気持ちを理解する

いきなり二次指数移動平均を用いる最適化手法を考えると話がややこしくなるので、モメンタムも考えない最もシンプルなSGDを出発点に ScheduleFree の挙動を確認していきましょう。

単純なSGDと、パラメータの反復平均化

モメンタムのないSGDは、更新されるパラメータ系列 z_t において、各時刻 t の確率的勾配方向に徐々に進んでいくだけの挙動を示します。ここで、f はある関数、\gamma は学習率、\zeta_t は各時刻で与えられる乱数(機械学習でいえばミニバッチに相当)を表しています。

z_{t+1} = z_t - \gamma \nabla f(z_t,\zeta_t).

さて、このような愚直な勾配降下は準最適な(つまり最適とはいえない)結果に落ち着くことが広く知られています。そこで、このようにして更新されるパラメータ系列 z_t をいい感じに平均化することで、より最適なパラメータ x_t を得られるのではないか、という考えが探求されました。そのひとつが Polyak averaging です。Polyak averaging の考え方は簡単で、それまでに得られたすべてのパラメータを素直に平均化するものです。

\begin{align*} z_{t+1} &= z_t - \gamma \nabla f(z_t,\zeta_t),\\ x_{t+1} &= \frac{1}{t+1}\sum^{t+1}_{i=1} z_{i}. \end{align*}

各時刻で勾配を計算するための(すなわち訓練時に評価される)パラメータは z_t のままであるものの、推論用のパラメータとして新たに x_t が返されるというイメージです。さらに、この x_t の計算を各時刻で逐次的におこなえるようにしたものが Polyak-Ruppert (PR) averaging です。

\begin{align*} z_{t+1} &= z_t - \gamma \nabla f(z_t,\zeta_t),\vphantom{\frac{1}{t}}\\ x_{t+1} &= (1-c_{t+1})\,x_t + c_{t+1} z_{t+1}, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ c_{t+1} &= \frac{1}{t+1}. \end{align*}

計算してみるとわかりますが、x_{t+1} における z_i それぞれの寄与は \frac{1}{i}\frac{i}{i+1}\cdots\frac{t}{t+1}=\frac{1}{t+1} となるので、まさに Polyak averaging に一致します。

ところで、Polyak averaging が想定する関数 f のクラスは、本来はLipschitz連続な凸関数です。機械学習が扱うより複雑な関数との乖離がここにあり、実際、PR averaging は実践ではあまり有効ではないことが知られているようです。

話は変わり、比較的最近提案された平均化手法として、Primal averaging というものも存在します[10]。Primal averaging は見かけ上 PR averaging とそっくりですが、勾配を評価するパラメータが z_t ではなく x_t であることにのみ違いがあります。

\begin{align*} z_{t+1} &= z_t - \gamma \nabla f(x_t,\zeta_t),\vphantom{\frac{1}{t}}\\ x_{t+1} &= (1-c_{t+1})\,x_t + c_{t+1} z_{t+1}, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ c_{t+1} &= \frac{1}{t+1}. \end{align*}

実は、うまく定数変換をすることで、Primal averaging は注意深く学習率スケジューリングされたモメンタム付きSGDと同一視できることが示されます。その意味で、機械学習最適化における超基礎たるモメンタム付きSGDは、勾配の指数移動平均ではなく、純粋なSGDにおけるパラメータの反復平均化という世界へ接続するのです。

これを踏まえて、改めてそれぞれの平均化の表式を眺めてみましょう。PR averaging では z_t で勾配が評価されていたため、z_t の更新速度そのものはモメンタムのないSGDと等しくなりますが、その平均を取る x_t はゆっくりと値が更新されていきます。一方 Primal averaging は、そのゆっくりと更新される x_t で勾配を評価しています。これにより z_t の更新自体も同様にゆっくりになってしまい、収束速度が著しく落ちるという性質が読み取れます。

パラメータ更新は速いが実践に弱い PR averaging と、収束の遅さは拭えないが「注意深く学習率スケジューリングされたモメンタム付きSGD」の形で有効性が実証されている Primal averaging という図式が揃いました。綺麗なトレードオフです。

こういうとき、両者のいいとこどりをしたいのが人情というものです。


パラメータ反復平均化の補間としての SGDScheduleFree

PR averaging と Primal averaging を内挿するような定数 0 \leq \beta \leq 1 を考えてみましょう。

\begin{align*} y_t &= (1-\beta)\,z_t + \beta\,x_t,\vphantom{\frac{1}{t}} \\ z_{t+1} &= z_t - \gamma g_t,\vphantom{\frac{1}{t}}\\ x_{t+1} &= (1-c_{t+1})\,x_t + c_{t+1} z_{t+1}, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t)\vphantom{\frac{1}{t}},\\ c_{t+1} &= \frac{1}{t+1}. \end{align*}

\beta=0y_t=z_t\beta=1y_t=x_tとなるので、両極にて PR averaging と Primal averaging を表すのがわかるでしょうか。\beta がその中間の値であれば、それぞれの性質を補間したような最適化手法となるわけですね。これこそが SGDScheduleFree です。

安定性と更新速度のいいとこどりを実現する実践的な値として、論文では \beta=0.9 が推奨されています。これは図らずもモメンタム付きSGDにおける \beta のベストプラクティスと同じなので、実際にコーディングする際の使用感としても遜色ないものとなりそうです。


ScheduleFree の収束性[最も抽象的で、難しい]

さて、このようにして導出された SGDScheduleFree は、学習率スケジューリングをしなくても優秀な収束性を持つことが示されています。論文内で最初に示される性質は、Lipschitz連続な非平滑凸関数設定を考えたとき、どのような \beta の選択に対しても最悪ケース最適性を持つということです。なにをいっているかわかりませんね。

Lipschitz連続な凸関数は、機械学習の確率的最適化手法を考える際によく出てくる問題のクラスです。先ほども出てきましたね。平たくいうと『どの2点を取っても、その間の関数の変化の大きさが距離に比例して抑えられている凸関数』のことです。急峻な変化を持つ箇所がどこにも存在しない、基本的にはなだらかなお椀型の形状を持ち、理論的な解析がしやすい関数のクラスと考えればよいでしょう。例えば、どのような \zeta を与えても関数 fG-Lipschitzであるとき、任意のパラメータ a, b について \|f(a, \zeta)-f(b, \zeta)\|\leq G\|a-b\| が成り立ちます。今回はさらに非平滑とあるので、連続だが尖った点の存在は許容されます。

また最悪ケース最適とは、理論的に可能な最高の最悪ケース収束率を達成していることを意味します。最悪ケース収束率とは『それ以上遅い収束はあり得ない』ことを表現する収束性指標で、例えば最適化の終端時刻 T を用いて表されます。ScheduleFree ではハイパーパラメータとして T を指定する必要はありませんが、仮にある T までの更新系列を収束性評価するならどうか、と考えるといいかもしれません。この場合、終端パラメータ x_T と最適解 x_{\star} それぞれにおける関数値の差の期待値 \mathbb{E}[F(x_T) - F(x_{\star})] が、T についてどのようなオーダで減少するかを評価します。

いま、G-Lipschitz な f について F(x)=\mathbb{E}[f(x,\zeta)] としたとき、F が凸関数であれば、\beta の選択にかかわらず SGDScheduleFree は以下を成立させます。この結果は F の平滑性には左右されないことに留意してください。

\begin{align*} \mathbb{E}[F(x_T) - F(x_{\star})] &\leq \frac{DG}{\sqrt{T}},\\ \mathrm{where}\ \ D &= \|x_1 - x_{\star}\|,\vphantom{\frac{1}{T}}\\ \gamma &= \frac{D}{G\sqrt{T}}. \end{align*}

論文では特に記載がありませんが、次のように変数を消去するともっとわかりやすいかもしれません。T を事前に決めずに好きな \gamma を設定したような状況としてはイメージしやすいでしょう。

\mathbb{E}[F(x_{\frac{D^2}{\gamma^2 G^2}}) - F(x_{\star})] \leq \gamma G^2.

ところで、Lipschitz連続な平滑凸関数の最悪ケース最適速度は O(\frac{1}{T})、Lipschitz連続な非平滑凸関数の最悪ケース最適速度は O(\frac{1}{\sqrt{T}}) であることが広く知られています。先の結果と照らし合わせれば、F が非平滑なとき、SGDScheduleFree の収束保証は最悪ケース最適なことがわかります。

結局 SGDScheduleFree がどういう性質を持つかというと、「Lipschitz連続な非平滑凸関数において、どんな \beta を設定しても、達成可能な最も速い最悪ケース収束率が常に実現される」と主張しています。勾配の一次指数移動平均を用いる既存のoptimizerは \beta によっては最適性を壊してしまうことが知られているため、これは ScheduleFree で初めて顕現した能力です。すごいですね。

論文ではこの議論をさらに拡張し、x_tz_i の任意の荷重平均である場合や、時刻ごとに \beta が変化する場合における一般的な収束率が導出されています。このとき、\sum_{t=1}^T {\langle g_t, z_t - x_\star\rangle} の形で表される regret という量を用いた議論に転換することで、Adamを含む様々なoptimizerがこれまで歩んできたオンライン凸最適化に関する議論史の包括的な統一化を果たしています。

後学のため、定式化も確認しておきましょう。独立同分布な変数系列 \zeta_t、任意の重み系列 w_t、任意の係数系列 \beta_t に対し、とあるパラメータ系列 z_t から SGDScheduleFree を拡張したような次の形式で x_t, y_t の系列を作ることを考えます。

\begin{align*} x_t &= \frac{\sum^t_{i=1} w_i z_i}{\sum^t_{i=1} w_i} = (1-c_t) \, x_{t-1} + c_t z_t, \\ y_t &= \beta_t x_t + (1-\beta_t) z_t,\vphantom{\frac{1}{T}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t)\vphantom{\frac{1}{t}},\\ c_t &= \frac{w_t}{\sum^t_{i=1} w_i}. \end{align*}

このとき、凸関数 F に対して x_T の収束性は次のように評価できることが示されました。

\mathbb{E}[F(x_T) - F(x_{\star})] \leq \frac{\mathbb{E}[\sum_{t=1}^T {w_t \langle g_t, z_t - x_\star\rangle}]}{\sum^T_{t=1} w_t}.

この時点で z_t の更新の仕方が単なる勾配降下でなくてもいいように議論が拡張されており、Adamなどの二次指数移動平均を用いた手法へ接続する扉が開かれています。どういうことかというと、そもそも z_t のパラメータ系列がどのように更新されるかを一切記述しないまま、x_t の収束性を議論することが実現されているのです。離れ業ですね。こういうことができるのは、「なんらかの学習率系列 \gamma_t にしたがって更新される z_t」ではなく、「なんでもいいからとりあえず存在する z_t をパラメータ平均化した x_t」を対象に議論を展開しているためです。schedulerを撤廃し、パラメータ平均化の文脈で最適化を議論することで辿り着いた境地です。

そういうわけで、w_t\beta_t の値、また z_t の更新方法を変えることで、最初に提示した収束性指標を含め、この定式化は様々なoptimizerの議論を復元します。


色々と小難しい言葉を並べましたが、これまではschedulerを用いて時刻ごとに変化する学習率 \gamma_t を設計する必要があったのに対し、ScheduleFree シリーズは定数 \gamma さえ指定すれば問題なく素早く収束する、ということがわかれば十分です。これが ScheduleFree たる所以ですね。また \gamma の値にも頑健で、これまで用いられてきた一般的な学習率よりもオーダーレベルで大きな学習率を用いても安定して収束することも示されたりしています。

直感的な理解をすると、t\to\infty の極限で c_t\to0 になるため、訓練が進むにつれて x_t の動きは徐々に鈍り、ある点 x_{*} に収束するようにして動かなくなりそうです[11]。このとき、ある種の性質の良いアンカーの役割を果たす x_t \approx x_{*} と、確率的勾配に従って動き続ける z_t の内挿として y_t は表されるため、z_t ほどでないにしろ y_t も動き続けます。こうして、パラメータ空間内の探索と安定なパラメータ平均化のバランスが取られるわけですね。もし探索しているパラメータ空間の性質がよければ、x_* は最適解 x_\star となるのでしょう。

先に説明したように推論時は x_t が用いられることを考慮すると、訓練が進むにつれて新しい勾配の寄与が小さくなっていく、つまり、実質的な学習率が漸近的に小さくなっていくような効果が得られていることがわかるでしょうか。

このイメージを具象化してみましょう。論文の行間を埋める作業として、SGDScheduleFree において \beta=0,\ y_t=z_t で表される、シンプルな Polyak averaging に立ち戻ってみます。いま x_t はそれまでの全てのパラメータ系列 z_t の平均なので、以下のようにして各ステップの勾配 g_t が訓練中の x_t にどれだけ寄与しているかを求めることができます。

\begin{align*} x_t &= \frac{1}{t}\sum^t_{i=1} z_i\\ &=\frac{1}{t}\{(z_0 - \gamma g_0) + (z_1 - \gamma g_1) + \dots + (z_t - \gamma g_t)\}\\ &=\frac{1}{t}\{(z_0 - \gamma g_0) + (z_0 - \gamma\sum^1_{i=0} g_i) + \dots + (z_0 - \gamma\sum^{t-1}_{i=0} g_i)\}\\ &=z_0-\frac{\gamma}{t}\{g_0 + \sum^1_{i=0} g_i + \dots + \sum^{t-1}_{i=0} g_i\}\\ &=z_0-\frac{\gamma}{t}\{t g_0 + (t-1)g_1 + \dots + g_{t-1}\}\\ &=z_0-\gamma (g_0 + \frac{t-1}{t} g_1 + \dots + \frac{1}{t} g_{t-1})\\ &=z_0-\gamma \sum_{i=0}^{t-1} \left(1-\frac{i}{t}\right) g_i.\\ \end{align*}

興味深いことに、綺麗に線形減衰する寄与度を導くことができましたね。Primal averaging や SGDScheduleFree に対しても同様の議論をおこなうと下図が得られます。



The Road Less Scheduled Figure 4|ある時点のスナップショット(総訓練時間 T\frac{1}{3}\frac{2}{3}、訓練終了時)にて、それまでの各時刻 t の勾配がパラメータ系列に対してどの程度寄与するかの模式図。上から順に、linear decay schedule、Polyak averaging、Primal averaging、および SGDScheduleFree。水色が勾配評価点 y_t、橙色がパラメータ平均点 x_t における勾配寄与度を表す。Linear decay scheduleはパラメータ平均化をしないので y_t に関する様子のみが描かれており、指定された訓練終端時刻 T に向かって勾配寄与度が線形減衰する様子が読み取れる。Polyak averagingは y_t=z_t において直接勾配降下するため、y_t は常に一定の勾配寄与を受ける(無限遠点を終端に線形減衰しているという見方もできる)が、パラメータ平均点 x_t は各時刻 t を動的な終端とするような線形減衰挙動を示す。Primal averagingは y_t=x_t であるため、どちらも各時刻 t を動的な終端とするような線形減衰挙動を示す(図では両者が重なって x_t だけが見えている)。最後に SGDScheduleFree であるが、x_t は前者2つのパラメータ平均化と同様に各時刻 t を動的な終端とするような線形減衰挙動を示す。y_tx_ttを終端とする傾き)と z_t (傾き0)の内挿であるため、t よりも先の未来を終端に見据えつつ、その終端が動的に伸びていく線形減衰挙動を示す。


ここから分かるのは、SGDScheduleFree が内部的に Linear decay scheduler に類する挙動を実現しているということです。学習率スケジューリングを陽に設定せずとも、自動的に。このような観点からも、ScheduleFree の ScheduleFree たる所以が垣間見えます。


余談ですが、このように優秀な収束性を持つ ScheduleFree は、実践上も非常に面白い挙動を示します。下図に示すように、既存のoptimizerとschedulerを組み合わせ、異なるハイパラチューニングをして複数実験した結果を集めると、その訓練の行先が ScheduleFree の訓練曲線にほぼ一致するのです。これを論文では「訓練時間と評価損失のパレート・フロンティア」と表現しています。



The Road Less Scheduled Figure 1|ScheduleFree(黒線)が、1回の試行で「訓練時間と評価損失のパレート・フロンティア」をおおよそなぞる挙動。左がSGD、右がAdamW。赤線はそれぞれのoptimzierにcosine schedulerを組み合わせ、ハイパーパラメータを変えて試行した複数の実験を表す。ScheduleFreeは、既存のoptimizerとschedulerの組み合わせが実現し得る「訓練時間と評価損失のトレードオフ」の曲線をなぞるようにしながら、任意の時間訓練し続けられるという特性が読み取れる。


多変量同時最適化において各変数の最適性バランスを取った解はパレート解と呼ばれ、一意に定まらないパレート解の集合が描く曲線(あるは一般に超曲面)はパレート・フロンティアと呼ばれますが、ここでは「より良い評価損失を得るためには十分な訓練終端時刻(とそれに依存して学習率を減少させるscheduler)が必要」という既存システムのトレードオフを指してこのように表現しているようです。図から分かるように、ScheduleFree はこのパレート・フロンティアを1回の試行で悠々と撫でていくかのようです。トレードオフにより実現されるあらゆるパレート解を辿りながら、好きなだけ長い時間訓練し続けられるわけですね。ここに、ScheduleFree が訓練終端時刻 T に依存しないことの強みが現れています。


メモリ効率の良い SGDScheduleFree の実装

話を SGDScheduleFree に戻しましょう。これを愚直に実装すると、x_t, y_t, z_t それぞれのために、合計でモデルパラメータの3倍相当のテンソルを保持する必要性が出てくることが容易にわかります。モメンタム付きSGDがパラメータとモメンタムの計2倍で済むことを考えると、このメモリ増加は放置したくありませんね。

ScheduleFree のリポジトリを参照すると、SGDScheduleFreeReference というクラスにこれまでの定式化を素直に反映した実装が組み込まれており、理論と実装の対応を丁寧に追うことができます。これはあくまで参考用実装で、使用すると先に述べたように余計なメモリ消費を招きます。一方、通常利用が想定される SGDScheduleFree クラスは、実装を効率化することでメモリ消費量をモデルパラメータの2倍、つまり通常のモメンタム付きSGDと同じレベルまで削減しています。その実装が何をしているかを式に起こすと、次のようになります。

\begin{align*} y_{t+1} &= (1-c_{t+1})\,y_t + c_{t+1} z_t + \gamma \bigl\{\beta (1-c_{t+1}) - 1\bigr\}\,g_t,\vphantom{\frac{1}{t}}\\ z_{t+1} &= z_t - \gamma g_t, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t)\vphantom{\frac{1}{t}},\\ c_{t+1} &= \frac{1}{t+1}. \end{align*}

つまり、x_t に関する計算を各時刻のパラメータ更新に暗に組み込んでしまって、保持すべきパラメータを y_t, z_t だけにしているわけですね。z_{t+1} の計算が後になっているのは、z_t を用いる y_{t+1} の計算を先にしてから z_{t+1} をinplaceに計算するほうが実装上は効率がいいためです。

実践の説明で「optimizer.eval()を呼んでくださいね」と話した意味が、ここに繋がります。つまり、optimizerが train mode のときはモデルパラメータとして y_t がセットされており、推論時はこれを x_t に切り替える必要があるということだったのですね。x_ty_tz_t から直ちに計算できるので、これを常に保持しておく必要はないわけです。これこそが eval mode への切り替えにおける内部挙動だったのでした。同様に x_tz_t から y_t も直ちに導けるので、推論用にセットされたパラメータを y_t に戻す処理としてoptimizer.train()は実装されています。

ちなみに、x_t を隠蔽するこの計算方式は論文で特に記載がないので、ScheduleFreeのリポジトリの中身を読んだ人の中には論文内のアルゴリズムと実際の実装との乖離に混乱する人もいるようです。後学のため、上記の定式化を導いてみましょう。といってもただ計算するだけです。

\begin{align*} y_{t+1} &= (1-\beta)z_{t+1}+\beta x_{t+1}\\ &= (1-\beta+\beta c_{t+1})(z_t-\gamma g_t)+\beta (1-c_{t+1}) x_t\\ &= (1-\beta+\beta c_{t+1})(z_t-\gamma g_t)+(1-c_{t+1}) \{y_t-(1-\beta)z_t \}\\ &= (1-c_{t+1})\,y_t + c_{t+1} z_t + \gamma \bigl\{\beta (1-c_{t+1}) - 1\bigr\}\,g_t.\ \ {}_{\blacksquare} \end{align*}


AdamWScheduleFree への拡張

SGDScheduleFree を AdamWScheduleFree に拡張するのはそんなに難しい話ではなく、AdamWにて提案された以下の2点、

  • 二次指数移動平均による適応的ステップサイズ
  • weight decay \lambda の適切な組み込み

および、warmup steps T_wz_t の更新に組み込んでしまうだけです。

\begin{align*} y_t &= (1-\beta_1)\,z_t + \beta_1\,x_t,\vphantom{\frac{1}{t}} \\ v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^2,\vphantom{\frac{1}{t}} \\ z_{t+1} &= z_t - \gamma_t \left\{\frac{g_t}{\sqrt{v_t}+\epsilon} + \lambda y_t\right\},\\ x_{t+1} &= (1-c_{t+1})\,x_t + c_{t+1} z_{t+1}, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t)\vphantom{\frac{1}{t}},\\ \gamma_t &= \gamma \sqrt{1-\beta_2^t} \min \left\{1, \frac{t}{T_w}\right\},\\ c_{t+1} &= \frac{\gamma_t^2}{\sum_{i=1}^t \gamma_i^2}. \end{align*}

このとき、適応的ステップサイズ \gamma_t を考慮して c_{t+1} も調整されていることがわかります。仮に \gamma_t=\mathrm{const}. であれば c_{t+1}=\frac{1}{t} となるため、SGDの場合は \gamma_t=\gamma と表せたことを踏まえれば、t に関する処理がひとつだけずれていることを除いて整合性が取れていることがわかります。

さて、このアルゴリズムは、ScheduleFree のリポジトリにて AdamWScheduleFreePaper として実装されています。わざわざPaperという接尾辞がついているのはなにやら不穏ですね。これは、著者らがその後、c_{t+1}\beta_2 に関する項を含めないほうが訓練が安定化するという結論に至ったことで、実際に使用が想定される AdamWScheduleFree で処理の細部を異なるものにしたためです。差別化のため命名を変えているのですね。AdamWScheduleFree は先に説明したメモリ消費を抑えるテクニックも適用されており、次のように表されます。


\begin{align*} v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^2, \vphantom{\frac{1}{t}}\\ y_{t+1} &= (1-c_{t+1})\,y_t + c_{t+1} z_t + \gamma_t \bigl\{\beta_1 (1-c_{t+1}) - 1\bigr\}\,G_t, \vphantom{\frac{1}{t}}\\ z_{t+1} &= z_t - \gamma_t G_t, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t), \vphantom{\frac{1}{t}}\\ G_t &= \frac{g_t}{\sqrt{\frac{v_t}{1-\beta_2^t}}+\epsilon} + \lambda y_t,\\ \gamma_t &= \gamma \min \left\{1, \frac{t}{T_w}\right\},\\ c_{t+1} &= \frac{\gamma_t^2}{\sum_{i=1}^t \gamma_i^2}. \end{align*}

以上で、AdamWScheduleFree が何をしているかを全て追いかけることができました。同時に、記事の導入で触れたように warmup steps T_w を指定する部分だけが嬉しくないハイパーパラメータとして残されたことが、より明確化されました。長かったですね。ここまでで理論パートの前半です。次は話の舞台を移し、まさにこのwarmupを亡き者にしてくれるRAdamの説明をしましょう。


RAdamを思い出す

RAdamは5年以上前に提案された手法なので、既に有用な日本語解説記事がいくつもあります。
https://qiita.com/omiita/items/d24568a835da6911b01e
https://acro-engineer.hatenablog.com/entry/2019/12/25/130000
https://nykergoto.hatenablog.jp/entry/2019/08/16/Adam_の学習係数の分散を考えた論文_RAdam_を読んだよ!

RAdamが何をしているかを簡潔に述べると、「二次指数移動平均係数 \beta_2 に基づいて学習率を緩やかに立ち上げる補正項 r_t を付加する」、これだけです。定式化は以下のようになります。

\begin{align*} \rho_{\infty} &= \frac{2}{1-\beta_2}-1,\\ \rho_t &= \rho_\infty - \frac{2t \beta_2^t}{1-\beta_2^t}, \\ r_t &= \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_\infty}{(\rho_\infty-4)(\rho_\infty-2)\rho_t}},\\ \gamma_t &= \begin{cases} \gamma \, r_t \, {\left(\sqrt{\frac{v_t}{1-\beta_2^t}}+\epsilon\right)}^{-1}\ \ &(\rho_t>4)\\ \gamma \ \ &(\rho_t\leq4) \end{cases}. \end{align*}

式を見ればわかるように、\rho_t として計算される値が4以上のときのみ r_t は実数値を持つため、ここを境に挙動が変わっています。具体的には、RAdamの提案では \rho_t \leq 4 のときは学習率を特に変えずSGD的に振る舞い、\rho_t > 4 以降は r_t に従って学習率を変えながらAdam的な振る舞いをします。例えば慣例的な \beta_2=0.999 の場合、\rho_tt\geq5\rho_t>4 を満たします。訓練のごく初期は適応的学習率が暴れやすいので二次指数移動平均を用いず、ある程度勾配の移動平均が暖機運転できてからAdam的な振る舞いに移行するといった具合です。

この \rho_t の境目に関しては、PyTorchの公式実装では5になっているなど、複数の派閥が存在する気配を感じます。というのも、SGDからAdamへの切り替え挙動が不連続なので、その切り替えによる不安定さを取り除くための暗黙的な了解なのかもしれません。一方、
https://qiita.com/T-STAR/items/b9593d64a1ccfb2e775f
にも言及があるように、初期のSGDフェーズはそもそもなくてもいいという見方もあります。言い換えれば、\beta_2=0.999 の例でいえば t\leq 4 の間はパラメータを全く更新せず、一次および二次指数移動平均の暖機運転のみをおこなうということですね。式にすると単に次のような意味です。

\begin{align*} \gamma_t &= \begin{cases} \gamma \, r_t \, {\left(\sqrt{\frac{v_t}{1-\beta_2^t}}+\epsilon\right)}^{-1}\ \ &(\rho_t>4)\\ 0 \ \ &(\rho_t\leq4) \end{cases}. \end{align*}

この定式化であれば、\gamma_t\rho_t=4 でも連続性を保ちます。


そして RAdamScheduleFree へ

ここまでで、AdamWScheduleFree の warmup steps T_w をRAdamで改善する準備が整いました。RAdamの \gamma_tv_t による補正項を含みますが、AdamWScheduleFree では G_t にこの項が含まれることに注意して組み合わせると、最終的に以下のような表式を得ることができます。


\begin{align*} v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^2, \vphantom{\frac{1}{t}}\\ y_{t+1} &= (1-c_{t+1})\,y_t + c_{t+1} z_t + \gamma_t \bigl\{\beta_1 (1-c_{t+1}) - 1\bigr\}\,G_t, \vphantom{\frac{1}{t}}\\ z_{t+1} &= z_t - \gamma_t G_t, \vphantom{\frac{1}{t}}\\ \mathrm{where}\ \ g_t &= \nabla f(y_t,\zeta_t), \vphantom{\frac{1}{t}}\\ G_t &= \frac{g_t}{\sqrt{\frac{v_t}{1-\beta_2^t}}+\epsilon} + \lambda y_t,\\ \rho_{\infty} &= \frac{2}{1-\beta_2}-1,\\ \rho_t &= \rho_\infty - \frac{2t \beta_2^t}{1-\beta_2^t}, \\ \gamma_t &= \begin{cases} \gamma \, \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_\infty}{(\rho_\infty-4)(\rho_\infty-2)\rho_t}} \ \ &(\rho_t>4)\\ 0 \ \ &(\rho_t\leq4) \end{cases},\\ c_{t+1} &= \begin{cases} \frac{\gamma_t^2}{\sum_{i=1}^t \gamma_i^2} \ \ &(\rho_t>4)\\ 0 \ \ &(\rho_t\leq4) \end{cases}. \end{align*}


\beta_2=0.999 における \gamma_t の動きは下図のようになります。\beta_2 に応じて自動決定されたペースで学習率が立ち上がり、緩やかに \gamma に漸近していくことがわかります。

前節でRAdamの実装にはいくつかの派閥が存在するという話をしましたが、今回の私の実装では \rho_t\leq 4 では学習率を0にするアプローチを取りました。c_{t+1} の計算を考えるとこれはほぼ自明な選択で、というのも、仮に訓練初期にSGD的に振る舞う方式を採用すると、最初の数ステップが \gamma_t=\gamma となり、c_{t+1} の初期に不当に高い値(本来であれば訓練後期にようやく \gamma_t が到達する値)が繰り返し加算されてしまうのです。これは ScheduleFree における望ましいパラメータ平均化を壊し、訓練の不安定化を招く可能性があります。

先述のように、最初の数ステップでSGD的な挙動をするか否かは全体の最適化においては瑣末な差異でしかなく、立ち上がりの安定性を担保することこそが最優先です。この判断に基づき、SGDフェーズを撤廃するような実装方式を取りました。この挙動は RAdamScheduleFree の初期化引数 silent_sgd_phase で制御しています。デフォルト値Trueのままにしておくのを推奨します。

以上、RAdamScheduleFree にまつわる理論背景説明、全て終了です。お疲れ様でした。


結び

本記事では、Adamに代わって新たに普段遣いしてほしいoptimizerとして RAdamScheduleFree を提案しました。概要に始まり、置き換えると何が嬉しいのかを実践例ベースで解説するとともに、内部挙動に興味のある方向けに要素技術の背景についても説明を施しました。

RAdamScheduleFree は汎用的で、強く、手軽です。本記事が、皆さんが取り組まれている様々な機械学習プロジェクトの開発効率をちょっとでも高めるエッセンスになれることを願い、筆を置きます。それではまた、いずれどこかで。


参考文献


脚注
  1. optimizerの提案史は適当に眺めるだけでも群雄割拠で、いまや古典的なものだけでも SGD、モメンタム付きSGD、NAG、AdaGrad、RMSProp、AdaDelta など数多く、Adam以降はもはや全容を把握するのは不可能なほど溢れています。Adamと地続きなものだけでも AdaMaxAdamWNAdamAMSGradRAdamAdaBoundAMSBoundAdaBelief、最近だとADOPTCautiousシリーズ などが挙げられますし、視野を広げれば SantaEveYellowFinLion、最近だと強力なものとして ShampooDistributed ShampooSOAP などが提案されています。 ↩︎

  2. 余談ですが、コードサンプルって MyModel や MyDataset などのトイクラスがどこからともなく現れるものが多いですよね。コンパクトさを重視するなら絶対そのほうがいいのですが、私は「その MyModel どこからimportしたんだよ」みたいな違和感で萎える性分なので、気持ち悪くない範囲であえて丁寧めに書いています。深層学習のコードってどうしても準備やお作法が結構ボリューミーですが、その辺を端折ると実践に当てはめる際のヒントとしては不十分になって迷わせてしまうこともあるので……。 ↩︎

  3. 余談ですが、これまではschedulerが直接 optimizer.param_groups[i]["lr"]を書き換えて学習率を操作することが多かったものの、ScheduleFreeでは"lr" 属性は変化せず、optimizer.param_groups[i]["scheduled_lr"]に読み取り用の値が格納される様式になっています。TensorboardやW&B等で学習率をモニタリングする際の属性は "scheduled_lr" を使ってくださいね。 ↩︎

  4. より詳しい話をすると、ScheduleFree シリーズにはそれぞれ ScheduleFreeClosure という亜種が実装されており、私も RAdamScheduleFreeClosure も含めて実装をしました。こちらはPyTorchのclosureを用いてoptimizerを動かす場合のクラスで、optimizer.step(closure)を呼ぶとこのmode切り替え相当の処理が内部で自動実行されるので、明示的にユーザーがtrainevalを呼ぶ必要はなくなっています。ただ、closureを使わない実装の方が一般的であることや、ScheduleFreeClosure は ScheduleFree と比べごく僅かに訓練効率が劣ること、本文に ScheduleFreeClosure についての言及を含めると混乱を招く可能性があることから、脚注にて補足しました。関連の記述は本家のHow to Useをご覧ください。 ↩︎

  5. これは、optimizerのmode切り替えによりモデルパラメータ自体が書き換えられるためです。 ↩︎

  6. 仕様の補足ですが、ScheduleFree 系のoptimizerは初期化時には eval mode になっています。eval mode の時にoptimizer.eval()を呼んだり、逆に train mode の時にoptimizer.train()を呼んでも何も起きないので、重ねがけによる副作用を心配する必要はありません。 ↩︎

  7. どこに補足するか迷ったのでこちらに。RAdamScheduleFree はAdamWと同様に decoupled weight decay を採用しているので、必要があれば weight decay を適用しても齟齬なく動作します。命名には陽に含めていませんが、実態としては RAdamWScheduleFree と思っていただいても問題ありません。 ↩︎

  8. 組み合わせるにあたってちょっとした工夫はあり、それについては本記事の終盤で紹介しています。ただ、RAdamと ScheduleFree という巨人の肩に乗った上でささやかに口笛を吹いたくらいのものなので、「組み合わせるというものでしかなく」という表現でも特に差し支えはありません。 ↩︎

  9. 背中を押してなんの得があるのだと思われるかもなぁと書いていて思いましたが、自分が使っていて効果が実感できたものを純粋に人にもお勧めしたいときってありますよね。そんな感じです。あとは、業界基準が更新されることが巡り巡って将来の自分のためにもなるかなという淡い感覚と、長らく塗り替えられなかった『optimizerのデファクト・スタンダードに関する全体意識』を刷新する一助になったら超面白いな、という気持ちもあるかもしれません。実際、Adamからの乗り換えを阻む主要因は、個々のoptimizerの性能云々よりもむしろ、乗り換えの面倒さだと思っています。その面倒さを極力排除した上で、既存の面倒さもついでに消し去りますよーみたいなお膳立てができればこの記事としてはいいのかなと考えて執筆しました。 ↩︎

  10. PR averaging が1990年前後に提案された手法であるのに対し、Primal averaging は2015年~2020年頃に整理されたものなので、近年の深層学習ブームに照らして再考されたパラメータ平均化手法と言えるかもしれません。 ↩︎

  11. 実際には無限に長く訓練することはできませんが、浮動小数点の表現精度の限界で実質 x_t がほとんど動かなくなる時点は来そうです。 ↩︎

DeNA Engineers

Discussion