🎃

時系列データ分析 論文解説⑤「 S4D 」

に公開

論文

前回は、S4の理論編実装編を解説しました。

今回は、その続編である以下の論文について解説します。

タイトル

On the Parameterization and Initialization of Diagonal State Space Models

論文: https://arxiv.org/pdf/2206.11893
Github: https://github.com/state-spaces/s4/blob/main/models/s4/s4d.py

概要

  • S4 をシンプルにしたモデル
  • 行列Aを対角行列に制限した(低ランク表現をやめた)
  • 行列Aの実部をマイナスに制限した
  • シンプルになったにも関わらず、精度は概ねS4と同等である

SSMのおさらい

入力uの次元D_{in}=2, 内部状態xの次元D_x=3, 出力yD_{out}=4, 入力uの時系列長Lの場合(簡単のため行列Aは対角行列)

u_0=\begin{bmatrix} u_{00} \\ u_{01} \end{bmatrix}, \ ... \ , \ \ u_t=\begin{bmatrix} u_{t0} \\ u_{t1} \end{bmatrix}, \ ... \ , \ \ u_L=\begin{bmatrix} u_{L0} \\ u_{L1} \end{bmatrix}
\begin{bmatrix} x_{t}^{(0)} \\ x_{t}^{(1)} \\ x_{t}^{(2)} \\ \end{bmatrix}= \underbrace{ \begin{bmatrix} {\lambda}_{0} & 0 & 0 \\ 0 & {\lambda}_{1} & 0 \\ 0 & 0 & {\lambda}_{2} \end{bmatrix} }_{A} \begin{bmatrix} x_{t-1}^{(0)} \\ x_{t-1}^{(1)} \\ x_{t-1}^{(2)} \\ \end{bmatrix} + \underbrace{ \begin{bmatrix} b_{00} & b_{01} \\ b_{10} & b_{11} \\ b_{20} & b_{21} \\ \end{bmatrix} }_{B} \begin{bmatrix} u_{t0} \\ u_{t1} \end{bmatrix}
\begin{bmatrix} y_{t}^{(0)} \\ y_{t}^{(1)} \\ y_{t}^{(2)} \\ y_{t}^{(3)} \end{bmatrix}= \underbrace{ \begin{bmatrix} c_{00} & c_{01} & c_{02} \\ c_{10} & c_{11} & c_{12} \\ c_{20} & c_{21} & c_{22} \\ c_{30} & c_{31} & c_{32} \\ \end{bmatrix} }_{C} \begin{bmatrix} x_{t}^{(0)} \\ x_{t}^{(1)} \\ x_{t}^{(2)} \\ \end{bmatrix} + \underbrace{ \begin{bmatrix} d_{00} & d_{01} \\ d_{10} & d_{11} \\ d_{20} & d_{21} \\ d_{30} & d_{31} \\ \end{bmatrix} }_{D} \begin{bmatrix} u_{t0} \\ u_{t1} \end{bmatrix}

愚直に変数表現で表すと上記のようになります(自分にとっては分かりやすいです)。基本的にはD=0と考えますが、念のため全部書き下してみました。

行列A

ざっくり言えば、S4DはS4に比べて行列Aをシンプルにしたモデルです。ではその詳細について見てみましょう。

A= \begin{bmatrix} {\lambda}_{1} & ... & 0 \\ ... & ... & ... \\ 0 & ... & {\lambda}_{n} \end{bmatrix}

行列のランクを無くす

S4では行列AがDPLR(Diagonal Plus Low-Ran)形式で記述できる事で、カーネルの計算を簡略化できていました。改めて式を記述すると以下になります。

A'=V^*AV=\Lambda - PQ^*

S4DではこのPQ^*という項を無くしています。え、勝手にそんな事をしていいの? って感じですが、この項が無くなると、どういう影響があるのかを見てみましょう。

この項は、カーネル計算に影響します。改めてカーネルの計算式を書いてみます。

{\hat{K}_L}'(z)=\frac{2}{1+z} \tilde{C}' \left( R(z) - R(z)P(I+Q^*R(z)P)^{-1}Q^*R(z) \right) B'
R(z)=\left( \frac{1-z}{1+z}sI-\Lambda \right)^{-1}

{\hat{K}_L}'(z)の第二項にPQ^*が現れます。これはつまり、対角行列R(z)の固有値を混ぜ合わせて複雑な状態を表現している訳です。

この低ランク項PQ^*を無くすという事は、この固有値の相互作用が発生する項を無くす、つまり表現力が落ちる事を意味します。この表現力に関する議論は Expressivity and Limitations of Diagonal SSMs で記述されているような気がしますが、ちょっと難しいので割愛します。

実部の制約

とありますが、S4Dでは行列Aの実部がマイナスになるように制限しています。これは、長い時系列でカーネルが発散するのを防ぐためです。

ReLU形式なども取れるとしており、行列Aは以下のような形になります。

A=-exp(A_{Re})+i\cdot A_{Img} \ \ \ or \ \ \ A=-ReLU(A_{Re})+i\cdot A_{Img}

連続値SSMのカーネルCe^{tA}B、離散SSMのカーネルC\={A}^{t-1}\={B} となっています。Aの実部がマイナスであれば連続値SSMのカーネルは発散せず、離散SSMでは\={A}=(I-\Delta /2A)^{-1}(I+\Delta /2A) (Bilinear形式)なので、Aがマイナスであれば \={A}<1となり、離散SSMのカーネルは発散しません。

上図で、横軸は時系列長で、縦軸はカーネルの値です。時間が大きい箇所(つまり過去の情報、K_{l=L}となるような時系列点)では値が減衰している事が分かります。

上図の疑似コードはZOH展開での計算式(以下)を使用しており、 (exp(dt*A)-1)/A となっている箇所は、\={B}を展開した時の項が現れています。

カーネルの計算がぐっと楽になっている事が分かります。

S4D kernel

ある時系列点lのカーネルはK_l=\={C}{\={A}}^{l-1}\={B}\ \ (\={C}=C)のように表せるので(簡単のため、変数表記を統一します)

K_l= \begin{bmatrix} c_{00} & c_{01} & c_{02} \\ c_{10} & c_{11} & c_{12} \\ c_{20} & c_{21} & c_{22} \\ c_{30} & c_{31} & c_{32} \\ \end{bmatrix} \begin{bmatrix} {\lambda}_{0}^{l-1} & 0 & 0 \\ 0 & {\lambda}_{1}^{l-1} & 0 \\ 0 & 0 & {\lambda}_{2}^{l-1} \end{bmatrix} \begin{bmatrix} b_{00} & b_{01} \\ b_{10} & b_{11} \\ b_{20} & b_{21} \\ \end{bmatrix}
=\begin{bmatrix} c_{00}{\lambda}_{0}^{l-1}b_{00} + c_{01}{\lambda}_{1}^{l-1}b_{10} + c_{02}{\lambda}_{2}^{l-1}b_{20} & c_{00}{\lambda}_{0}^{l-1}b_{01} + c_{01}{\lambda}_{1}^{l-1}b_{11} + c_{02}{\lambda}_{2}^{l-1}b_{21} \\ c_{10}{\lambda}_{0}^{l-1}b_{00} + c_{11}{\lambda}_{1}^{l-1}b_{10} + c_{12}{\lambda}_{2}^{l-1}b_{20} & c_{10}{\lambda}_{0}^{l-1}b_{01} + c_{11}{\lambda}_{1}^{l-1}b_{11} + c_{12}{\lambda}_{2}^{l-1}b_{21} \\ c_{20}{\lambda}_{0}^{l-1}b_{00} + c_{21}{\lambda}_{1}^{l-1}b_{10} + c_{22}{\lambda}_{2}^{l-1}b_{20} & c_{20}{\lambda}_{0}^{l-1}b_{01} + c_{21}{\lambda}_{1}^{l-1}b_{11} + c_{22}{\lambda}_{2}^{l-1}b_{21} \\ c_{30}{\lambda}_{0}^{l-1}b_{00} + c_{31}{\lambda}_{1}^{l-1}b_{10} + c_{32}{\lambda}_{2}^{l-1}b_{20} & c_{30}{\lambda}_{0}^{l-1}b_{01} + c_{31}{\lambda}_{1}^{l-1}b_{11} + c_{32}{\lambda}_{2}^{l-1}b_{21} \\ \end{bmatrix}
=\begin{bmatrix} c_{00}{\lambda}_{0}^{l-1}b_{00} & c_{00}{\lambda}_{0}^{l-1}b_{01} \\ c_{10}{\lambda}_{0}^{l-1}b_{00} & c_{10}{\lambda}_{0}^{l-1}b_{01} \\ c_{20}{\lambda}_{0}^{l-1}b_{00} & c_{20}{\lambda}_{0}^{l-1}b_{01} \\ c_{30}{\lambda}_{0}^{l-1}b_{00} & c_{30}{\lambda}_{0}^{l-1}b_{01} \\ \end{bmatrix} + \begin{bmatrix} c_{01}{\lambda}_{1}^{l-1}b_{10} & c_{01}{\lambda}_{1}^{l-1}b_{11} \\ c_{11}{\lambda}_{1}^{l-1}b_{10} & c_{11}{\lambda}_{1}^{l-1}b_{11} \\ c_{21}{\lambda}_{1}^{l-1}b_{10} & c_{21}{\lambda}_{1}^{l-1}b_{11} \\ c_{31}{\lambda}_{1}^{l-1}b_{10} & c_{31}{\lambda}_{1}^{l-1}b_{11} \\ \end{bmatrix} + \begin{bmatrix} c_{02}{\lambda}_{2}^{l-1}b_{20} & c_{02}{\lambda}_{2}^{l-1}b_{21} \\ c_{12}{\lambda}_{2}^{l-1}b_{20} & c_{12}{\lambda}_{2}^{l-1}b_{21} \\ c_{22}{\lambda}_{2}^{l-1}b_{20} & c_{22}{\lambda}_{2}^{l-1}b_{21} \\ c_{32}{\lambda}_{2}^{l-1}b_{20} & c_{32}{\lambda}_{2}^{l-1}b_{21} \\ \end{bmatrix}

行列B,C を内部状態の次元に対して以下のように考えてみます。

B=\begin{bmatrix} B_0 \\ B_1 \\ B_2 \\ \end{bmatrix}, \ \ \ C=\begin{bmatrix} C_0 & C_1 & C_2 \end{bmatrix}, \ \ \ \left( B_0= \begin{bmatrix} b_{00} & b_{01} \end{bmatrix}, \ \ \ C_0= \begin{bmatrix} c_{00} \\ c_{10} \\ c_{20} \\ c_{30} \\ \end{bmatrix} \right)

この形式で表現すると

K_l={\lambda}_{0}^{l-1} C_0 B_0 + {\lambda}_{1}^{l-1} C_1 B_1 + {\lambda}_{2}^{l-1} C_2 B_2 \ \ \ \in \mathbb{R}^{4\times 2}

内部状態x の次元D_x=N で一般化すると(D_{in}, D_{out} の次元)

K_l=\sum_{n=0}^{N-1}{\lambda}_{n}^{l-1} C_n B_n \ \ \ \in \mathbb{R}^{D_{out}\times D_{in}}

さらに、カーネルK_L時系列長Lのベクトルとして一般化すると

K_L= \begin{bmatrix} C_0 B_0 & ... & C_n B_n & ... C_{N-1} B_{N-1} \end{bmatrix} \begin{bmatrix} 1 & ... & {\lambda}_{0}^{l-1} & ... & {\lambda}_{0}^{L-1} \\ 1 & ... & ... & ... & ... \\ 1 & ... & {\lambda}_{n-1}^{l-1} & ... & {\lambda}_{n-1}^{L-1} \\ 1 & ... & ... & ... & ... \\ 1 & ... & {\lambda}_{N-1}^{l-1} & ... & {\lambda}_{N-1}^{L-1} \\ \end{bmatrix} \ \ \ \in \mathbb{R}^{L \times D_{out} \times D_{in}}

D_{out} \times D_{in}のカーネルK_lL個並んでいる状態です。実際には D_{out} \times D_{in}の次元ではなく、reshape(-1) にして横並びに保持して計算されているようです。

こちらは、論文の以下の箇所に該当する解説です。論文ではより一般化した形式で書かれています。

疑似コード解説

log_dt

なぜ log_dt としてパラメータを定義しているのか。

これは \Delta > 0 と強制するためです。dt = np.exp(log_dt) とする事で、log_dt がどんな値をとろうが、dt > 0 となります。

行列Aの初期化

A = -0.5 + 1j * np.pi * np.arange(N//2) これは、S4D-Lin という形式で初期化されています。

N//2 のパラメータ数

複素数を扱うので、実質的にはパラメータは半分でいいとのこと。

カーネルの計算

これ多分誤字があって、正しくは return 2 * ((dB*C) @ (dA[:, None] ** np.arange(L))).real だと思われます。B*CではなくdB*C

yの計算

S4同様、カーネルKと入力uをフーリエ変換し、掛け合わせ、逆フーリエ変換して出力yを計算しています。

補足:パラメータB

疑似コードではBを1で初期化してパラメータ化していますが、B,Cというパラメータはカーネルの計算C\={B} でしか関わらず、まとめてC'として扱う事も可能です。

現にコードでは

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4d.py#L31-L47

となっており、パラメータはBは出てこず、\={B}=A^{-1} (exp(\Delta A) - I) BBの係数は以下で吸収されています。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4d.py#L44

※コードはZOHでの計算になっています。

結果

今回は構造の理解が目的で、あまり結果には言及しませんが、今回の実験では若干S4が良いという結果になっています。

というのも、S4DはS4における低ランク表現を捨てているため、S4より劣るというのは予想通りという事です。

しかしながら、それを捨てる事による計算をシンプルにできるメリットは大きく、また、低ランク表現はSSMの内部状態次元を多くする事で代替できると期待されます。

Discussion