Closed7

mamba-ssm から cuda を取り除けるか?

ピン留めされたアイテム
yuji96yuji96

公式実装に pytorch 版の selective_scan_ref が用意されていた。

https://github.com/state-spaces/mamba/issues/89#issuecomment-1873167894

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/ops/selective_scan_interface.py#L86

(どこからも参照されてないから存在意義が分からなかったけどそういうことなのか)

yuji96yuji96

取り敢えず、build はパスできる

git clone git@github.com:state-spaces/mamba.git
cd mamba
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE pip install -e .

triton は linux のみ対応なので mac でやるなら setup.py から消す。
https://github.com/openai/triton?tab=readme-ov-file#compatibility

こんな感じで triton がなければ torch に切り替えてくれるはず
https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/modules/mamba_simple.py#L20-L28

causal_conv1d もこの目的だと要らないか

yuji96yuji96

色々削ってて気づいたこと

悲報: torch に RMSNorm がない。みんな自作してるのか?

yuji96yuji96

ココはどう足掻いても cuda になってしまうので無理そう

https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/mamba_ssm/ops/selective_scan_interface.py#L37C16-L37C16

元々存在は知ってたけど、解説だけではなく

Equivalent numerical output as official implementation for both forward and backward pass

とのことなのでこれを使えばいいだけの話だったかも
https://github.com/johnma2006/mamba-minimal

このスクラップは2024/01/18にクローズされました