🔍 FP8×行列演算(matmul)のリアル —— PyTorch実装の裏側を解説!
前回はFP8のスケーリング方式(Delayed vs Dynamic)について掘り下げました。今回は、実際にFP8を使って行列積(matmul)を計算する流れや、PyTorchの内部実装がどうなっているかを、開発者視点で見ていきます。
🧮 FP8でmatmulするってどういうこと?
TransformerのMLP層では、重みと活性化(activation)をかけ合わせる「行列積(matmul)」が頻繁に登場します。
通常は torch.mm()
や torch.matmul()
を使いますが、FP8で演算するにはちょっとした前処理が必要です。
✅ matmulの流れ(torchao実装)
- weightを転置しておく
- weightをFP8にQuantize
- activationをFP8にQuantize
- matmulを実行
- 出力をreshapeして元の形に戻す
✏️ 例(pseudocode)
# Step 1: weightは(out, in)の形なので転置しておく
weight_t = self.weight.t()
# Step 2: weightをFP8化
scale = calculate_scale(weight)
weight_fp8 = to_fp8(weight, scale)
# Step 3: activationをFP8化(動的スケーリング)
input_fp8 = to_fp8_dynamic(input)
# Step 4: matmul(bs, h)x(h, h_ffn)
res = torch.mm(input_fp8, weight_fp8.t())
# Step 5: reshape
res = res.view(original_shape)
🎯 なぜreshapeが必要?
PyTorchの torch.mm()
は 2次元のテンソルしか扱えません。
でも、Transformerの入力って普通は (batch_size, seq_length, hidden_size)
の3次元ですよね?
そこで、
(b, s, h) → (b×s, h)
にreshapeしてmatmulを計算。終わったら元の形に戻す、というわけです。
🧠 数値精度と変換のワナ
ここで問題になるのが数値精度(dtype)の変換です。
💥 PyTorchの制限
PyTorchでは、FP8の入力同士で torch.mm()
を行うと、出力も必ずFP8になります。
a = mat1.to(torch.float8_e4m3fn)
b = mat2.to(torch.float8_e4m3fn)
res = torch.mm(a, b)
print(res.dtype) # → float8_e4m3fn
一見自然に見えますが、実はこれが大きな落とし穴になります。
🧨 精度変換が挟まると何がマズい?
Transformerでは、Linear層の出力に対して**Residual Connection(元の入力との加算)**を行います。
このとき、こんな処理になります:
FP32 acc → FP8 → BF16 → 足し算
もしPyTorch側で「FP8出力固定」になっていると、不要なFP8変換が挟まって:
- 精度が落ちる(量子化誤差が入る)
- upcastが発生してコストが増える
というデメリットがあります。
🤯 float8_e4m3fnとfloat8_e4m3fnuzの違い
PyTorchには複数のFP8型があります:
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
このうち fnuz
が付いているものは、Exponent Biasが異なるなど、やや特殊な挙動をします。
つまり「同じE4M3に見えて微妙に仕様が違う」ので注意が必要です。
現状(PyTorch 2.5.1)では、ユーザーが出力dtypeを自由に指定する方法がありません。
🚧 これから解決したい課題
- FP8出力から直接BF16に変換したい(不要なFP8経由を避けたい)
- matmulの入力・出力dtypeを個別に指定できるようにしたい
- PyTorchにおけるfloat8型の扱いをもっと柔軟に!
これらの機能が拡張されれば、より効率的で高精度なFP8トレーニングが可能になります。
🔚 おわりに
今回は、FP8を使った行列演算の流れや、PyTorch実装上の制約・注意点について紹介しました。
FP8は「新しくて強い」技術ですが、まだまだ発展途上。
ライブラリの実装やハードウェア仕様を深く理解することで、本当の強さを引き出すことができるんです。
次回は:
✅ FP8と混合精度(mixed precision)学習の相性
✅ Residual AddやLayerNormでのdtype戦略
✅ トレーニング安定化のための実験ノート紹介
Discussion