Closed7
mamba-ssm から cuda を取り除けるか?
ピン留めされたアイテム

公式実装に pytorch 版の selective_scan_ref
が用意されていた。
(どこからも参照されてないから存在意義が分からなかったけどそういうことなのか)

ちゃんと cuda kernel と一致するかのテストも実装されてる

好きなところから内部状態を取り出すためには cuda kernel だと困ることがある。

取り敢えず、build はパスできる
git clone git@github.com:state-spaces/mamba.git
cd mamba
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE pip install -e .
triton は linux のみ対応なので mac でやるなら setup.py から消す。
こんな感じで triton がなければ torch に切り替えてくれるはず
causal_conv1d もこの目的だと要らないか

色々削ってて気づいたこと
悲報: torch に RMSNorm がない。みんな自作してるのか?

T5, LLaMA ,Mistral はこれを使い回しているらしい。
mamba はココで実装してた。
ただし、fused_add_norm を false にしないと triton が必要になってしまう
(fused_add_norm ってなんぞ?)

ココはどう足掻いても cuda になってしまうので無理そう
元々存在は知ってたけど、解説だけではなく
Equivalent numerical output as official implementation for both forward and backward pass
とのことなのでこれを使えばいいだけの話だったかも
このスクラップは2024/01/18にクローズされました