状態空間モデル 論文解説④「 Mamba 」
論文
前回は、少し脇道にそれて、S4系列とは別の著者の論文であるH3について解説しました。
今回は、いよいよ本命の Mamba になります。
タイトル
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
論文: https://arxiv.org/pdf/2312.00752
Github: https://github.com/state-spaces/mamba
解説ステップ
概要
- SSMにおけるパラメータA,B,C,dtを時間(入力)依存にした
- 時間依存による制限で、カーネル計算をやめた
- GPUの高速メモリを用いて高速な並列計算を実現した
- H3 と Gated MLP の構造を取り入れた
いったん S4D までの理論的な枠組みは忘れてください。Mamba では、状態空間モデル(SSM)をカーネル計算など行わずシンプルに再帰的な計算を並列スキャンで実現します。
SSMのおさらい
入力
※Mamba Layer では D_in と D_out の次元は一致します。論文に合わせて記号を変えています。
「
行列
S4Dからの変更点
行列A,B,C,\Delta の時間依存
S4D以前では、行列A,B,C は時間不変(Linear Time Invariance (LTI))でした。Mamba ではこのパラメータを**時間依存(入力依存)**形式に置き換える事で、より柔軟なパラメータを実現しています。
-
\Delta
SSMの離散化における時間間隔を表します。S4D以前は単一のパラメータでした。
それが今回入力 依存となり、それと同じ次元の変数を持ちます。\vec{x}
つまり、各データ・各時系列点・各潜在変数に対して1対1対応で を持ち、その重要度を係数として操作できるようになります。\Delta -
\={A}
同様に入力 依存を目指しますが、これは\vec{x} と掛ける事により入力依存となります。\Delta
\={A}=\exp{\Delta A} -
\={B}
同様に入力 依存です。S4D以前の、\vec{x} は一切関係が無いので注意してください。\={B}=(\Delta A)^{-1} (\exp{\Delta A} - I) \cdot \Delta B -
C
同様に入力 依存です。\vec{x}
カーネルKを計算しない
各パラメータが時間依存になった事により、S4D以前で定式化されていた
そのため Mamba では 再帰的な計算と同じ結果を作ります。つまり
の再帰的な計算結果を用意するという事です。しかしご存じの通り、再帰的な計算のままでは遅いので、ハードウェアの観点での工夫を取り入れます。それが Selective Scan と呼ばれるものです。こちらについては後述します。
Architecture の変更

H3 に見られる 1次元Conv や、Gated MLP の要素を取り入れて融合しています。
この詳細は下記のモデル構造で書きます。
モデル構造
全体的なモデル構造を明示します。コードベースで確認したところ、以下のようになっています。
簡易python環境構築
CPU環境で、何でもいいからモデルのインスタンスを作るための最低限簡易手順です。
mkdir -p ~/tmp && cd ~/tmp
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout 10b5d6358f27966f6a40e4bf0baa17a460688128
python -m venv venv && source ./venv/bin/activate
pip install torch==2.8.0+cpu --index-url https://download.pytorch.org/whl/cpu
pip install einops triton packaging huggingface_hub transformers
sed -i 's/^import selective_scan_cuda/# &/' ./mamba_ssm/ops/selective_scan_interface.py
python -i -m mamba_ssm.models.mixer_seq_simple
>>> model = MixerModel(128, 4, 128, 50277)
>>> model
MixerModel(
(embedding): Embedding(50277, 128)
(layers): ModuleList(
(0-3): 4 x Block(
(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(mixer): Mamba(
(in_proj): Linear(in_features=128, out_features=512, bias=False)
(conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
(act): SiLU()
(x_proj): Linear(in_features=256, out_features=40, bias=False)
(dt_proj): Linear(in_features=8, out_features=256, bias=True)
(out_proj): Linear(in_features=256, out_features=128, bias=False)
)
(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
(mlp): GatedMLP(
(fc1): Linear(in_features=128, out_features=256, bias=False)
(fc2): Linear(in_features=128, out_features=128, bias=False)
)
)
)
(norm_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
コードから読み解いて絵で描くと以下のようになっています。MLP Layer はオプションで設定できるようです。基本的には Transformer と同じような構成になっています。

Mamba Layer
さて、いよいよ Mamba Layer を深堀していきます。
このコードを読み解くと、以下のような絵になります。
B # バッチ数
L # 時系列長. sequence length
D # 入力次元. 言語モデルにおいてはトークンの embedding 次元
H = 16 # SSMの内部状態(h)の次元数
Din = D * 2 # D を projection した次元
K = 4 # 1次元畳み込みのカーネルサイズの初期値
R = D / 16 # dtのランク. 詳細不明

\Delta, \={A}, \={B} パラメータの解釈
「モデル構造」で見られるように、
となります。
下記の論文での解説はその事を指します。

Aは対角...ではない?
SSMのおさらい にもあるように、S4D以前では
例えば
>>> A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
>>> A_log = torch.log(A) # Keep A_log in fp32
>>> self.A_log = nn.Parameter(A_log)
そして計算のときは、以下のように行列ではなく要素積としての計算が行われています。
係数の付き方としては対角行列を行列計算した時と似ていますが、パラメータの学習のされ方がこの場合では違っています。
Aは必ず負の値となる
Aは以下のように計算され、必ず負の値 をとります。
>>> A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
>>> A
tensor([[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000],
[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000],
[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000],
...,
[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000],
[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000],
[ -1.0000, -2.0000, -3.0000, ..., -14.0000, -15.0000, -16.0000]],
grad_fn=<NegBackward0>)
>>> A.shape
torch.Size([256, 16])
[補足] Selective scan
上の図で Selective Scan というオレンジ色のブロックがあるかと思います。Selective Scan は、以下の3つの要素を備えた計算/メモリ効率化の要となるアルゴリズムです。
- kernel fusion
- parallel scan
- recomputation
kernel fusion
Selective Scan では、SRAMというGPUの高速な内部メモリを使った高速な並列計算 が使われています。それが kernel fusion です。実際のコードは、CUDAで最適化されており、内部状態

ざっくり言えば、オレンジ色はGPUレイヤーのみでの内部計算、緑色はCPU⇔GPUレイヤーの計算という事です。
parallel scan

こちらから参照
parallel scan 自体は一般的なアルゴリズムであり、並列可能な箇所から別々に計算を行い、統合しながら、全体を計算するものです。これが SSM の再帰計算を並列化するために用いられます。
具体的な並列計算の例
SSM の内部状態は以下のように計算されます。Mamba では各パラメータが時間依存のため、添え字
t=4 まで展開すると以下になります。
この時、次のようなペア
そして、次のような結合演算子を定義します。
さて、これを元に次のような計算
この2つ目の要素は
となり、途中を先に計算しても、最終的には同じ結果が得られます。これを工夫する事で、うまく並列化が可能です。t=7 まで増やして計算してみると
-
を並列で計算T_{10}, T_{32}, T_{54}, T_{76} -
を並列で計算T_{3210}=T_{32} \circ T_{10}, T_{7654}=T_{76} \circ T_{54} -
を計算T_{76543210}=T_{7654} \circ T_{3210}, T_{543210}=T_{54} \circ T_{3210} - この時点で
の計算結果はそろったので、後はその間のh_1, h_3, h_5, h_7 を計算すれば、e_0, T_{210}=e_2 \circ T_{10}, T_{43210}=e_4 \circ T_{3210}, T_{6543210}=e_6 \circ T_{543210} ~h_0 までが全て計算できる。h_7
※実際の並列的な手順と一致しているとは限らないのでご注意ください。あくまで例です。
recomputation
訓練時における forward の計算を高速に回すため GPU の内部メモリを使用しますが、その結果中間値が記録されません(例えば Pytorch の autograd に残りません)。つまり、誤差逆伝播で必要な中間値を再計算して取り戻す必要があり、それが recomputation になります。具体的な計算方法は分からないので割愛します。
特殊なタスクに対する性能
H3でも問題提起されていましが、SSMの性能として、ある特殊なタスクに対する性能を期待しています。以下は、その3つのタスクに関しての図です。

特殊なタスクの説明
Copying
指定された長さの連続したシーケンスを、後の時点で出力として再現するタスクです。このタスクは、モデルが入力を損なうことなく長時間にわたり保持する能力を必要とします。
Selective Copying
Copying タスクの発展形です。どの情報を保持すべきかを入力に基づいて決めるという要件を加えたものです。
Induction heads
Key-Value のような形で、あるトークンに対する直後のトークンを出力するようなタスクです。図では、■のトークンに対して、青い□のトークンが出力できる事を期待しています。
特殊なタスクの性能評価

Mambaはどちらのタスクでもとても良い結果となっています。これが Selective State Spaces と言われる所以でもあります。
先の「パラメータの解釈」でも述べたように、
Transformer との比較
Transformer は原理的に全ての系列情報を保存するような構造となっています。対してSSMでは、その情報の保存は内部状態
言語モデルの結果
様々なデータで結果を出していますが、ここでは言語モデルとしての性能のみ記載します。

割と他を圧倒しているように見えますが、この当時の状況や他のモデルについての理解も足りないため、結果に対する言及は避けます。
参考
Discussion
良い参考記事を発見できたので、Selective Scan の章を加筆しました。※2025/10/23 修正