Mamba理解のための前提知識(S4を中心に)
はじめに
少し前に時空間統計モデリングの授業を受け、状態空間モデルについて知りました。状態空間モデルについてさらに調べていくと、最近発表されたMambaというTransformerの代替になり得るモデルがあるということを知りました。Mambaの原理を学ぼうとしましたが、前提知識がほとんどないということで、理解することができませんでした。そこで、Mambaの基礎となった論文から理解を深めることにしました。
また、この記事は授業で状態空間モデルに関する論文を要約するという課題がでたため、それを転用しています。
Mambaの基礎となる論文
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections, Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré(NeurIPS 2020 Spotlight)
-
Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer, Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021)
-
Efficiently Modeling Long Sequences with Structured State Spaces, Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral)
今回は、これらの論文を通して、Efficiently Modeling Long Sequences with Structured State Spacesを理解することを最終目標とする。
レポートのために作成したものなので、最初の2つについての具体的な内容は省略しています。
S4: Efficiently Modeling Long Sequences with Structured State Spaces
Introduction
近年、シーケンスモデリングの分野では、異なるデータモダリティやタスクで長距離依存性(long-range dependencies, LRDs)を効果的かつ効率的に扱う単一のモデルを設計することが中心的な課題となっている。しかしながら、従来のRNN、CNN、Transformerといったモデルは、LRDsを捉えるための特化型バリアントを開発してきたにもかかわらず、1万以上のシーケンス長を持つデータにおいてスケーラビリティの課題が残っている。これに対して、近年提案された「状態空間モデル(State Space Model, SSM)」は、以下の数式で表される基礎的な動的システムのシミュレーションを活用し、長距離依存性を理論的・実証的に扱う可能性を示した:
- x'(t) = Ax(t) + Bu(t)
- y(t) = Cx(t) + Du(t)
特に、適切な状態行列 A を用いることで、このアプローチは大規模な依存関係を効率的に捉えることが可能とされている。しかし、この方法は計算およびメモリコストが非常に高く、実用的なシーケンスモデリング全般の解決策としては現実的ではない。本研究では、新しいパラメータ化手法を用いたStructured State Space Model(S4)を提案し、従来のSSMアプローチの理論的な利点を維持しつつ、計算効率を大幅に向上させた。この技術的手法には、次の要素が含まれてる:
1. 状態行列 A を低ランク補正を加えて条件付け、安定した対角化を可能にする。
2. SSMの計算をCauchyカーネルという既知の効率的な計算問題に帰着させる。
この結果、S4は、トレーニングと推論の双方において効率的であり、次のような成果を達成した:
- Sequential CIFAR-10では、データ拡張や補助損失なしで91%の精度を達成し、大型の2D ResNetと同等の性能。
- 画像および言語モデリングタスクでTransformerとの差を縮小し、生成速度は60倍高速化。
- Long Range Arena(LRA)のすべてのタスクで最先端性能(SoTA)を達成し、特にPath-Xタスク(長さ16k)で唯一成功したモデルとして認識。
S4は、シーケンスモデリングにおける長距離依存性への対応を理論的および実証的に向上させるだけでなく、画像、音声、テキスト、時系列などの多様なデータモダリティにわたって汎用的に適用可能なモデルの開発に向けた一歩を示している。
Background
シーケンスモデリングにおいて、状態空間モデル(State Space Model, SSM)は、長い歴史を持つ重要な数理モデルである。このモデルは、入力信号 u(t) を潜在状態 x(t) に変換し、それを出力信号 y(t) に射影するという形で動作する。数式で表すと以下のようになる:
- x'(t) = Ax(t) + Bu(t)
- y(t) = Cx(t) + Du(t)
SSMは、制御理論や計算神経科学、信号処理など、幅広い分野で活用されてきたモデルであり、隠れマルコフモデル(HMM)や他の潜在状態モデルとも関連が深い。一方で、深層学習におけるシーケンスモデリングにSSMをそのまま適用することは困難である。特に、勾配がシーケンス長に比例して消失または発散する問題があり、これにより学習が難しくなる。この課題に対して、「HiPPOフレームワーク」が有効な解決策を提供する。このフレームワークは、SSMの状態行列 A を特殊な形式に設計することで、長距離依存性(Long-Range Dependencies, LRDs)を効率的に記憶できるようにする。HiPPO行列である状態行列Aは以下のように表現される。
Sequential MNISTというタスクでは、従来のSSMが60%程度の精度しか達成できなかったものの、HiPPO行列を導入することで98%の精度を実現している。SSMは一般にODEとして表現されるが、SSMを離散的なデータに適用する際には、連続時間の方程式を離散化する必要がある。典型的には、双線形法(Bilinear Method)を用いる。この手法により、連続的なモデルが再帰的な形式に変換され、シーケンスデータ全体を逐次的に処理することが可能となる。しかし、この再帰的形式は並列処理に適しておらず、GPUなどの現代的なハードウェアではトレーニング効率が低い。この課題を克服するため、SSMの畳み込み表現が提案されている。これにより、畳み込みカーネル K を用いて、SSMを効率的に計算することが可能である。例えば、1次元の音声データや時系列データでは、このアプローチを用いることでトレーニングと推論の速度が大幅に向上することが示されている。ただし、この畳み込みカーネルの直接的な計算は計算量が膨大であり、理論的な効率を実現するためには新たな技術的工夫が必要であった。これらの背景を踏まえ、Structured State Space(S4)は、SSMの持つ理論的な強みを維持しつつ、計算効率と汎用性を飛躍的に向上させたモデルである。例えば、S4は画像分類タスクで高い精度を達成しながらも、従来モデルに比べて30倍以上の速度でトレーニングが可能である。また、Path-Xのような極めて長いシーケンス(16,384ステップ)を含むタスクにおいても、唯一成功したモデルであることが確認されている。このように、S4はSSMの理論的な課題を解決し、長距離依存性に関するシーケンスモデリングの新たな可能性を切り拓いたモデルである。
Method: Structured State Spaces (s4)
Structured State Spaces(S4)は、状態空間モデル(SSM)の計算効率と数値安定性を大幅に向上させるために、新しいパラメータ化とアルゴリズムを提案する。特に、S4は連続時間、再帰的、畳み込みという3つのSSM表現のすべてを効率的に計算できるよう設計されている。
3.1 対角化の動機と課題
SSMの計算の主なボトルネックは、行列 A を累乗する必要があることである。この累乗計算は、計算量 O(N^2L)を要し、特に長いシーケンス L を扱う場合に非効率である。行列Aを対角化することで、この問題を理論的に解決できる。対角行列では累乗計算が要素ごとの累乗に簡略化され、効率的に計算可能となる。しかし、従来のSSMで用いられる行列(例:HiPPO行列)は数値的不安定性を持つため、単純な対角化では実用的でない。この問題に対処するため、S4では「正規行列+低ランク行列(Normal Plus Low-Rank, NPLR)」という新しいパラメータ化を導入する。
3.2 NPLR形式によるパラメータ化
行列Aを以下の形式に分解する:
ここで、
- V はユニタリ行列(数値的に安定)
-
は対角行列(累乗が効率的)\Lambda - P, Q は低ランク行列であり、A の特性を補正する役割を果たす。
この形式により、累乗計算における不安定性を軽減しつつ、行列演算の効率を向上させる。
さらに、この形式はWoodburyの補題を用いることで、低ランク補正を効率的に計算することを可能にする。
3.3 Cauchyカーネルの導入と効率化
NPLR形式を用いた A の操作は、Cauchyカーネルを利用してさらに効率化できる。Cauchyカーネルは以下の形で表される:
ここで、
具体的には、以下のアルゴリズムを用いる:
1. Cauchyカーネルを評価し、スペクトル領域で畳み込みカーネルを計算する。
2. 高速フーリエ変換(FFT)を利用して、結果を時間領域に変換する。
このアプローチにより、従来 O(N^2L) であった計算量をO(N + L) に削減できる。
3.4 S4の計算複雑度
S4は以下のような計算効率を実現している:
- 再帰表現:1ステップあたりO(N) の計算量で実行可能。
- 畳み込み表現:O(N + L) の計算量で畳み込みカーネルを生成。
- メモリ使用量:O(N + L) に削減され、長いシーケンスに対してもスケーラブル。
これらの特性により、S4は従来の状態空間モデルやTransformerに比べ、圧倒的に高い計算効率を実現している。
3.5 アーキテクチャの詳細
S4レイヤーは、SSMを基礎に設計されており、次の特徴を持つ:
1. 入力シーケンスを受け取り、畳み込みカーネルを用いて変換を行う。
2. 各レイヤーは非線形活性化関数を組み合わせ、深層学習モデルとしての柔軟性を持つ。
3. マルチチャネルのデータ(例:画像データ)には複数の独立したS4インスタンスを適用し、特徴間の結合を線形層で処理する。
この設計により、S4は従来のRNN、CNN、Transformerと同様の柔軟性を持ちながら、計算効率を飛躍的に向上させている。
Experiments
本章では、Structured State Spaces(S4)の性能を評価するために実施した実験について述べる。S4は効率性、長距離依存性(Long-Range Dependencies, LRDs)の学習能力、そして汎用性を備えており、その特性を確認するため、さまざまなベンチマークやタスクを対象に詳細な検証を行った。
4.1 計算効率の評価
S4は、従来の状態空間モデル(LSSL)や効率的なTransformerモデル(例: Performer、Linear Transformer)と比較して計算効率を向上させる設計を持つ。
実験設定
計算効率を評価するため、以下の2つの指標を測定した:
1. トレーニングステップの計算時間(1ステップあたりの時間)。
2. メモリ使用量(モデルトレーニング時に必要なメモリ量)。
結果
- LSSLとの比較
S4はLSSLに比べて計算速度が最大30倍向上し、メモリ使用量は400分の1に削減された。 - Transformerモデルとの比較
PerformerやLinear Transformerと同等の計算効率を示し、長いシーケンス(例: 長さ4096)においても競争力があることが確認された。
4.2 長距離依存性(LRD)の学習能力
S4のLRDに関する性能を評価するため、以下のタスクを実施した。
タスク 1: Long Range Arena(LRA)
LRAは、長さ1K〜16Kステップの6つのタスクで構成され、LRDをテストするために設計されたベンチマークである。
結果
- S4はすべてのタスクで他の全モデルを上回る性能を発揮した。
- 平均スコアは 86.09% であり、Transformerモデルを含む従来のベースラインを大幅に超えた。
- 特に、Path-Xタスク(長さ16,384)では、S4は 88% の精度を達成し、従来モデルがランダム推測レベル(50%)に留まる中、唯一成功したモデルであった。
タスク 2: 音声分類
Speech Commandsデータセットの長さ16,000の生音声データを対象に分類タスクを実施した。
結果
- S4は 98.3% の精度を達成し、従来の音声CNNモデルを上回る性能を示した。
- 特徴抽出(例: MFCC変換)を必要とせず、生データから直接学習する能力を証明した。
4.3 汎用性の評価
S4の汎用性を確認するため、さまざまなドメインでのタスクを実施した。
タスク 1: 大規模生成モデル
CIFAR-10(画像生成)とWikiText-103(言語生成)の2つのベンチマークを用いた。
結果
- CIFAR-10
S4は密度推定において2.85 bits/dimを達成し、PixelSNAILなどの2D特化型モデルに匹敵する性能を示した。 - WikiText-103
S4は20.95のパープレキシティを記録し、Transformerに非常に近い性能を達成した。
タスク 2: 高速な自動回帰生成
S4の再帰表現を利用して、CIFAR-10およびWikiText-103における生成速度を測定した。
結果
- S4はTransformerと比較して60倍高速に生成を行うことが可能であった。
タスク 3: 異なるサンプリングレートへの適応
Speech Commandsデータセットをサンプリングレートの半分(0.5×)で評価した。
結果
- 再トレーニングを行わず、96.3%の精度を維持した。
- S4の連続時間モデルとしての柔軟性が示された。
タスク 4: 画像分類
ピクセルレベルのSequential CIFAR-10タスクでの性能を評価した。
結果
- S4は91.13%の精度を達成し、特化型モデルを上回る性能を示した。
タスク 5: 時系列予測
複数の時系列予測タスク(例: 気象データ予測、エネルギーデータ予測)を実施した。
結果
- S4は従来の特化型モデル(例: Informer)を上回り、特に長い予測範囲において大幅な改善を達成した。
4.4 アブレーション研究
S4の設計要素の重要性を検証するため、以下の実験を実施した。
HiPPO初期化の有無
- HiPPO行列で初期化した場合、性能が大幅に向上することが確認された。
NPLRパラメータ化の影響
- NPLR形式自体では性能の向上は限定的であり、HiPPO行列との組み合わせが効果的であることが示された。
Conclusion
本研究では、長距離依存性(Long-Range Dependencies, LRDs)を効率的かつ効果的にモデル化するための新しい手法としてStructured State Spaces(S4)を提案した。S4は、状態空間モデル(SSM)の計算効率と数値安定性を向上させる新しいパラメータ化(NPLR形式)とアルゴリズムを導入し、従来のモデルが抱える計算コストやスケーラビリティの課題を解決した。
実験結果から、S4は以下の特性を備えていることが明らかになった:
1. 計算効率
S4は従来のLSSLやTransformerモデルに比べて、計算速度やメモリ効率において大幅に優れており、長いシーケンスを扱うタスクでも高いパフォーマンスを発揮した。
2. 長距離依存性の学習能力
S4は、Path-Xを含む長いシーケンスを持つベンチマークで最先端の性能を達成し、既存のモデルが学習できなかった課題に成功した唯一のモデルであった。
3. 汎用性
S4は、画像生成、音声分類、時系列予測など、異なるドメインにおいても高い性能を示し、連続時間モデルとしてサンプリングレートの変更にも柔軟に対応できる。
さらに、S4の設計要素の重要性をアブレーション研究によって検証した結果、HiPPOフレームワークの初期化とNPLR形式の組み合わせが特に効果的であることが確認された。
以上の結果から、S4は計算効率、長距離依存性の学習能力、汎用性のすべてにおいて優れた性能を示し、次世代のシーケンスモデリング手法として有望である。今後の研究では、S4をさらに発展させるために以下の方向性が考えられる:
- S4を高次元データ(例: 画像や動画)に適用するための拡張。
- 他のモデル(例: Transformer)と組み合わせたハイブリッドアプローチの探求。
- より大規模なデータセットや生成タスクへの応用。
本研究は、状態空間モデルの可能性を拡張し、多様な応用分野におけるシーケンスモデリングの新たな道を切り開いた。
おわりに
s4を理解するにあたって、論文にはゴリゴリの理論が立ちはだかっており、今回取り上げた本論文での工夫点がなぜ有効かについての証明など非常にボリュームのあるものである。いきなりその数式たちと対峙するのは、とても厳しい戦いになるため、今回は難しい証明などは省略し、結局この論文はどんな課題に、どのような工夫をして、どのくらい精度が上がったかのざっくりとした概要を把握することに努めた。
間違いなどありましたら、教えて頂けたら、今後の成長にもつながります。
読んでいただき、ありがとうございました。
参考文献
-
HiPPO: Recurrent Memory with Optimal Polynomial Projections, Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré(NeurIPS 2020 Spotlight)
-
Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer, Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré (NeurIPS 2021)
-
Efficiently Modeling Long Sequences with Structured State Spaces, Albert Gu, Karan Goel, Christopher Ré (ICLR 2022 Oral)
- HiPPO/S4解説 https://www.slideshare.net/MorphoIncPR/hippos4#14
Discussion