🐎

状態空間モデル 論文解説④「 Mamba 」

に公開
1

論文

前回は、少し脇道にそれて、S4系列とは別の著者の論文であるH3について解説しました。

今回は、いよいよ本命の Mamba になります。

タイトル

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

論文: https://arxiv.org/pdf/2312.00752
Github: https://github.com/state-spaces/mamba

解説ステップ

概要

  • SSMにおけるパラメータA,B,C,dtを時間(入力)依存にした
  • 時間依存による制限で、カーネル計算をやめた
  • GPUの高速メモリを用いて高速な並列計算を実現した
  • H3 と Gated MLP の構造を取り入れた

いったん S4D までの理論的な枠組みは忘れてください。Mamba では、状態空間モデル(SSM)をカーネル計算など行わずシンプルに再帰的な計算を並列スキャンで実現します。

SSMのおさらい

入力\vec{x}の次元D_{in}, 内部状態hの次元D_h=3, 出力yD_{out}=D_{in}, 入力xの時系列長Lの場合
※Mamba Layer では D_in と D_out の次元は一致します。論文に合わせて記号を変えています。

\underbrace{ \begin{bmatrix} \vec{h}_{t}^{(0)} \\ \vec{h}_{t}^{(1)} \\ \vec{h}_{t}^{(2)} \\ \end{bmatrix} }_{\in \mathbb{R}^{3 \times D_{in}}} =\underbrace{ \begin{bmatrix} a_0 & 0 & 0 \\ 0 & a_1 & 0 \\ 0 & 0 & a_2 \\ \end{bmatrix} }_{A \in \mathbb{R}^{3 \times 3}} \underbrace{ \begin{bmatrix} \vec{h}_{t-1}^{(0)} \\ \vec{h}_{t-1}^{(1)} \\ \vec{h}_{t-1}^{(2)} \\ \end{bmatrix} }_{\in \mathbb{R}^{3 \times D_{in}}} + \underbrace{ \begin{bmatrix} b_{0} \\ b_{1} \\ b_{2} \\ \end{bmatrix} }_{B \in \mathbb{R}^{3 \times 1}} \underbrace{ \vec{x}_t }_{\in \mathbb{R}^{1 \times D_{in}}}
\underbrace{ \vec{y}_t }_{\in \mathbb{R}^{1 \times D_{out}}} =\underbrace{ \begin{bmatrix} c_{0} & c_{1} & c_{2} \\ \end{bmatrix} }_{C \in \mathbb{R}^{1 \times 3}} \begin{bmatrix} \vec{h}_{t}^{(0)} \\ \vec{h}_{t}^{(1)} \\ \vec{h}_{t}^{(2)} \\ \end{bmatrix} + \underbrace{ \vec{d} }_{D \in \mathbb{R}^{1 \times D_{in}}} \odot \vec{x}_t

\odot」は要素積です。ベクトルの内積ではなく、要素同士の積になっている事に注意してください。

行列Bは、入力\vec{x}を複数の重みで保持するような行列で、行列Aは内部状態hを適切な係数をかけて繰り越すような意味合いを持ちます。行列Chを適切な重みで合成する操作です。

S4Dからの変更点

行列A,B,C,\Deltaの時間依存

S4D以前では、行列A,B,C は時間不変(Linear Time Invariance (LTI))でした。Mamba ではこのパラメータを**時間依存(入力依存)**形式に置き換える事で、より柔軟なパラメータを実現しています。

  • \Delta
    SSMの離散化における時間間隔を表します。S4D以前は単一のパラメータでした。
    それが今回入力\vec{x}依存となり、それと同じ次元の変数を持ちます。
    つまり、各データ・各時系列点・各潜在変数に対して1対1対応で\Deltaを持ち、その重要度を係数として操作できるようになります。
  • \={A}
    同様に入力\vec{x}依存を目指しますが、これは \Deltaと掛ける事により入力依存となります。
    \={A}=\exp{\Delta A}
  • \={B}
    同様に入力\vec{x}依存です。S4D以前の、\={B}=(\Delta A)^{-1} (\exp{\Delta A} - I) \cdot \Delta B一切関係が無いので注意してください。
  • C
    同様に入力\vec{x}依存です。

カーネルKを計算しない

各パラメータが時間依存になった事により、S4D以前で定式化されていた \={K}=(C\={B}, C\={A}\={B}, ..., C\={A}^{k}\={B}, ...) といったカーネルはそう単純に計算できません。

そのため Mamba では 再帰的な計算と同じ結果を作ります。つまり

h_t=Ah_{t-1} + Bx_t, y_t=Ch_t + Dx_t

の再帰的な計算結果を用意するという事です。しかしご存じの通り、再帰的な計算のままでは遅いので、ハードウェアの観点での工夫を取り入れます。それが Selective Scan と呼ばれるものです。こちらについては後述します。

Architecture の変更

H3 に見られる 1次元Conv や、Gated MLP の要素を取り入れて融合しています。

この詳細は下記のモデル構造で書きます。

モデル構造

全体的なモデル構造を明示します。コードベースで確認したところ、以下のようになっています。

簡易python環境構築

CPU環境で、何でもいいからモデルのインスタンスを作るための最低限簡易手順です。

mkdir -p ~/tmp && cd ~/tmp
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout 10b5d6358f27966f6a40e4bf0baa17a460688128
python -m venv venv && source ./venv/bin/activate
pip install torch==2.8.0+cpu --index-url https://download.pytorch.org/whl/cpu
pip install einops triton packaging huggingface_hub transformers
sed -i 's/^import selective_scan_cuda/# &/' ./mamba_ssm/ops/selective_scan_interface.py
python -i -m mamba_ssm.models.mixer_seq_simple
>>> model = MixerModel(128, 4, 128, 50277)
>>> model
MixerModel(
  (embedding): Embedding(50277, 128)
  (layers): ModuleList(
    (0-3): 4 x Block(
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mixer): Mamba(
        (in_proj): Linear(in_features=128, out_features=512, bias=False)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        (act): SiLU()
        (x_proj): Linear(in_features=256, out_features=40, bias=False)
        (dt_proj): Linear(in_features=8, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=128, bias=False)
      )
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): GatedMLP(
        (fc1): Linear(in_features=128, out_features=256, bias=False)
        (fc2): Linear(in_features=128, out_features=128, bias=False)
      )
    )
  )
  (norm_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)

https://github.com/state-spaces/mamba/blob/10b5d6358f27966f6a40e4bf0baa17a460688128/mamba_ssm/modules/block.py#L10

コードから読み解いて絵で描くと以下のようになっています。MLP Layer はオプションで設定できるようです。基本的には Transformer と同じような構成になっています。

Mamba Layer

さて、いよいよ Mamba Layer を深堀していきます。

https://github.com/state-spaces/mamba/blob/10b5d6358f27966f6a40e4bf0baa17a460688128/mamba_ssm/modules/mamba_simple.py#L31

このコードを読み解くと、以下のような絵になります。

B # バッチ数
L # 時系列長. sequence length
D # 入力次元. 言語モデルにおいてはトークンの embedding 次元
H   = 16     # SSMの内部状態(h)の次元数
Din = D * 2  # D を projection した次元
K   = 4      # 1次元畳み込みのカーネルサイズの初期値
R   = D / 16 # dtのランク. 詳細不明

\Delta, \={A}, \={B} パラメータの解釈

「モデル構造」で見られるように、\DeltaABに作用します。そして、\Deltaは softplus の関数を通ってくるので必ず正の値を取ります。

\={A}=\exp{(\Delta A)}, \ \ \ \={B}=\Delta B

A は後述しますが、必ず負の値 をとります。つまり
\Delta \rightarrow \infin では

\={A} \rightarrow 0, \ \ \ \={B} \rightarrow \infin

\Delta \rightarrow 0 では

\={A} \rightarrow 1, \ \ \ \={B} \rightarrow 0

となります。

\={A}\={B}の大きさは、内部状態hと入力xどれだけ残すかに関係があります。つまり、\Delta とはその制御パラメータです。

下記の論文での解説はその事を指します。

Aは対角...ではない?

SSMのおさらい にもあるように、S4D以前では A \in \mathbb{R}^{H \times H} は対角でした。しかし今回は以下のように、repeat してからパラメータとして登録しています。

例えば D_{H}=3 では、a_0, a_1, a_2 だけがパラメータとなるなずですが、さらに D_{in} 分繰り返して生成しています。

>>> A = repeat(
    torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
    "n -> d n",
    d=self.d_inner,
).contiguous()
>>> A_log = torch.log(A)  # Keep A_log in fp32
>>> self.A_log = nn.Parameter(A_log)

そして計算のときは、以下のように行列ではなく要素積としての計算が行われています。

\begin{bmatrix} a_0^{(1)} & a_0^{(2)} & a_0^{(3)} & \cdots & a_0^{(D_{in})} \\ a_1^{(1)} & a_1^{(2)} & a_1^{(3)} & \cdots & a_1^{(D_{in})} \\ a_2^{(1)} & a_2^{(2)} & a_2^{(3)} & \cdots & a_2^{(D_{in})} \\ \end{bmatrix} \odot \begin{bmatrix} \vec{h}_{t-1}^{(0)} \\ \vec{h}_{t-1}^{(1)} \\ \vec{h}_{t-1}^{(2)} \\ \end{bmatrix}

係数の付き方としては対角行列を行列計算した時と似ていますが、パラメータの学習のされ方がこの場合では違っています。

Aは必ず負の値となる

Aは以下のように計算され、必ず負の値 をとります。

https://github.com/state-spaces/mamba/blob/10b5d6358f27966f6a40e4bf0baa17a460688128/mamba_ssm/modules/mamba_simple.py#L143

>>> A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
>>> A
tensor([[ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000],
        [ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000],
        [ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000],
        ...,
        [ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000],
        [ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000],
        [ -1.0000,  -2.0000,  -3.0000,  ..., -14.0000, -15.0000, -16.0000]],
       grad_fn=<NegBackward0>)
>>> A.shape
torch.Size([256, 16])

[補足] Selective scan

上の図で Selective Scan というオレンジ色のブロックがあるかと思います。Selective Scan は、以下の3つの要素を備えた計算/メモリ効率化の要となるアルゴリズムです。

  • kernel fusion
  • parallel scan
  • recomputation

kernel fusion

Selective Scan では、SRAMというGPUの高速な内部メモリを使った高速な並列計算 が使われています。それが kernel fusion です。実際のコードは、CUDAで最適化されており、内部状態hは PyTorch(Python) 側の変数として表面には出てきません。

ざっくり言えば、オレンジ色はGPUレイヤーのみでの内部計算緑色はCPU⇔GPUレイヤーの計算という事です。

parallel scan


こちらから参照

parallel scan 自体は一般的なアルゴリズムであり、並列可能な箇所から別々に計算を行い、統合しながら、全体を計算するものです。これが SSM の再帰計算を並列化するために用いられます。

具体的な並列計算の例

SSM の内部状態は以下のように計算されます。Mamba では各パラメータが時間依存のため、添え字tをつけています。

h_t=A_t h_{t-1} + B_t x_t = A_t h_{t-1} + b_t

t=4 まで展開すると以下になります。

\begin{align*} h_0 &= b_0 \\ h_1 &= A_1 ( b_0 ) + b_1 \\ h_2 &= A_2 ( A_1 ( b_0 ) + b_1 ) + b_2 \\ h_3 &= A_3 ( A_2 ( A_1 ( b_0 ) + b_1 ) + b_2 ) + b_3 \\ h_4 &= A_4 ( A_3 ( A_2 ( A_1 ( b_0 ) + b_1 ) + b_2 ) + b_3 ) + b_4 \\ \end{align*}

この時、次のようなペアe_t=(A_t, b_t)を考えます。

e_0=(1, b_0), e_1=(A_1, b_1), e_2=(A_2, b_2), e_3=(A_3, b_3), e_4=(A_4, b_4)

そして、次のような結合演算子を定義します。

T_{ij} = e_i \circ e_j = (A_i, b_i) \circ (A_j, b_j) = (A_i A_j, A_i b_j + b_i)

さて、これを元に次のような計算Xを考えます。

\begin{align*} X &= e_4 \circ e_3 \circ e_2 \circ e_1 \circ e_0 \\ &= (A_4, b_4) \circ (A_3, b_3) \circ (A_2, b_2) \circ (A_1, b_1) \circ (1, b_0) \\ &= (A_4, b_4) \circ (A_3, b_3) \circ (A_2, b_2) \circ (A_1, A_1 b_0 + b_1) \\ &= (A_4, b_4) \circ (A_3, b_3) \circ (A_2 A_1, A_2 (A_1 b_0 + b_1) + b_2) \\ &= (A_4, b_4) \circ (A_3 A_2 A_1, A_3 (A_2 (A_1 b_0 + b_1) + b_2) + b_3) \\ &= (A_4 A_3 A_2 A_1, A_4 (A_3 (A_2 (A_1 b_0 + b_1) + b_2) + b_3) + b_4) \end{align*}

この2つ目の要素は h_4 そのものです。では、後ろから順に計算するのではなく、途中を先に計算してみましょう。

\begin{align*} X &= (e_4 \circ e_3) \circ (e_2 \circ e_1) \circ e_0 \\ &= ((A_4, b_4) \circ (A_3, b_3)) \circ ((A_2, b_2) \circ (A_1, b_1)) \circ (1, b_0) \\ &= (A_4 A_3, A_4 b_3 + b_4) \circ(A_2 A_1, A_2 b_1 + b_2) \circ (1, b_0) \\ &= ( (A_4 A_3) (A_2 A_1), (A_4 A_3) (A_2 b_1 + b_2) + (A_4 b_3 + b_4)) \circ (1, b_0) \\ &= (A_4 A_3 A_2 A_1, (A_4 A_3 A_2 A_1) b_0 + (A_4 A_3) (A_2 b_1 + b_2) + (A_4 b_3 + b_4)) &= (A_4 A_3 A_2 A_1, A_4 A_3 A_2 A_1 b_0 + A_4 A_3 A_2 b_1 + A_4 A_3 b_2 + A_4 b_3 + b_4) \end{align*}

となり、途中を先に計算しても、最終的には同じ結果が得られます。これを工夫する事で、うまく並列化が可能です。t=7 まで増やして計算してみると

  1. T_{10}, T_{32}, T_{54}, T_{76} を並列で計算
  2. T_{3210}=T_{32} \circ T_{10}, T_{7654}=T_{76} \circ T_{54} を並列で計算
  3. T_{76543210}=T_{7654} \circ T_{3210}, T_{543210}=T_{54} \circ T_{3210} を計算
  4. この時点で h_1, h_3, h_5, h_7 の計算結果はそろったので、後はその間の e_0, T_{210}=e_2 \circ T_{10}, T_{43210}=e_4 \circ T_{3210}, T_{6543210}=e_6 \circ T_{543210} を計算すれば、h_0h_7 までが全て計算できる。

※実際の並列的な手順と一致しているとは限らないのでご注意ください。あくまで例です。

recomputation

訓練時における forward の計算を高速に回すため GPU の内部メモリを使用しますが、その結果中間値が記録されません(例えば Pytorch の autograd に残りません)。つまり、誤差逆伝播で必要な中間値を再計算して取り戻す必要があり、それが recomputation になります。具体的な計算方法は分からないので割愛します。

特殊なタスクに対する性能

H3でも問題提起されていましが、SSMの性能として、ある特殊なタスクに対する性能を期待しています。以下は、その3つのタスクに関しての図です。

特殊なタスクの説明

Copying

指定された長さの連続したシーケンスを、後の時点で出力として再現するタスクです。このタスクは、モデルが入力を損なうことなく長時間にわたり保持する能力を必要とします。

Selective Copying

Copying タスクの発展形です。どの情報を保持すべきかを入力に基づいて決めるという要件を加えたものです。

Induction heads

Key-Value のような形で、あるトークンに対する直後のトークンを出力するようなタスクです。図では、■のトークンに対して、青い□のトークンが出力できる事を期待しています。

特殊なタスクの性能評価

Mambaはどちらのタスクでもとても良い結果となっています。これが Selective State Spaces と言われる所以でもあります。

先の「パラメータの解釈」でも述べたように、\Delta による情報の取捨選択機構がうまく作用した結果と言えます。

Transformer との比較

Transformer は原理的に全ての系列情報を保存するような構造となっています。対してSSMでは、その情報の保存は内部状態hの次元数に大きく制限されます。情報を圧縮し、内部状態を持ちまわしているので、当然といえば当然です。

言語モデルの結果

様々なデータで結果を出していますが、ここでは言語モデルとしての性能のみ記載します。

割と他を圧倒しているように見えますが、この当時の状況や他のモデルについての理解も足りないため、結果に対する言及は避けます。

参考

https://speakerdeck.com/kurita/mamba

Discussion