論文
前回は、S4の理論編と実装編を解説しました。
今回は、その続編である以下の論文について解説します。
タイトル
On the Parameterization and Initialization of Diagonal State Space Models
論文: https://arxiv.org/pdf/2206.11893
Github: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py
概要
- S4 をシンプルにしたモデル
- 行列Aを対角行列に制限した(低ランク表現をやめた)
- 行列Aの実部をマイナスに制限した
- シンプルになったにも関わらず、精度は概ねS4と同等である
SSMのおさらい
入力uの次元D_{in}=2, 内部状態xの次元D_x=3, 出力yのD_{out}=4, 入力uの時系列長Lの場合(簡単のため行列Aは対角行列)
u_0=\begin{bmatrix}
u_{00} \\ u_{01}
\end{bmatrix}, \ ... \ , \ \
u_t=\begin{bmatrix}
u_{t0} \\ u_{t1}
\end{bmatrix}, \ ... \ , \ \
u_L=\begin{bmatrix}
u_{L0} \\ u_{L1}
\end{bmatrix}
\begin{bmatrix}
x_{t}^{(0)} \\
x_{t}^{(1)} \\
x_{t}^{(2)} \\
\end{bmatrix}=
\underbrace{
\begin{bmatrix}
{\lambda}_{0} & 0 & 0 \\
0 & {\lambda}_{1} & 0 \\
0 & 0 & {\lambda}_{2}
\end{bmatrix}
}_{A}
\begin{bmatrix}
x_{t-1}^{(0)} \\
x_{t-1}^{(1)} \\
x_{t-1}^{(2)} \\
\end{bmatrix} +
\underbrace{
\begin{bmatrix}
b_{00} & b_{01} \\
b_{10} & b_{11} \\
b_{20} & b_{21} \\
\end{bmatrix}
}_{B}
\begin{bmatrix}
u_{t0} \\ u_{t1}
\end{bmatrix}
\begin{bmatrix}
y_{t}^{(0)} \\
y_{t}^{(1)} \\
y_{t}^{(2)} \\
y_{t}^{(3)}
\end{bmatrix}=
\underbrace{
\begin{bmatrix}
c_{00} & c_{01} & c_{02} \\
c_{10} & c_{11} & c_{12} \\
c_{20} & c_{21} & c_{22} \\
c_{30} & c_{31} & c_{32} \\
\end{bmatrix}
}_{C}
\begin{bmatrix}
x_{t}^{(0)} \\
x_{t}^{(1)} \\
x_{t}^{(2)} \\
\end{bmatrix} +
\underbrace{
\begin{bmatrix}
d_{00} & d_{01} \\
d_{10} & d_{11} \\
d_{20} & d_{21} \\
d_{30} & d_{31} \\
\end{bmatrix}
}_{D}
\begin{bmatrix}
u_{t0} \\ u_{t1}
\end{bmatrix}
愚直に変数表現で表すと上記のようになります(自分にとっては分かりやすいです)。基本的にはD=0と考えますが、念のため全部書き下してみました。
行列A
ざっくり言えば、S4DはS4に比べて行列Aをシンプルにしたモデルです。ではその詳細について見てみましょう。
A=
\begin{bmatrix}
{\lambda}_{1} & ... & 0 \\
... & ... & ... \\
0 & ... & {\lambda}_{n}
\end{bmatrix}
行列のランクを無くす
S4では行列AがDPLR(Diagonal Plus Low-Ran)形式で記述できる事で、カーネルの計算を簡略化できていました。改めて式を記述すると以下になります。
S4DではこのPQ^*という項を無くしています。え、勝手にそんな事をしていいの? って感じですが、この項が無くなると、どういう影響があるのかを見てみましょう。
この項は、カーネル計算に影響します。改めてカーネルの計算式を書いてみます。
{\hat{K}_L}'(z)=\frac{2}{1+z} \tilde{C}' \left( R(z) - R(z)P(I+Q^*R(z)P)^{-1}Q^*R(z) \right) B'
R(z)=\left( \frac{1-z}{1+z}sI-\Lambda \right)^{-1}
{\hat{K}_L}'(z)の第二項にPとQ^*が現れます。これはつまり、対角行列R(z)の固有値を混ぜ合わせて複雑な状態を表現している訳です。
この低ランク項PQ^*を無くすという事は、この固有値の相互作用が発生する項を無くす、つまり表現力が落ちる事を意味します。この表現力に関する議論は Expressivity and Limitations of Diagonal SSMs で記述されているような気がしますが、ちょっと難しいので割愛します。
実部の制約

とありますが、S4Dでは行列Aの実部がマイナスになるように制限しています。これは、長い時系列でカーネルが発散するのを防ぐためです。
ReLU形式なども取れるとしており、行列Aは以下のような形になります。
A=-exp(A_{Re})+i\cdot A_{Img} \ \ \ or \ \ \ A=-ReLU(A_{Re})+i\cdot A_{Img}
連続値SSMのカーネルCe^{tA}B、離散SSMのカーネルC\={A}^{t-1}\={B} となっています。Aの実部がマイナスであれば連続値SSMのカーネルは発散せず、離散SSMでは\={A}=(I-\Delta /2A)^{-1}(I+\Delta /2A) (Bilinear形式)なので、Aがマイナスであれば \={A}<1となり、離散SSMのカーネルは発散しません。

上図で、横軸は時系列長で、縦軸はカーネルの値です。時間が大きい箇所(つまり過去の情報、K_{l=L}となるような時系列点)では値が減衰している事が分かります。
上図の疑似コードはZOH展開での計算式(以下)を使用しており、 (exp(dt*A)-1)/A
となっている箇所は、\={B}を展開した時の項が現れています。

カーネルの計算がぐっと楽になっている事が分かります。
S4D kernel
ある時系列点lのカーネルはK_l=\={C}{\={A}}^{l-1}\={B}\ \ (\={C}=C)のように表せるので(簡単のため、変数表記を統一します)
K_l=
\begin{bmatrix}
c_{00} & c_{01} & c_{02} \\
c_{10} & c_{11} & c_{12} \\
c_{20} & c_{21} & c_{22} \\
c_{30} & c_{31} & c_{32} \\
\end{bmatrix}
\begin{bmatrix}
{\lambda}_{0}^{l-1} & 0 & 0 \\
0 & {\lambda}_{1}^{l-1} & 0 \\
0 & 0 & {\lambda}_{2}^{l-1}
\end{bmatrix}
\begin{bmatrix}
b_{00} & b_{01} \\
b_{10} & b_{11} \\
b_{20} & b_{21} \\
\end{bmatrix}
=\begin{bmatrix}
c_{00}{\lambda}_{0}^{l-1}b_{00} + c_{01}{\lambda}_{1}^{l-1}b_{10} + c_{02}{\lambda}_{2}^{l-1}b_{20} & c_{00}{\lambda}_{0}^{l-1}b_{01} + c_{01}{\lambda}_{1}^{l-1}b_{11} + c_{02}{\lambda}_{2}^{l-1}b_{21} \\
c_{10}{\lambda}_{0}^{l-1}b_{00} + c_{11}{\lambda}_{1}^{l-1}b_{10} + c_{12}{\lambda}_{2}^{l-1}b_{20} & c_{10}{\lambda}_{0}^{l-1}b_{01} + c_{11}{\lambda}_{1}^{l-1}b_{11} + c_{12}{\lambda}_{2}^{l-1}b_{21} \\
c_{20}{\lambda}_{0}^{l-1}b_{00} + c_{21}{\lambda}_{1}^{l-1}b_{10} + c_{22}{\lambda}_{2}^{l-1}b_{20} & c_{20}{\lambda}_{0}^{l-1}b_{01} + c_{21}{\lambda}_{1}^{l-1}b_{11} + c_{22}{\lambda}_{2}^{l-1}b_{21} \\
c_{30}{\lambda}_{0}^{l-1}b_{00} + c_{31}{\lambda}_{1}^{l-1}b_{10} + c_{32}{\lambda}_{2}^{l-1}b_{20} & c_{30}{\lambda}_{0}^{l-1}b_{01} + c_{31}{\lambda}_{1}^{l-1}b_{11} + c_{32}{\lambda}_{2}^{l-1}b_{21} \\
\end{bmatrix}
=\begin{bmatrix}
c_{00}{\lambda}_{0}^{l-1}b_{00} & c_{00}{\lambda}_{0}^{l-1}b_{01} \\
c_{10}{\lambda}_{0}^{l-1}b_{00} & c_{10}{\lambda}_{0}^{l-1}b_{01} \\
c_{20}{\lambda}_{0}^{l-1}b_{00} & c_{20}{\lambda}_{0}^{l-1}b_{01} \\
c_{30}{\lambda}_{0}^{l-1}b_{00} & c_{30}{\lambda}_{0}^{l-1}b_{01} \\
\end{bmatrix} +
\begin{bmatrix}
c_{01}{\lambda}_{1}^{l-1}b_{10} & c_{01}{\lambda}_{1}^{l-1}b_{11} \\
c_{11}{\lambda}_{1}^{l-1}b_{10} & c_{11}{\lambda}_{1}^{l-1}b_{11} \\
c_{21}{\lambda}_{1}^{l-1}b_{10} & c_{21}{\lambda}_{1}^{l-1}b_{11} \\
c_{31}{\lambda}_{1}^{l-1}b_{10} & c_{31}{\lambda}_{1}^{l-1}b_{11} \\
\end{bmatrix} +
\begin{bmatrix}
c_{02}{\lambda}_{2}^{l-1}b_{20} & c_{02}{\lambda}_{2}^{l-1}b_{21} \\
c_{12}{\lambda}_{2}^{l-1}b_{20} & c_{12}{\lambda}_{2}^{l-1}b_{21} \\
c_{22}{\lambda}_{2}^{l-1}b_{20} & c_{22}{\lambda}_{2}^{l-1}b_{21} \\
c_{32}{\lambda}_{2}^{l-1}b_{20} & c_{32}{\lambda}_{2}^{l-1}b_{21} \\
\end{bmatrix}
行列B,C を内部状態の次元に対して以下のように考えてみます。
B=\begin{bmatrix}
B_0 \\
B_1 \\
B_2 \\
\end{bmatrix}, \ \ \
C=\begin{bmatrix}
C_0 & C_1 & C_2
\end{bmatrix}, \ \ \
\left(
B_0=
\begin{bmatrix}
b_{00} & b_{01}
\end{bmatrix}, \ \ \
C_0=
\begin{bmatrix}
c_{00} \\
c_{10} \\
c_{20} \\
c_{30} \\
\end{bmatrix}
\right)
この形式で表現すると
K_l={\lambda}_{0}^{l-1} C_0 B_0 + {\lambda}_{1}^{l-1} C_1 B_1 + {\lambda}_{2}^{l-1} C_2 B_2 \ \ \ \in \mathbb{R}^{4\times 2}
内部状態x の次元D_x=N で一般化すると(D_{in}, D_{out} の次元)
K_l=\sum_{n=0}^{N-1}{\lambda}_{n}^{l-1} C_n B_n \ \ \ \in \mathbb{R}^{D_{out}\times D_{in}}
さらに、カーネルK_L時系列長Lのベクトルとして一般化すると
K_L=
\begin{bmatrix}
C_0 B_0 & ... & C_n B_n & ... C_{N-1} B_{N-1}
\end{bmatrix}
\begin{bmatrix}
1 & ... & {\lambda}_{0}^{l-1} & ... & {\lambda}_{0}^{L-1} \\
1 & ... & ... & ... & ... \\
1 & ... & {\lambda}_{n-1}^{l-1} & ... & {\lambda}_{n-1}^{L-1} \\
1 & ... & ... & ... & ... \\
1 & ... & {\lambda}_{N-1}^{l-1} & ... & {\lambda}_{N-1}^{L-1} \\
\end{bmatrix} \ \ \ \in \mathbb{R}^{L \times D_{out} \times D_{in}}
※D_{out} \times D_{in}のカーネルK_l が L個並んでいる状態です。実際には D_{out} \times D_{in}の次元ではなく、reshape(-1)
にして横並びに保持して計算されているようです。
こちらは、論文の以下の箇所に該当する解説です。論文ではより一般化した形式で書かれています。

疑似コード解説

log_dt
なぜ log_dt としてパラメータを定義しているのか。
これは \Delta > 0 と強制するためです。dt = np.exp(log_dt)
とする事で、log_dt がどんな値をとろうが、dt > 0
となります。
行列Aの初期化
A = -0.5 + 1j * np.pi * np.arange(N//2)
これは、S4D-Lin という形式で初期化されています。
N//2 のパラメータ数

複素数を扱うので、実質的にはパラメータは半分でいいとのこと。
カーネルの計算
これ多分誤字があって、正しくは return 2 * ((dB*C) @ (dA[:, None] ** np.arange(L))).real
だと思われます。B*C
ではなくdB*C
yの計算
S4同様、カーネルKと入力uをフーリエ変換し、掛け合わせ、逆フーリエ変換して出力yを計算しています。
補足:パラメータB
疑似コードではBを1で初期化してパラメータ化していますが、B,Cというパラメータはカーネルの計算C\={B} でしか関わらず、まとめてC'として扱う事も可能です。
現にコードでは
https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4d.py#L31-L47
となっており、パラメータはBは出てこず、\={B}=A^{-1} (exp(\Delta A) - I) B のBの係数は以下で吸収されています。
https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4d.py#L44
※コードはZOHでの計算になっています。
結果
今回は構造の理解が目的で、あまり結果には言及しませんが、今回の実験では若干S4が良いという結果になっています。

というのも、S4DはS4における低ランク表現を捨てているため、S4より劣るというのは予想通りという事です。
しかしながら、それを捨てる事による計算をシンプルにできるメリットは大きく、また、低ランク表現はSSMの内部状態次元を多くする事で代替できると期待されます。
Discussion