📘

⚠️ FP8トレーニングの落とし穴?Delaye​d ScalingとDynamic Scalingのリアルな話

2025/03/22に公開

前回はFP8の基礎についてざっくり解説しましたが、今回はより実践的な話に踏み込んで、特に**スケーリング方式(Delayed vs Dynamic)**の話を深掘りしていきます。

大規模言語モデル(LLM)をFP8でトレーニングするとき、「スケーリングのやり方」ひとつで学習が不安定になったり、性能がガタ落ちしたりすることも…。


🔁 Delayed Scalingとは?

✔ 特徴

Delayed Scalingは、「過去の最大値」を使ってスケーリングする手法です。

history_len = 16
scale_fn_name = "max"

のように、過去16回分のamax(absolute max)を保存し、その中の最大値をスケールに使う、という感じです。

👍 メリット

  • 一度計算したamaxを使い回せるので速い
  • メモリアクセスが少ないので軽い

👎 デメリット

  • 過去の値でスケールするので、今のテンソルに合ってない場合がある
  • その結果、一部の値が「FP8で表現できる最大値」で**バッサリ切られる(clamp)**ことに…

これ、実は地味にヤバいです。
というのも、大規模モデルの中ではたった1個の超巨大なactivationが重要な情報を持っている場合があり、それが失われると**学習が暴走(loss spike)**してしまうこともあるんです。


⚡ Dynamic Scalingとは?

Dynamic Scalingは、今のテンソルから直接amaxを計算してスケールを決めるやり方です。

amax = torch.max(torch.abs(tensor))
scale = fp8_max / amax

というように、その場でスケール値を決定します。

👍 メリット

  • 常に最新のスケールで最適に変換できる
  • clampのリスクが低い → 学習が安定しやすい

👎 デメリット

  • 毎回amaxを計算するため、追加のメモリアクセスが必要
  • トレーニングが少し遅くなる

🧪 実際どっちが良いの?

これは一長一短です。
研究レベルでの議論もありますが、ざっくりまとめると:

モデル規模 推奨スケーリング
小〜中規模(〜200B tokens) Delayed Scaling(速くて十分安定)
大規模(1T tokens〜, 継続学習など) Dynamic Scaling(精度優先)

🔥 Delayed Scalingの「落とし穴」

  • 過去のamaxが古すぎてclamp地獄になる
  • 特に「突然大きな値が出た」ときに対応できない
  • 実際、torchaoチームも「実用例が少ない」として、Delayed Scalingの廃止を検討しているとのこと

🧩 clampが与える影響をどう評価するか?

面白い課題です。

たとえば:

  • 変換前の値と、clamp後の値との差分を記録
  • loss spikeとの相関を調べる
  • 一部だけDynamic Scalingに切り替えて影響を比較する

などの方法で、どの程度clampが学習を不安定化させるかを調べることができます。

研究的にもホットな話題なので、「一緒に調べてみたい!」という方がいたらぜひコラボしたいですね!


🛠 PyTorch実装の注意点

PyTorchでFP8に変換するとき、こんなコードがよく出てきます:

tensor_scaled = tensor.to(torch.float32) * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

ここでsaturatedとあるように、FP8の範囲に収まるようにclampする処理が入ってるんですね。

もしこれを入れずにそのまま.to(float8_dtype)すると、PyTorch側の挙動が保証されず、思わぬオーバーフローが起こる可能性も…。


🔄 まとめ:FP8は「速いけど、繊細」な武器

  • FP8は計算・転送が超高速だけど、扱いを間違えると精度が落ちる
  • スケーリング戦略は、モデルサイズや目的に応じて選ぶべし
  • clampによる情報損失は、時に学習全体を壊す可能性もある
  • PyTorch実装では、to_fp8_saturatedやfloat32キャストの扱いに注意!

次回は、FP8 matmulの詳細実装や、PyTorchの数値精度まわりの罠について解説していこうと思います!

Discussion