TransformersでoptimizerのMuonを使う

最近、深層学習のOptimizerでMuon (MomentUm Orthogonalized by Newton-Schulz)が注目されています。発音はよくわかりませんが「ミューオン」と呼んでいます。
言語モデルで注目をされていますが、言語タスクに限定せず、KaggleであればtascjさんがYale/UNC-CH - Geophysical Waveform InversionコンペでMuonを用いて3位入賞されてます。
この記事では、Muonの理論については紹介せず、簡単に提案元の紹介とTransformersで使用するときのコード例について紹介します。
Muonについて
Muon: An optimizer for hidden layers in neural networksで提案されています。具体的な定義や設計思想はこちらを読んでください。実装はKellerJordan/Muonに公開されています。
簡単にこのOptimizerの特徴としては、SGD-momentumによる更新行列をNewton-Schulz反復法で近似的に直交化することにあります。なぜ直行化するといいのかは、Why is it good to orthogonalize the update?以降を読んでください。
この直行化により、2階以上のテンソルのパラメータにのみ最適化が可能です。スカラーや1階のテンソルのパラメータに対してはAdamWなど標準的なOptimizerを使うことが推奨されています。
Muonをtransformersで使う
Muon使用することはシンプルです(bf16が使えるAmpere以降のGPUでないといけないかも?)。
まずKellerJordan/Muonをインストールします。
pip install git+https://github.com/KellerJordan/Muon
uvであれば以下でインストールします。
uv add git+https://github.com/KellerJordan/Muon
以下でtorch.optim.Optimizer
インスタンスを作成します。クラスはMuonWithAuxAdam
を使用します。なおMuonWithAuxAdam
は分散並列環境で実行することを前提としているため、シングルGPUでも実行時にtorchrun
などで実行するといいです。notebookなどでシングルGPUで実行する場合はSingleDeviceMuonWithAuxAdam
を使用してください。
from typing import cast
import torch
from muon import MuonWithAuxAdam
def build_muon(
model: torch.nn.Module,
lr_muon: float = 2e-2,
lr_aux: float = 3e-4,
beta1: float = 0.9,
beta2: float = 0.95,
weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
hidden_weights: list[torch.nn.Parameter] = []
hidden_gains_biases: list[torch.nn.Parameter] = []
for _, p in model.named_parameters():
if not p.requires_grad:
continue
if p.ndim >= 2:
hidden_weights.append(p)
else:
hidden_gains_biases.append(p)
return cast(
torch.optim.Optimizer,
MuonWithAuxAdam(
param_groups=[
dict(
params=hidden_weights,
use_muon=True,
lr=lr_muon,
weight_decay=weight_decay,
),
dict(
params=hidden_gains_biases,
use_muon=False,
lr=lr_aux,
betas=(beta1, beta2),
weight_decay=weight_decay,
),
]
),
)
パラメータの階数が2階以上の場合とそれ未満の場合でparam_groups
を分ける必要があるのが注意です。
transformersの場合は、Trainerクラスのoptimizers
という引数にタプルで渡します。独自のschedulerを使う場合はNone
を変更します。
trainer = Trainer(
...
optimizers=(optimizer, None),
)
簡単な実験結果
簡単に文章の二値分類タスクでOptimizerでベンチマークします。
データセットはGLUEのsst2を使用します。モデルはModernBERT-baseを使用します。
AdamWとMuonを比較します。各パラメータは以下で実行します。
- batch_size = 32
- AdamW:
- lr = 5e-5
- beta1 = 0.9
- beta2 = 0.99
- weight_decay = 1e-2
- Muon
- lr (use_muon=True) = 1e-3
- weight_decay (use_muon=True) = 1e-2
- lr (use_muon= False) = 5e-5
- weight_decay (use_muon=False) = 1e-2
- beta1 (use_muon= False) = 0.9
- beta2 (use_muon= False) = 0.99
ここでbeta2だけ変更して実験を回した結果が以下です。
train 1epoch時のloss
validation 各epochのAUC
パラメータについては広い範囲で探索したわけではないので、正確なスコアではないことに気をつけてください(このベンチマークコードについては後で公開すると思います、多分)。
感想
Muon自体はデフォルトのパラメータが lr (use_muon=True) = 2e-2
ですが、これはBERTのfinetuneでは小さくした方が学習が安定しました(同様にlr (use_muon=False) = 5e-5
と小さくしています)。
あまり実験は回せていませんが、AdamWよりも精度がかなり高くなるのは面白かったです。変更も今後初手でOptimizerをAdamWとMuon2つ試そうと思います。GPUメモリもMuonの方が使用割合が少ないのもいいですね。
p.ndim >= 2
でMuonかAdamWに最適化するか分けるときに、各レイヤーごとに調整すると性能が変わりました。あとで調べてみるとEmpirical considerationsに言及されていました。
以下の画像で橙色がp.ndim >= 2でmuon, p.ndim < 2でadamの普通の実装で、embedding layerをAdamWで最適化するようにした結果は黄緑です。黄緑の方がlossを落ちており、validation scoreも改善していました(Xでのポスト)。
実装としては以下のようにシンプルに入れることができます。モデルによってこのlayer名は変わるので、適宜確認してください。
def _is_nonhidden_param(name: str) -> bool:
n = name.lower()
# Typical HF module names for embeddings and classifier/LM heads
keywords = [
"embed", # embeddings, embed_tokens
"embedding",
"embeddings",
"lm_head",
"classifier",
"cls.",
"score", # e.g., RobertaForSequenceClassification.classifier.out_proj
]
return any(k in n for k in keywords)
def build_muon(
model: torch.nn.Module,
lr_muon: float = 2e-2,
lr_aux: float = 3e-4,
beta1: float = 0.9,
beta2: float = 0.95,
weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
hidden_weights: list[torch.nn.Parameter] = []
hidden_gains_biases: list[torch.nn.Parameter] = []
nonhidden_params: list[torch.nn.Parameter] = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if _is_nonhidden_param(name):
nonhidden_params.append(p)
elif p.ndim >= 2:
hidden_weights.append(p)
else:
hidden_gains_biases.append(p)
return cast(
torch.optim.Optimizer,
MuonWithAuxAdam(
param_groups=[
dict(
params=hidden_weights,
use_muon=True,
lr=lr_muon,
weight_decay=weight_decay,
),
dict(
params=[*hidden_gains_biases, *nonhidden_params],
use_muon=False,
lr=lr_aux,
betas=(beta1, beta2),
weight_decay=weight_decay,
),
]
),
)
良いMuonライフを。

雑な実験だけどlr (use_muon=True)
を変えて実験すると、lr = 1e-3
が良さそう。

muon側のlrを1e-3
に固定して、AdamW側のlr (use_muon=False)
を変えても精度はあまり変わらない。