🚀

🔍 FP8×行列演算(matmul)のリアル —— PyTorch実装の裏側を解説!

2025/03/22に公開

前回はFP8のスケーリング方式(Delayed vs Dynamic)について掘り下げました。今回は、実際にFP8を使って行列積(matmul)を計算する流れや、PyTorchの内部実装がどうなっているかを、開発者視点で見ていきます。


🧮 FP8でmatmulするってどういうこと?

TransformerのMLP層では、重みと活性化(activation)をかけ合わせる「行列積(matmul)」が頻繁に登場します。

通常は torch.mm()torch.matmul() を使いますが、FP8で演算するにはちょっとした前処理が必要です。


✅ matmulの流れ(torchao実装)

  1. weightを転置しておく
  2. weightをFP8にQuantize
  3. activationをFP8にQuantize
  4. matmulを実行
  5. 出力をreshapeして元の形に戻す

✏️ 例(pseudocode)

# Step 1: weightは(out, in)の形なので転置しておく
weight_t = self.weight.t()

# Step 2: weightをFP8化
scale = calculate_scale(weight)
weight_fp8 = to_fp8(weight, scale)

# Step 3: activationをFP8化(動的スケーリング)
input_fp8 = to_fp8_dynamic(input)

# Step 4: matmul(bs, h)x(h, h_ffn)
res = torch.mm(input_fp8, weight_fp8.t())

# Step 5: reshape
res = res.view(original_shape)

🎯 なぜreshapeが必要?

PyTorchの torch.mm()2次元のテンソルしか扱えません
でも、Transformerの入力って普通は (batch_size, seq_length, hidden_size) の3次元ですよね?

そこで、

(b, s, h)(b×s, h)

にreshapeしてmatmulを計算。終わったら元の形に戻す、というわけです。


🧠 数値精度と変換のワナ

ここで問題になるのが数値精度(dtype)の変換です。

💥 PyTorchの制限

PyTorchでは、FP8の入力同士で torch.mm() を行うと、出力も必ずFP8になります

a = mat1.to(torch.float8_e4m3fn)
b = mat2.to(torch.float8_e4m3fn)
res = torch.mm(a, b)
print(res.dtype)  # → float8_e4m3fn

一見自然に見えますが、実はこれが大きな落とし穴になります。


🧨 精度変換が挟まると何がマズい?

Transformerでは、Linear層の出力に対して**Residual Connection(元の入力との加算)**を行います。

このとき、こんな処理になります:

FP32 acc → FP8 → BF16 → 足し算

もしPyTorch側で「FP8出力固定」になっていると、不要なFP8変換が挟まって

  • 精度が落ちる(量子化誤差が入る)
  • upcastが発生してコストが増える

というデメリットがあります。


🤯 float8_e4m3fnとfloat8_e4m3fnuzの違い

PyTorchには複数のFP8型があります:

  • float8_e4m3fn
  • float8_e4m3fnuz
  • float8_e5m2
  • float8_e5m2fnuz

このうち fnuz が付いているものは、Exponent Biasが異なるなど、やや特殊な挙動をします。

つまり「同じE4M3に見えて微妙に仕様が違う」ので注意が必要です。

現状(PyTorch 2.5.1)では、ユーザーが出力dtypeを自由に指定する方法がありません。


🚧 これから解決したい課題

  • FP8出力から直接BF16に変換したい(不要なFP8経由を避けたい)
  • matmulの入力・出力dtypeを個別に指定できるようにしたい
  • PyTorchにおけるfloat8型の扱いをもっと柔軟に!

これらの機能が拡張されれば、より効率的で高精度なFP8トレーニングが可能になります


🔚 おわりに

今回は、FP8を使った行列演算の流れや、PyTorch実装上の制約・注意点について紹介しました。

FP8は「新しくて強い」技術ですが、まだまだ発展途上。
ライブラリの実装やハードウェア仕様を深く理解することで、本当の強さを引き出すことができるんです。


次回は:

✅ FP8と混合精度(mixed precision)学習の相性
✅ Residual AddやLayerNormでのdtype戦略
✅ トレーニング安定化のための実験ノート紹介

Discussion