⚠️ FP8トレーニングの落とし穴?Delayed ScalingとDynamic Scalingのリアルな話
前回はFP8の基礎についてざっくり解説しましたが、今回はより実践的な話に踏み込んで、特に**スケーリング方式(Delayed vs Dynamic)**の話を深掘りしていきます。
大規模言語モデル(LLM)をFP8でトレーニングするとき、「スケーリングのやり方」ひとつで学習が不安定になったり、性能がガタ落ちしたりすることも…。
🔁 Delayed Scalingとは?
✔ 特徴
Delayed Scalingは、「過去の最大値」を使ってスケーリングする手法です。
history_len = 16
scale_fn_name = "max"
のように、過去16回分のamax(absolute max)を保存し、その中の最大値をスケールに使う、という感じです。
👍 メリット
- 一度計算したamaxを使い回せるので速い
- メモリアクセスが少ないので軽い
👎 デメリット
- 過去の値でスケールするので、今のテンソルに合ってない場合がある
- その結果、一部の値が「FP8で表現できる最大値」で**バッサリ切られる(clamp)**ことに…
これ、実は地味にヤバいです。
というのも、大規模モデルの中ではたった1個の超巨大なactivationが重要な情報を持っている場合があり、それが失われると**学習が暴走(loss spike)**してしまうこともあるんです。
⚡ Dynamic Scalingとは?
Dynamic Scalingは、今のテンソルから直接amaxを計算してスケールを決めるやり方です。
amax = torch.max(torch.abs(tensor))
scale = fp8_max / amax
というように、その場でスケール値を決定します。
👍 メリット
- 常に最新のスケールで最適に変換できる
- clampのリスクが低い → 学習が安定しやすい
👎 デメリット
- 毎回amaxを計算するため、追加のメモリアクセスが必要
- トレーニングが少し遅くなる
🧪 実際どっちが良いの?
これは一長一短です。
研究レベルでの議論もありますが、ざっくりまとめると:
モデル規模 | 推奨スケーリング |
---|---|
小〜中規模(〜200B tokens) | Delayed Scaling(速くて十分安定) |
大規模(1T tokens〜, 継続学習など) | Dynamic Scaling(精度優先) |
🔥 Delayed Scalingの「落とし穴」
- 過去のamaxが古すぎてclamp地獄になる
- 特に「突然大きな値が出た」ときに対応できない
- 実際、torchaoチームも「実用例が少ない」として、Delayed Scalingの廃止を検討しているとのこと
🧩 clampが与える影響をどう評価するか?
面白い課題です。
たとえば:
- 変換前の値と、clamp後の値との差分を記録
- loss spikeとの相関を調べる
- 一部だけDynamic Scalingに切り替えて影響を比較する
などの方法で、どの程度clampが学習を不安定化させるかを調べることができます。
研究的にもホットな話題なので、「一緒に調べてみたい!」という方がいたらぜひコラボしたいですね!
🛠 PyTorch実装の注意点
PyTorchでFP8に変換するとき、こんなコードがよく出てきます:
tensor_scaled = tensor.to(torch.float32) * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
ここでsaturated
とあるように、FP8の範囲に収まるようにclampする処理が入ってるんですね。
もしこれを入れずにそのまま.to(float8_dtype)
すると、PyTorch側の挙動が保証されず、思わぬオーバーフローが起こる可能性も…。
🔄 まとめ:FP8は「速いけど、繊細」な武器
- FP8は計算・転送が超高速だけど、扱いを間違えると精度が落ちる
- スケーリング戦略は、モデルサイズや目的に応じて選ぶべし
- clampによる情報損失は、時に学習全体を壊す可能性もある
- PyTorch実装では、
to_fp8_saturated
やfloat32キャストの扱いに注意!
次回は、FP8 matmulの詳細実装や、PyTorchの数値精度まわりの罠について解説していこうと思います!
Discussion