Closed17

mamba-ssm の cuda 実装を(苦しみながら)追う

yuji96yuji96

cuda は2通りの呼び出され方をしている。

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/modules/mamba_simple.py#L145

が True ならこっち↓。conv とワンセットになってる。

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/modules/mamba_simple.py#L146-L160

False ならこっち↓。これは conv と selective scan が別々になっている。

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/modules/mamba_simple.py#L189-L200

前者が training 用で後者が inference 用なのか?

yuji96yuji96

z=None にすると silu を無視できるので普通にデバッグにも使える

yuji96yuji96

selective_scan_cuda.fwd という関数名で cuda を呼び出せるのは、ここで binding ということがされてるからっぽい

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/csrc/selective_scan/selective_scan.cpp#L494-L497

yuji96yuji96

すると selective_scan_fwd_cuda が呼び出される

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/csrc/selective_scan/selective_scan.cpp#L330

どういう経緯か分からんが、float32 だとこの selective_scan_fwd_cuda にたどり着きそう

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/csrc/selective_scan/selective_scan_fwd_fp32.cu#L9

コンパイルされる時に <float, float> に基づいてポリモーフィズム的なことがされているのかな?

yuji96yuji96

分からないとこ

論文にあるが実装で見つからない

  • B の離散化どこ?

A は確かにココにある

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/csrc/selective_scan/selective_scan_fwd_kernel.cuh#L216-L217

実装にあるが論文で見つからない

  • \Delta_t u_t なんて論文にあったか?

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/csrc/selective_scan/selective_scan_fwd_kernel.cuh#L157

yuji96yuji96

もしかして2つの理由は同じで、公式実装も mamba-minimal と同様に B の離散化を省略している…?(論文では省略については触れてなかったような)

https://github.com/johnma2006/mamba-minimal/issues/2

Q1. じゃあなんで公式実装と minimal の実行結果が ssm を通ると全く別物になってしまうんだ?
Q2. しかもなんで公式の重みを minimal でロードしても良い感じの文章を生成できるんだ?

yuji96yuji96

A1: やっぱりそうだ。minimal は加算され後が出力されるのに対して、official は残差接続の直前の2状態が別々に出力されるのか。(理由は fused_add_norm_fn で高速化したいからっぽい)

official:
https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/models/mixer_seq_simple.py#L154-L157

minimal:
https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L85-L86

つまり、全く別物どころかほとんど同じ。


A2: 最終出力を比較してみた

torch.manual_seed(0)
random_ids = torch.randint(model_official.config.vocab_size, size=(1, 100)).cuda()
out1 = model_official(random_ids).logits
out2 = model_minimal(random_ids)

for i in range(2, 9):
    print(i, torch.isclose(out1, out2, atol=10**-i).float().mean())
精度 close rate
1e-2 1
1e-3 0.9698
1e-4 0.95
1e-5 0.9372
1e-6 0.9345
1e-7 0.9342
1e-8 0.9342

こんだけ似てればたしかに大丈夫そう

このスクラップは3ヶ月前にクローズされました