Open3

TransformersでoptimizerのMuonを使う

colum2131colum2131

最近、深層学習のOptimizerでMuon (MomentUm Orthogonalized by Newton-Schulz)が注目されています。発音はよくわかりませんが「ミューオン」と呼んでいます。

言語モデルで注目をされていますが、言語タスクに限定せず、KaggleであればtascjさんがYale/UNC-CH - Geophysical Waveform InversionコンペでMuonを用いて3位入賞されてます。

この記事では、Muonの理論については紹介せず、簡単に提案元の紹介とTransformersで使用するときのコード例について紹介します。

Muonについて

Muon: An optimizer for hidden layers in neural networksで提案されています。具体的な定義や設計思想はこちらを読んでください。実装はKellerJordan/Muonに公開されています。

簡単にこのOptimizerの特徴としては、SGD-momentumによる更新行列をNewton-Schulz反復法で近似的に直交化することにあります。なぜ直行化するといいのかは、Why is it good to orthogonalize the update?以降を読んでください。

この直行化により、2階以上のテンソルのパラメータにのみ最適化が可能です。スカラーや1階のテンソルのパラメータに対してはAdamWなど標準的なOptimizerを使うことが推奨されています。

Muonをtransformersで使う

Muon使用することはシンプルです(bf16が使えるAmpere以降のGPUでないといけないかも?)。

まずKellerJordan/Muonをインストールします。

pip install git+https://github.com/KellerJordan/Muon

uvであれば以下でインストールします。

uv add git+https://github.com/KellerJordan/Muon

以下でtorch.optim.Optimizerインスタンスを作成します。クラスはMuonWithAuxAdamを使用します。なおMuonWithAuxAdamは分散並列環境で実行することを前提としているため、シングルGPUでも実行時にtorchrunなどで実行するといいです。notebookなどでシングルGPUで実行する場合はSingleDeviceMuonWithAuxAdamを使用してください。

from typing import cast

import torch
from muon import MuonWithAuxAdam


def build_muon(
    model: torch.nn.Module,
    lr_muon: float = 2e-2,
    lr_aux: float = 3e-4,
    beta1: float = 0.9,
    beta2: float = 0.95,
    weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
    hidden_weights: list[torch.nn.Parameter] = []
    hidden_gains_biases: list[torch.nn.Parameter] = []
    for _, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.ndim >= 2:
            hidden_weights.append(p)
        else:
            hidden_gains_biases.append(p)

    return cast(
        torch.optim.Optimizer,
        MuonWithAuxAdam(
            param_groups=[
                dict(
                    params=hidden_weights,
                    use_muon=True,
                    lr=lr_muon,
                    weight_decay=weight_decay,
                ),
                dict(
                    params=hidden_gains_biases,
                    use_muon=False,
                    lr=lr_aux,
                    betas=(beta1, beta2),
                    weight_decay=weight_decay,
                ),
            ]
        ),
    )

パラメータの階数が2階以上の場合とそれ未満の場合でparam_groupsを分ける必要があるのが注意です。

transformersの場合は、Trainerクラスのoptimizersという引数にタプルで渡します。独自のschedulerを使う場合はNoneを変更します。

trainer = Trainer(
    ...
    optimizers=(optimizer, None),
)

簡単な実験結果

簡単に文章の二値分類タスクでOptimizerでベンチマークします。
データセットはGLUEのsst2を使用します。モデルはModernBERT-baseを使用します。

AdamWとMuonを比較します。各パラメータは以下で実行します。

  • batch_size = 32
  • AdamW:
    • lr = 5e-5
    • beta1 = 0.9
    • beta2 = 0.99
    • weight_decay = 1e-2
  • Muon
    • lr (use_muon=True) = 1e-3
    • weight_decay (use_muon=True) = 1e-2
    • lr (use_muon= False) = 5e-5
    • weight_decay (use_muon=False) = 1e-2
    • beta1 (use_muon= False) = 0.9
    • beta2 (use_muon= False) = 0.99

ここでbeta2だけ変更して実験を回した結果が以下です。

train 1epoch時のloss

validation 各epochのAUC

パラメータについては広い範囲で探索したわけではないので、正確なスコアではないことに気をつけてください(このベンチマークコードについては後で公開すると思います、多分)。

感想

Muon自体はデフォルトのパラメータが lr (use_muon=True) = 2e-2 ですが、これはBERTのfinetuneでは小さくした方が学習が安定しました(同様にlr (use_muon=False) = 5e-5と小さくしています)。

あまり実験は回せていませんが、AdamWよりも精度がかなり高くなるのは面白かったです。変更も今後初手でOptimizerをAdamWとMuon2つ試そうと思います。GPUメモリもMuonの方が使用割合が少ないのもいいですね。

p.ndim >= 2でMuonかAdamWに最適化するか分けるときに、各レイヤーごとに調整すると性能が変わりました。あとで調べてみるとEmpirical considerationsに言及されていました。
以下の画像で橙色がp.ndim >= 2でmuon, p.ndim < 2でadamの普通の実装で、embedding layerをAdamWで最適化するようにした結果は黄緑です。黄緑の方がlossを落ちており、validation scoreも改善していました(Xでのポスト)。

実装としては以下のようにシンプルに入れることができます。モデルによってこのlayer名は変わるので、適宜確認してください。

def _is_nonhidden_param(name: str) -> bool:
    n = name.lower()
    # Typical HF module names for embeddings and classifier/LM heads
    keywords = [
        "embed",  # embeddings, embed_tokens
        "embedding",
        "embeddings",
        "lm_head",
        "classifier",
        "cls.",
        "score",  # e.g., RobertaForSequenceClassification.classifier.out_proj
    ]
    return any(k in n for k in keywords)
def build_muon(
    model: torch.nn.Module,
    lr_muon: float = 2e-2,
    lr_aux: float = 3e-4,
    beta1: float = 0.9,
    beta2: float = 0.95,
    weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
    hidden_weights: list[torch.nn.Parameter] = []
    hidden_gains_biases: list[torch.nn.Parameter] = []
    nonhidden_params: list[torch.nn.Parameter] = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if _is_nonhidden_param(name):
            nonhidden_params.append(p)
        elif p.ndim >= 2:
            hidden_weights.append(p)
        else:
            hidden_gains_biases.append(p)

    return cast(
        torch.optim.Optimizer,
        MuonWithAuxAdam(
            param_groups=[
                dict(
                    params=hidden_weights,
                    use_muon=True,
                    lr=lr_muon,
                    weight_decay=weight_decay,
                ),
                dict(
                    params=[*hidden_gains_biases, *nonhidden_params],
                    use_muon=False,
                    lr=lr_aux,
                    betas=(beta1, beta2),
                    weight_decay=weight_decay,
                ),
            ]
        ),
    )

良いMuonライフを。

colum2131colum2131

雑な実験だけどlr (use_muon=True)を変えて実験すると、lr = 1e-3が良さそう。

colum2131colum2131

muon側のlrを1e-3に固定して、AdamW側のlr (use_muon=False)を変えても精度はあまり変わらない。