💥

ロススパイクをXXす

に公開

Transformerの学習していると出てくるloss spikeなんとかころしたいですよね。
そんなあなたへの論文です。
ZClip: Adaptive Spike Mitigation for LLM Pre-Training

“ZClip: Adaptive Spike Mitigation for LLM Pre-Training” 解説

arXiv: 2504.02507, 3 Apr 2025)


1 Introduction

巨大言語モデル(LLM)の事前学習では loss spike(突発的な損失増大)が深刻な問題になる。

  • 540 B パラメータモデルで 20 回以上のスパイクが報告され、各回で数百バッチを巻き戻したという実例 [6]。
  • 65 B モデルではスパイク対処だけで 129 MWh もの追加電力が必要だった [7]。
    理論的には Jacobian のスペクトルノルム暴走が原因と示唆されているが(Spike-No-More [8])、学習中に動的に対処できる手法は限られている。そこで著者らは ZClip を提案する。ZClip は勾配の \ell_2 ノルム分布を逐次推定し、z-score による外れ値検出動的クリッピングでスパイクを抑制する。

2 Gradient Clipping Methods

2.1 Fixed-threshold clipping

勾配ベクトル g_t(パラメータ \theta_t に対する損失勾配)の総ノルム

\lVert g_t\rVert_2 \;=\;\sqrt{\sum_{i=1}^{N} g_{ti}^2}\tag{1}

が閾値 c を超えたときだけスケールダウンする

g_t^{\ast}\;=\; \begin{cases} g_t,&\lVert g_t\rVert_2\le c\\[4pt] \dfrac{c}{\lVert g_t\rVert_2}\,g_t,&\lVert g_t\rVert_2>c \end{cases}\tag{2}

静的しきい値 c

  1. トークン長・学習率・バッチサイズなどで 最適値が変動
  2. 訓練後期に 過小クリップ→スパイクを許す
    という欠点がある。

2.2 AutoClip

過去 k ステップのノルム履歴からパーセンタイル p を計算して閾値とする手法 [14]。LLM では履歴保管コストと 履歴中の外れ値汚染が問題で、事前学習数百万ステップには不向き。

2.3 ZClip

2.3.1 Spike detection

スカラー勾配ノルム
$$
g_t=\lVert g_t\rVert_2\tag{3}
$$

を用い、指数移動平均 (EMA) で

\mu_t = \alpha \mu_{t-1} + (1-\alpha)g_t, \qquad \sigma_t =\sqrt{\alpha\sigma_{t-1}^{2} + (1-\alpha)(g_t-\mu_t)^2}\tag{4–5}
z_t=\frac{g_t-\mu_t}{\sigma_t}\tag{6}

が閾値 z_{\text{thres}} を超えればスパイクと判定。

2.3.2 Gradient adjustment

g_t^{\ast}= \begin{cases} g_t,&z_t\le z_{\text{thres}}\\[4pt] \dfrac{\mu_t+z_t^{\ast}\sigma_t}{\lVert g_t\rVert_2}\,g_t,&z_t> z_{\text{thres}} \end{cases}\tag{7}
z_t^{\ast}= \xi(z_t)\tag{8}

\xi の候補

方式 特徴
clip-to-mean z_t^{\ast}=0 不連続・最強クリップ
clip-to-max z_t^{\ast}=z_{\text{thres}} 連続・弱め
reciprocal z_t^{\ast}=z_{\text{thres}}^{2}/z_t 連続・外れ値ほど強く抑制

著者は reciprocal を推奨。

2.3.3 Warm-up初期化

最初 N_w ステップはクリップせず

\mu_{N_w}=\frac1{N_w}\sum_{t=1}^{N_w}g_t,\qquad \sigma_{N_w}=\sqrt{\frac1{N_w}\sum_{t=1}^{N_w}(g_t-\mu_{N_w})^2}\tag{9}

で統計を初期化。

2.3.4 統計更新時のスパイク処理

スパイク時は クリップ後ノルム で統計を更新

g_t^{\text{update}} = \begin{cases} g_t,&z_t\le z_{\text{thres}}\\[2pt] g_t^{\ast},&z_t> z_{\text{thres}} \end{cases}\tag{10}

2.3.5 Algorithm 1(実装骨子)

  1. Warm-up(\mu,\sigma) を計算
  2. 各 step で勾配ノルム→z-score
  3. z_t>z_{\text{thres}} なら (7) で縮小
  4. PyTorch の clip_grad_norm_ に渡しパラメータ更新
  5. (10) に従い EMA を更新

3 Experiment Setup

項目 設定
モデル LLaMA 1 B(16 層 / hidden 2048 / RMSNorm / SwiGLU)
データ SmolLM 50 B token(FineWebEdu, Cosmopedia-V2, Python-Edu)
計算環境 4 ノード × 8 H100, FSDP
最大学習率 1\times10^{-4}5\times10^{-3}(線形 warm-up+cosine decay)
ZClip HP \alpha=0.97, z_{\text{thres}}=2.5(高 LR では 2.0)

4 Results and Analysis

4.1 High-LR regime (3{\times}10^{-3})

  • 固定クリップ (c=1.0) は発散。
  • ZClip は同じ LR で安定し、検証 loss を 35 % 早く達成
  • HellaSwag / WinoGrande でも大幅改善(表 2)。
  • ただし 5{\times}10^{-3} では両者とも発散し、ZClip だけで LR を無限に上げられるわけではない

4.2 Low-/Mid-LR regime (10^{-3}10^{-4})

  • スパイク発生は AutoClip と ZClip が 完全抑制 (0 回)
  • 下流性能は ZClip が最良 (HellaSwag 49.3 %, WinoGrande 54.9 %)。
  • クリップ後ノルム分布が滑らかになり、過剰正則化なく情報勾配を保持

5 Conclusion

ZClip は

  1. 局所ガウス仮定+EMA で (\mu_t,\,\sigma_t) を軽量推定
  2. z-score 外れ値 を reciprocal クリップで滑らかに抑制
    という 2 段構えで loss spike を抑える。1 B モデル実験では
  • 固定閾値より 広い LR 空間 を安全に探索
  • トークン 18.6 B 削減で同等損失に到達
    を実証した。将来は 7 B–70 B 規模や RL への適用が課題として挙げられている。

6 Appendix ハイライト

6.2 Percentile parameterization
AutoClip に似せて「上位 p %」を保ちたい場合、標準正規分布の逆 CDF で

z_{\text{thres}} = \Phi^{-1}(p)

を計算し ZClip に渡すことで同等挙動を得られる(例:p=0.99\Rightarrow z_{\text{thres}}\simeq2.326)。

6.3 クリップ戦略比較

方式 スパイク回数 下流性能
max 1
mean 0
reciprocal 0

6.4 ノルム正規性検証
135 ステップ窓でノルム分布をフィットし、

  • 早期は右裾重 (skew)
  • 中盤では正規近似が良好
    → 短窓 EMA と z-score の有効性を支持。

研究・実装のポイント

  1. EMA の \alphaz_{\text{thres}} はほぼ転移可能
    著者の sweep では \alpha\!=\!0.97, z_{\text{thres}}\!=\!2.5 が広範囲に妥当。

  2. Reciprocal クリップは連続性と攻撃性のトレードオフが良好
    \xi(z)=z_{\text{thres}}^{2}/zz\to\infty で平均値に収束し、極端外れ値を強く抑える。

  3. 実装はたった数行で既存 clip_grad_norm_ を包める
    追加状態は (\mu,\sigma) の 2 スカラーと warm-up カウンタのみ。

  4. Spike 検出と統計更新を分離
    スパイクを完全に統計から除外すると閾値が下がり過ぎ、
    “クリップ後ノルムで更新” がバランスを取る。


総括 — ZClip は「確率統計的クリッピング」という位置付けで、
  - 長期学習での スパイク再発防止
  - ハイパーパラメータ調整コスト低減
を同時に達成する実践的なソリューションである。Transformer 系の loss spike に悩む実務・研究のどちらにも即投入できるだろう。

Discussion