Zenn
🤖

🚀 FP8ってなに? 低精度で爆速トレーニングするための技術をやさしく解説!

2025/03/22に公開

高校生でもわかるようにFP8について説明していきましょう。

🔢 そもそもFP8ってなに?

FP8とは「8ビットの浮動小数点数」のこと。最近のNVIDIA Hopper世代のGPUからサポートされた新しい数値フォーマットで、超ざっくり言うと「精度はちょっと落ちるけど、その分めちゃくちゃ速くなる&省メモリ」な技術です。

FP8には2つの種類があります:

  • E4M3(指数4ビット、仮数3ビット)
  • E5M2(指数5ビット、仮数2ビット)

それぞれ、使える数の範囲や細かさが違うんです。

📉 E4M3ってこんな感じ

  • 最大値:約448
  • 最小値:約2^-9(subnormalを使えば)

📈 E5M2はもっと広い範囲が扱える

  • 最大値:約57,344
  • 最小値:約2^-14

ただし、精度が粗くなるので、E5M2ばっかり使えばいいってわけでもないんです。

💡 なんでFP8を使うの?

3つの大きな理由があります:

  1. モデルを軽くする → ストレージやメモリの使用量が減る!
  2. 通信を速くする → GPU内のデータ転送が速くなる!
  3. 計算が速くなる → 特にFP8 Tensor Coreを活用できる!

特に大規模なLLMをトレーニングするときは③がめちゃくちゃ重要です!

📏 FP8を使うときに必要な「スケーリング」

FP8って、取り扱える数の範囲が狭いんです。
そのまま普通の数(例えばFP32)をFP8に変換しようとすると、オーバーフローしたり、精度がすごく落ちたりします。

そこで、**スケーリング(拡大縮小)**という工夫をします。

たとえば…

FP8の範囲に収まるように、元のテンソルの最大値に合わせて「全体を小さくする」みたいなイメージです。これをしないと、情報が欠落してしまいます。

スケーリングの粒度(Granularity)

  • Tensor単位で1つのスケール値 → シンプルだけど、外れ値の影響を受けやすい
  • 行ごと・列ごとにスケール値を変える → 精度は高くなるが実装が大変

🧠 FP8トレーニングで使われる2つの方法

Delayed Scaling

  • 過去のデータの最大値を使ってスケールを決める
  • メリット:高速で余分な計算が少ない
  • デメリット:最新の情報じゃないから、時々「数値の切り捨て」が起きてしまう

Dynamic Scaling

  • その時点での最大値からスケールを決める
  • メリット:精度が高い
  • デメリット:少し遅くなる(毎回amaxを計算する必要がある)

✏️ 実際にどうやって使われてるの?

TransformerのMLP層などで、行列積(matmul)の処理に使われています。

流れとしては、

  1. 重み(weight)と活性化(activation)をFP8に変換
  2. reshapeして2次元に
  3. FP8のままmatmul
  4. 結果をまた元の形に戻す

というような順番です。

🤔 PyTorchではFP8ってどう扱われてる?

PyTorchでは、入力と出力の数値精度を分けることができないという制限があります。

たとえば:

mat1_fp8 = mat1.to(torch.float8_e4m3fn)
mat2_fp8 = mat2.to(torch.float8_e4m3fn)
res_fp8 = torch.mm(mat1_fp8, mat2_fp8)

こうやって計算すると、出力も自動的にFP8になります。

本来は「FP8で計算して、出力はBF16で受け取りたい」みたいなこともやりたいのに、それができないのが現状の課題の1つです。

✨ おわりに

今回はFP8の基礎の基礎をざっくり解説しました。

低精度って「精度が悪くなるからやだな〜」って思われがちですが、うまく使えば速く、軽く、強くなるとても重要な技術なんです。

Discussion

ログインするとコメントできます