S4の実装を読んでみる (状態空間モデル × 深層学習)
はじめに
「状態空間モデル」と「深層学習」を組み合わせた手法である S4 (Structured State Spaces for Sequence Modeling) の公式実装 (PyTorch) を読んでみます。
公式実装 (S4、S4Dなど)
原著論文 (S4D)
実装を読んでみる
S4 より実装がシンプルな S4D のソースコードを読んでいきます。
関係するファイルは、example.py
と models/s4/s4d.py
です。
S4モデル
タスクはCIFAR-10データセットの分類です。
画像を平坦化して系列データとして扱うため、入力形状は
以下は
========================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
========================================================================================
S4Model [64, 1024, 3] [64, 10] --
├─Linear: 1-1 [64, 1024, 3] [64, 1024, 128] 512
├─S4D: 1-2 [64, 128, 1024] [64, 128, 1024] 128
│ └─S4DKernel: 2-1 -- [128, 1024] 16,512
│ └─GELU: 2-2 [64, 128, 1024] [64, 128, 1024] --
│ └─DropoutNd: 2-3 [64, 128, 1024] [64, 128, 1024] --
│ └─Sequential: 2-4 [64, 128, 1024] [64, 128, 1024] 33,024
├─Dropout1d: 1-3 [64, 128, 1024] [64, 128, 1024] --
├─LayerNorm: 1-4 [64, 1024, 128] [64, 1024, 128] 256
(省略)
├─Linear: 1-14 [64, 128] [64, 10] 1,290
========================================================================================
S4Dレイヤーの入出力形状はともに
S4Dレイヤー
k
と入力u
の線形畳み込み (linear convolution) を計算します。
この計算は、畳み込み定理 (convolution theorem) を用いることで高速に計算できます。
すなわち、k
とu
をそれぞれ高速フーリエ変換したものの積が、k
とu
の循環畳み込み (circular convolution) を高速フーリエ変換したものと等しくなることを利用します。
線形畳み込みを計算するため、k
とu
をゼロパディングして高速フーリエ変換します。
また、畳み込みの結果y
と入力u
をスキップ接続します。この係数D
は最適化するパラメータの一つです。
S4Dカーネル
k
を計算します。
具体的には、零次ホールド (ZOH) による離散化を行い、ヴァンデルモンド行列 (Vandermonde matrix) の積を計算します。
最適化するパラメータは
ただしモデルの安定化のため、
また
参考文献
- A Gu, et al. On the Parameterization and Initialization of Diagonal State Space Models. NeurIPS 2022.
- A Gu, et al. Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022.
- The Annotated S4
- HiPPO/S4解説
- Is Attention All You Need? Part 1 Transformer を超える(?)新モデルS4
- mambaの理論を理解する①:HiPPOフレームワークとLSSL
- mambaの理論を理解する②:S4のアルゴリズム
Discussion