mamba-ssm の cuda 実装を(苦しみながら)追う
cuda は2通りの呼び出され方をしている。
が True ならこっち↓。conv とワンセットになってる。
False ならこっち↓。これは conv と selective scan が別々になっている。
前者が training 用で後者が inference 用なのか?
True でも False でも結局は selective_scan_cuda.fwd
を呼び出す
Mamba の論文を読んだ時、肝は Δ が入力に依存することであって、B, C は固定で良くない?って思ったのだが、実装でもその余地があった
selective_scan_cuda.fwd
という関数名で cuda を呼び出せるのは、ここで binding ということがされてるからっぽい
そんなわけで指定された selective_scan_fwd
が呼び出される。たしかに引数が python が渡しているものと同じだ。
すると selective_scan_fwd_cuda
が呼び出される
どういう経緯か分からんが、float32 だとこの selective_scan_fwd_cuda
にたどり着きそう
コンパイルされる時に <float, float>
に基づいてポリモーフィズム的なことがされているのかな?
分からないとこ
論文にあるが実装で見つからない
- B の離散化どこ?
A は確かにココにある
実装にあるが論文で見つからない
-
なんて論文にあったか?\Delta_t u_t
もしかして2つの理由は同じで、公式実装も mamba-minimal と同様に B の離散化を省略している…?(論文では省略については触れてなかったような)
Q1. じゃあなんで公式実装と minimal の実行結果が ssm を通ると全く別物になってしまうんだ?
Q2. しかもなんで公式の重みを minimal でロードしても良い感じの文章を生成できるんだ?
A1. 自分の使い方が間違ってたかも。以下2つの出力は 1e-4 くらいの精度では一致した。
official:
minimal:
A1: やっぱりそうだ。minimal は加算され後が出力されるのに対して、official は残差接続の直前の2状態が別々に出力されるのか。(理由は fused_add_norm_fn
で高速化したいからっぽい)
official:
minimal:
つまり、全く別物どころかほとんど同じ。
A2: 最終出力を比較してみた
torch.manual_seed(0)
random_ids = torch.randint(model_official.config.vocab_size, size=(1, 100)).cuda()
out1 = model_official(random_ids).logits
out2 = model_minimal(random_ids)
for i in range(2, 9):
print(i, torch.isclose(out1, out2, atol=10**-i).float().mean())
精度 | close rate |
---|---|
1e-2 | 1 |
1e-3 | 0.9698 |
1e-4 | 0.95 |
1e-5 | 0.9372 |
1e-6 | 0.9345 |
1e-7 | 0.9342 |
1e-8 | 0.9342 |
こんだけ似てればたしかに大丈夫そう
公式実装の issue に行列Bの離散化を単純化したことが述べられていた
InclusiveScan
で計算しているらしい
オペレータからぽさ味は感じなくもない
この scan は cub ライブラリを呼び出していて、BLOCK_SCAN_WARP_SCANS
という手法を使っているらしい(全く分からん)