🐍

S4の実装を読んでみる (状態空間モデル × 深層学習)

2024/03/06に公開

はじめに

「状態空間モデル」と「深層学習」を組み合わせた手法である S4 (Structured State Spaces for Sequence Modeling) の公式実装 (PyTorch) を読んでみます。

公式実装 (S4、S4Dなど)
https://github.com/state-spaces/s4

原著論文 (S4D)
https://arxiv.org/abs/2206.11893

実装を読んでみる

S4 より実装がシンプルな S4D のソースコードを読んでいきます。
関係するファイルは、example.pymodels/s4/s4d.py です。

S4モデル

タスクはCIFAR-10データセットの分類です。
画像を平坦化して系列データとして扱うため、入力形状は (Batch\_size, 1024, 3) に、出力形状は (Batch\_size, 10) になります。
以下は Batch\_size = 64Model\_dim = 128 の時のモデル構造です。

========================================================================================
Layer (type:depth-idx)                   Input Shape      Output Shape     Param #
========================================================================================
S4Model                                  [64, 1024, 3]    [64, 10]         --
├─Linear: 1-1                            [64, 1024, 3]    [64, 1024, 128]  512   
├─S4D: 1-2                               [64, 128, 1024]  [64, 128, 1024]  128
│    └─S4DKernel: 2-1                    --               [128, 1024]      16,512
│    └─GELU: 2-2                         [64, 128, 1024]  [64, 128, 1024]  --
│    └─DropoutNd: 2-3                    [64, 128, 1024]  [64, 128, 1024]  --
│    └─Sequential: 2-4                   [64, 128, 1024]  [64, 128, 1024]  33,024
├─Dropout1d: 1-3                         [64, 128, 1024]  [64, 128, 1024]  --
├─LayerNorm: 1-4                         [64, 1024, 128]  [64, 1024, 128]  256
(省略)
├─Linear: 1-14                           [64, 128]        [64, 10]         1,290
========================================================================================

S4Dレイヤーの入出力形状はともに (Batch\_size, Model\_dim, 1024) となります。

S4Dレイヤー

https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py#L62-L107
S4Dレイヤーでは、カーネルkと入力uの線形畳み込み (linear convolution) を計算します。

(k \ast u)[n]=\sum_{m=0}^{n} k[m] u[n-m]

この計算は、畳み込み定理 (convolution theorem) を用いることで高速に計算できます。
すなわち、kuをそれぞれ高速フーリエ変換したものの積が、kuの循環畳み込み (circular convolution) を高速フーリエ変換したものと等しくなることを利用します。

(k \ast u)[n]=\sum_{m=0}^{N} k[m] u[n-m]

線形畳み込みを計算するため、kuをゼロパディングして高速フーリエ変換します。

また、畳み込みの結果yと入力uをスキップ接続します。この係数Dは最適化するパラメータの一つです。

S4Dカーネル

https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py#L11-L59
S4Dレイヤーで出てきたカーネルkを計算します。
具体的には、零次ホールド (ZOH) による離散化を行い、ヴァンデルモンド行列 (Vandermonde matrix) の積を計算します。

\begin{align} \bm{\overline{A}} &= \exp(\Delta \bm{A}) \\ \bm{\overline{B}} &= (\Delta \bm{A})^{-1} (\exp(\Delta \cdot \bm{A}) - \bm{I}) \cdot \Delta \bm{B} \end{align}
\begin{align*} \bm{\overline{K}} = \begin{bmatrix} \bm{\overline{B}}_0 \bm{C}_0 & \dots & \bm{\overline{B}}_{N-1} \bm{C}_{N-1} \end{bmatrix} \begin{bmatrix} 1 & \bm{\overline{A}}_0 & \bm{\overline{A}}_0^2 & \dots & \bm{\overline{A}}_0^{L-1} \\ 1 & \bm{\overline{A}}_1 & \bm{\overline{A}}_1^2 & \dots & \bm{\overline{A}}_1^{L-1} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \bm{\overline{A}}_{N-1} & \bm{\overline{A}}_{N-1}^2 & \dots & \bm{\overline{A}}_{N-1}^{L-1} \\ \end{bmatrix} \end{align*}

最適化するパラメータは\Delta\bm{C}\bm{A}です。
ただしモデルの安定化のため、\bm{A}の実部と虚部を別のパラメータに分け、\bm{A}の実部が必ず負になるようにします。
また\Delta\bm{C}は乱数で、\bm{A}\bm{A}_n = -\frac{1}{2} + i \pi n に初期化します。

参考文献

  1. A Gu, et al. On the Parameterization and Initialization of Diagonal State Space Models. NeurIPS 2022.
  2. A Gu, et al. Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022.
  3. The Annotated S4
  4. HiPPO/S4解説
  5. Is Attention All You Need? Part 1 Transformer を超える(?)新モデルS4
  6. mambaの理論を理解する①:HiPPOフレームワークとLSSL
  7. mambaの理論を理解する②:S4のアルゴリズム

Discussion