はじめに
この前のブログでは、HiPPOフレームワークとLSSLについて書いた。
https://zenn.dev/izmyon/articles/8374a11d272602
本シリーズでは、mambaの理論的背景を理解するために、それらの先行研究を順々にまとめて解説していく。重要な先行研究としては、LMU, HiPPO, LSSL, S4, H3などがあり、主に以下の流れで解説する。
HiPPOフレームワークとLSSL
S4のアルゴリズム
H3とHyena Hierarchy
mamba
本記事は、第二弾として、S4のアルゴリズムについて書く。
元論文を以下に記す。
S4
https://arxiv.org/abs/2111.00396
読者の方へ
補足や訂正などがあれば、コメントにて優しく丁寧にご教示いただけると喜びます。
書くのけっこう大変だったのでバッチを送っていただけると嬉しいです。たくさんもらえると僕のやる気が上がって投稿頻度が上がるかもしれません。(逆にもらえないと下がるかもしれない。)
おさらい:状態空間モデル
LSSLでは、以下のような状態方程式を考え、1次元の信号u(t) を、1次元の信号y(t) に出力する前に、N 次元の隠れ状態x(t) に投影するのであった。
\begin{align}
x'(t) = A x(t) + B u(t)
\\
y(t) = C x(t) + D u(t)
\end{align}
S4の論文中や解説記事であるThe annotated S4(文献[8])では、行列D の影響が小さく計算が簡単であることから省略されて説明されるが、本記事では特に省略しない。
そして、行列A には、HiPPO-LegSで導かれたHiPPO行列と呼ばれる以下の行列で初期化するのであった。このように初期化することで、x(t) に入力信号u(t) の履歴を記憶することを可能とし、MNISTベンチマークでの性能を60%から98%に上昇させることが出来る。
\begin{align}
A_{nk} = - \left\{
\begin{array}{ll}
(2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\
n+1 & \text{if} \ n=k \\
0 & \text{if} \ n<k \\
\end{array}
\right.
\end{align}
また、式(1)、(2)ではu(t) は連続関数であるとしているが、実際には、ステップサイズ\Delta により離散化された入力系列(u_0, u_1, \ldots) が与えられる。このu_k は、連続関数u(t) からステップサイズ\Delta でサンプリングされたものであると考えられ、u_k = u(k\Delta) である。このような離散的な入力系列(u_0, u_1, \ldots) を扱うために、式(1)、(2)を離散化する必要があり、GBTを用いて以下のように離散化するのであった。
\begin{align}
x_k &= \bar{A} x_{k-1} + \bar{B} u_k
\\
y_k &= \bar{C} x_k + \bar{D} u_k
\\
\text{where,} \notag
\\
\bar{A} &= (I - \frac{\Delta}{2}A)^{-1} (I + \frac{\Delta}{2}A)
\\
\bar{B} &= (I - \frac{\Delta}{2}A)^{-1} \Delta B
\\
\bar{C} &= C, \bar{D} = D
\end{align}
また、式(4)、(5)から、以下が導かれる。
\begin{align}
y_k = (\overline{C A}^k \overline{B} u_0 + \overline{C A}^{k-1} \overline{B} u_1 + \cdots + \overline{C A B}u_{k-1} + \overline{CB}u_k) + \overline{D}u_k
\end{align}
ここで、入力u の系列長がL であるとし、カーネル\bar{K} を以下のように定義する。
\begin{align}
\overline{K} \in \mathbb{R}^L := \mathcal{K}_L (\overline{A}, \overline{B}, \overline{C}):= (\overline{C A}^i \overline{B})_{i \in [L]} = (\overline{CB}, \overline{CAB}, \ldots, \overline{CA}^{L-1} \overline{B})
\end{align}
これを用いると、以下のように表せる。
\begin{align}
y = \bar{K} * u + \overline{D}u_k
\end{align}
ただし、u = (u_0, u_1, \ldots, u_{L-1}) とした。もし、カーネル\bar{K} が既知として与えられたとすれば、式(11)はFFTにより非常に高速に求められる。このカーネル\bar{K} を、SSM畳み込みカーネルと呼ぶことにする。
SSMカーネルの計算に伴う課題
式(4)、(5)の再帰形式、式(11)の畳み込み形式で、\bar{A} のべき乗を大量に含んでいることや、\bar{A} によるベクトル行列演算(MVP; Matrix-Vector Product)が、計算のボトルネックとなる。そこで、これらの計算の効率化を考えるわけだが、まず\bar{A} 、もといA の対角化について考えてみる。
!
ここで、なぜ対角化を考えるのかがわからなかった人のために対角化について説明する。
n 次の正方行列A の固有ベクトルだけでn次元ベクトル空間の基底が構成できる、つまりA がn 本の一次独立な固有ベクトルを持つとき、n 次対角行列\Lambda とn 次正則行列V が存在して、以下のように対角化することが出来る。
ここで、対角行列\Lambda が、以下のように表されるとする。
\Lambda =
\begin{bmatrix}
\lambda_1 & 0 & \cdots & 0 \\
0 & \lambda_2 & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
0 & 0 & \cdots & \lambda_n \\
\end{bmatrix}
すると、\Lambda のべき乗は以下のように簡単に計算できる。
\Lambda^k =
\begin{bmatrix}
\lambda_1^k & 0 & \cdots & 0 \\
0 & \lambda_2^k & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
0 & 0 & \cdots & \lambda_n^k \\
\end{bmatrix}
さらに、元の行列A について、以下が成り立つ。
\begin{align}
A^k &= (V \Lambda V^{-1})^k = (V \Lambda V^{-1})(V \Lambda V^{-1}) \cdots (V \Lambda V^{-1}) \notag
\\
&= V \Lambda (V^{-1} V) \Lambda (V^{-1} V) \Lambda \cdots (V^{-1} V) \Lambda V^{-1} \notag
\\
&= V \Lambda^k V^{-1}
\end{align}
\Lambda^k は簡単に求められたから、式(12)を使うと、A^k も効率よく求めることが出来る。
式(1), (2)の対角化に関して、次の二つの補題が成り立つ。
Lemma [3, p.5, 3.1]
式(1)、(2)のSSMに共役な作用を施しても同じ変換\bold{u \mapsto y} を維持する。\bold{(A, B, C, D) \simeq (V^{-1}AV, V^{-1}B, CV, D)}
証明 :
状態x と状態\tilde{x} がそれぞれ、以下の二つのSSMを満たすとする。
\begin{align}
x' = Ax + Bu
\\
y = Cx + Du
\end{align}
\begin{align}
\tilde{x}' = V^{-1} A V \tilde{x} + V^{-1} B u
\\
y = C V \tilde{x} + D u
\end{align}
式(15)に左からV を掛け、x = V \tilde{x} と置けば、式(15)、(16)は式(13)、(14)のSSMと同じになる。つまり、状態x の基底をV によって変化させることで、これら二つのSSMは全く同じ変換u \mapsto y である。
(証明終)
Lemma [3, p.5, 3.2]
式(3)のHiPPO行列A は、行列\bold{V_{i,j} = \binom{i+j}{i-j}} によって対角化される。特に、\bold{V_{3i,i} = \binom{4i}{2i} \simeq 2^{4i}} であり、\bold{V} は最大で\bold{2^{4N/3}} の大きさの要素を持つ。
(証明略)
これらの補題から、A の代わりにA を対角化した行列を用いて計算することが考えられだろう。実際にこの方法によってカーネル\bar{K} の計算を高速化することが出来るが、残念ながらこの数値計算は安定しない。そこで、対角化ではない別の方法を考え、LSSLで用いられたアルゴリズムが考案されたが、LSSLのアルゴリズムも潜在次元N が大きくなると数値的に安定しないことが証明されている。(詳細は[2, B.2]に記載されているが、ここでは割愛する。)そこで考案されたのが、数値的に安定しかつ高速なS4のアルゴリズムである。
Normal Plus Low-Rank (NPLR)とDiagnoral Plus Low-Rank (DPLR)
ここれまで行列A を対角化けるすることを考えてきたが、実は、行列A として用いるHiPPO行列に関して、以下の定理が成り立ち、NPLRおよびDPLRという形式で表現することが出来る。これを手掛かりとしてS4のアルゴリズムが導かれる。
Theorem [3, p.6, 1]
文献[1]で用いたすべてのHiPPO行列は、ユリタリ行列V \in \mathbb{C}^{N \times N} 、対角行列\Lambda 、そして低ランク行列P, Q \in \mathbb{R}^{N \times r} (21)で表されるNPLR表現を持つ。すべてのHiPPO行列でr=1 かr=2 であり特に、式(3)のHiPPO行列ではr=1 である。
\begin{align}
A = V \Lambda V^* - PQ^T = V (\Lambda - (V^* P)(V^* Q)^* ) V^*
\end{align}
文献[3](S4の論文)では、文献[1](HiPPOの論文)で紹介されたすべてのHiPPO行列について証明を与えているが、HiPPO-LegSの行列A (すなわち、式(3))がNPLR表現を持つことは、式(3)に行列M_{nk} = \frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2} を加えることで示すことが出来る。
\begin{align*}
M_{nk} + A_{nk} =& - \left\{
\begin{array}{ll}
\frac{1}{2}(2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\
\frac{1}{2} & \text{if} \ n=k \\
-\frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2} & \text{if} \ n<k \\
\end{array}
\right.
\end{align*}
ここで、以下のように交代行列S を定義する。
\begin{align*}
S_{nk} = - \left\{
\begin{array}{ll}
\frac{1}{2}(2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\
0 & \text{if} \ n=k \\
-\frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2} & \text{if} \ n<k \\
\end{array}
\right.
\end{align*}
すると、M+A は、
\begin{align*}
M + A = - \frac{1}{2} I + S
\end{align*}
と表せる。したがって、
\begin{align}
A = (- \frac{1}{2} I + S) - M
\end{align}
交代行列S は、ユリタリ行列により対角化され、同じ行列により- \frac{1}{2} I + S も対角化することが出来る。
さらに、M はN \times 1 の列ベクトルP_n = \frac{1}{2}(2n+1)^{1/2} およびQ_n = (2n+1)^{1/2} 用いると、M=PQ^T で表すことができ、- \frac{1}{2} I + S を対角化することでV \Lambda V^* と表すことで切るとすると、以下のように式(18)はNPLR形式であることが分かる。
\begin{align*}
A = (- \frac{1}{2} I + S) - M = V \Lambda V^* - PQ^T
\end{align*}
さらに、NPLRは、以下の式(19)で表されるDiagnoral Plus Low-Rank (DPLR)形式と共役である。このことは、NPLRの式(17)からすぐに分かるだろう。
\begin{align}
V^* A V = \Lambda - (V^*P)(V^*Q)^*
\end{align}
S4では、実験的に性能の向上が確認されている、式(3)のNPLR形式で表されるHiPPO行列(Theorem [3, p.6, 1])で行列A を初期化し、(A, B, C, D) によるSSMを考えるが、Lemma[3, p.5, 3.1]により、ある\Lambda とベクトルP, Q, \tilde{B}, \tilde{C}, D \in \mathbb{C}^{N \times 1} に対し、同じ変換を維持する(\Lambda - PQ^*, \tilde{B}, \tilde{C}, D) が存在するため、行列A がDPLRであるものとして考えることが出来る。
!
ここをもう少し詳しく考える。
Theorem [3, p.6, 1]により、HiPPO行列A がある\Lambda とベクトルP, Q, B, C \in \mathbb{C}^{N \times 1} 、ユリタリ行列V によって以下のように表されるとする。
\begin{align*}
A = V \Lambda V^* - PQ^T = V (\Lambda - (V^* P)(V^* Q)^* ) V^*
\end{align*}
これを式(13), (14)のSSMに代入すると、以下のようになる。
\begin{align*}
x' &= V (\Lambda - (V^* P)(V^* Q)^* ) V^* x + Bu
\\
y &= Cx + Du
\end{align*}
両辺に左からV^* を掛けると以下のようになる。
\begin{align*}
V^* x' &= (\Lambda - (V^* P)(V^* Q)^* ) V^* x + V^* Bu
\\
y &= Cx + Du
\end{align*}
ここで、\tilde{x} = V^* x 、\tilde{C} = CV 、\tilde{B} = V^* B 、\tilde{P} = V^* P 、\tilde{Q} = V^* Q と置くと、
\begin{align*}
\tilde{x}' &= (\Lambda - \tilde{P}\tilde{Q}^* ) \tilde{x} + \tilde{B}u
\\
y &= \tilde{C}\tilde{x} + Du
\end{align*}
これは、(\Lambda - \tilde{P}\tilde{Q}^*, \tilde{B}, \tilde{C}, D) のSSMである。
以上より、(A, B, C, D) のSSMは、\tilde{x} = V^* x 、\tilde{C} = CV 、\tilde{B} = V^* B 、\tilde{P} = V^* P 、\tilde{Q} = V^* Q と置くことで、(\Lambda - \tilde{P}\tilde{Q}^*, \tilde{B}, \tilde{C}, D) のSSMに変換されることが分かる。
したがって、NPLRであるHiPPO行列A で初期化したとしても、行列A の代わりに\Lambda - \tilde{P}\tilde{Q}^* を用いたSSMと等価であると考えてアルゴリズムを導出することが出来る。
S4のアルゴリズム
S4のアルゴリズムを用いれば、行列A がDPLR形式であるとき、再帰形式でO(N) 、畳み込み形式で\tilde{O}(N+L) の計算量とO(N+L) の空間量で計算することが出来る。これから、アルゴリズムの詳細を解説する。
ここで、以下の補題で表されるWoodburyの行列恒等式が鍵を握る。
Proposition [3, p.23, 4]
可換環\bold{\mathcal{R}} 上において、\bold{A \in \mathcal{R}^{N \times N}} 、\bold{U, V \in \mathcal{R}^{N \times p}} とする。\bold{A} と\bold{A+UV^*} が可逆であるとすると、\bold{I_p + V^* A^{-1} U \in \mathcal{R}^{p \times p}} は可逆であり、以下が成り立つ。
\begin{align}
(A + UV^*)^{-1} = A^{-1} - A^{-1} U (I_p + V^* A^{-1} U)^{-1} V^* A^{-1}
\end{align}
再帰形式
再帰形式では、式(4)、(5)を計算していく。A = \Lambda - PQ^* というDPLR形式で表されるとき、式(6)、(7)に従って離散化された\bar{A} および\bar{B} を求める。特に、\bar{A} は、(I - \Delta/2 \cdot A)^{-1} と(I + \Delta/2 \cdot A) の積で表され、\bar{B} も(I - \Delta/2 \cdot A)^{-1} を持ち、しかもA を含むのはこの項のみである。そこで、(I + \Delta/2 \cdot A) と(I - \Delta/2 \cdot A)^{-1} をそれぞれ別々に求めてみる。
(I + \Delta/2 \cdot A) を求める。
\begin{align}
I + \frac{\Delta}{2} A &= I + \frac{\Delta}{2} (\Lambda - PQ^*) \notag
\\
&= \frac{\Delta}{2} [\frac{2}{\Delta} I + (\Lambda - PQ^*)] \notag
\\
&= \frac{\Delta}{2} A_0
\end{align}
ここで、以下のように置いた。
\begin{align}
A_0 = \frac{2}{\Delta} I + (\Lambda - PQ^*)
\end{align}
(I - \Delta/2 \cdot A)^{-1} を求める。
\begin{align}
(I - \frac{\Delta}{2} A)^{-1} &= (I - \frac{\Delta}{2} (\Lambda - PQ^*))^{-1} \notag
\\
&= (\frac{\Delta}{2} (\frac{2}{\Delta}- \Lambda + PQ^*))^{-1} \notag
\\
&= \frac{2}{\Delta} (\frac{2}{\Delta}- \Lambda + PQ^*)^{-1} \notag
\\
&= \frac{2}{\Delta} [R - RP(I + Q^* RP)^{-1} Q^* R] \quad (\because Woodburyの行列恒等式を用いた。) \notag
\\
&= \frac{2}{\Delta} A_1
\end{align}
ここで、以下のように置いた。
\begin{align}
R &= (\frac{2}{\Delta} - \Lambda)^{-1}
\\
A_1 &= R - RP(I + Q^* RP)^{-1} Q^* R
\end{align}
以上を用いると、S4の再帰形式用いる\bar{A} 、\bar{B} は、以下のように表せる。
\begin{align}
\bar{A} &= A_1 A_0
\\
\bar{B} &= \frac{2}{\Delta} A_1 \Delta B = 2 A_1 B
\end{align}
従って、離散化されたSSM、式(4)、(5)は、以下のように表せる。
\begin{align}
x_k &= \bar{A} x_{k-1} + \bar{B} u_k \notag
\\
&= A_1 A_0 x_{k-1} + 2 A_1 B u_k
\\
y_k &= C x_k
\end{align}
ここで、A_0 、A_1 は、どちらもDPLR形式で表されるため、両方ともO(N) の行列ベクトル積によって計算することができ、以下の定理が成り立つ。
Theorem [3, p.6, 2]
任意のステップサイズ\Delta が与えられた時、再帰計算のの1ステップは、潜在次元をN としたときにO(N) の計算量で求められる。
畳み込み形式
畳み込み形式の計算で重要なのは、SSMカーネル\bar{K} を効率的に計算することであり、本節ではその方法を考える。
S4で用いるアルゴリズムを計算するために、行列C を\mathbb{C}^{1 \times N} の形になるように転置し、B, P, Q と同じ形状を持つようにする。
\begin{align}
x'(t) &= A x(t) + B u(t)
\\
y(t) &= C^* x(t) + D u(t)
\end{align}
この時、式(10)のSSMカーネルは、以下のように表される。
\begin{align}
\mathcal{K}_L (\overline{A}, \overline{B}, \overline{C}) = (\overline{C}^* \overline{B}, \overline{C}^* \overline{AB}, \ldots, \overline{C}^*\overline{A}^{L-1} \overline{B}) \in \mathbb{R}^L
\end{align}
ここで、SSM生成関数を、z の関数として以下のように定義する。
\begin{align}
\hat{\mathcal{K}}(z; \bar{A}, \bar{B}, \bar{C}) \in \mathbb{C} := \sum^{\infty}_{i=0} \overline{C}^*\overline{A}^i \overline{B} z^i = \overline{C}^* \overline{B} \sum^{\infty}_{i=0} (\overline{A} z)^i = \overline{C}^* (I - \overline{A} z)^{-1} \overline{B}
\end{align}
さらに、長さL までで切り捨てられたSSM生成関数を、以下のように定義する。
\begin{align}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C}) \in \mathbb{C} := \sum^{L-1}_{i=0} \overline{C}^*\overline{A}^i \overline{B} z^i = \overline{C}^* (I - \overline{A}^L z^L)(I - \overline{A}z)^{-1} \overline{B}
\end{align}
ここでの目的は\bar{A} のべき乗を多く含む式(32)のSSMカーネルを直接計算するのではなく、より効率良く計算することであるが、式(33), (34)のように生成関数を定義するのは、以下が成り立つからである。
Lemma [3, p.23, C.2]
SSM関数\mathcal{K}_L (\overline{A}, \overline{B}, \overline{C}) は、1のべき根\Omega = {\exp(- 2 \pi i \frac{k}{L}) : k \in [L])} (ただし、i は虚数単位である。)におけるSSM生成関数\hat{\mathcal{K}}_L (\Omega; \bar{A}, \bar{B}, \bar{C}) からO(L \log L) で安定して求めることが出来る。
証明 :
\begin{align*}
\overline{K} &= \mathcal{K}_L (\overline{A}, \overline{B}, \overline{C}) = (\overline{C}^* \overline{B}, \overline{C}^* \overline{AB}, \ldots, \overline{C}^*\overline{A}^{L-1} \overline{B})
\\
\hat{K} &= \hat{\mathcal{K}}_L (\Omega; \bar{A}, \bar{B}, \bar{C})
\end{align*}
と表すこととし、さらに、\overline{K} のk 番目の要素を\overline{K}_k と表すとする。
つまり、以下のように表すことにする。
\begin{align*}
\overline{K}_k = \overline{C}^*\overline{A}^k \overline{B}
\end{align*}
さらに、以下のように\hat{K}_j を定義する。
\begin{align}
\hat{K}_j = \sum^{L-1}_{k=0} \overline{K}_k \exp(- 2 \pi i \frac{jk}{L})
\end{align}
式(35)は離散フーリエ変換(DFT)の式と一致しており、\overline{K} を離散フーリエ変換した関数が\hat{K} であることを示している。
したがって逆に、\hat{K} を求めることが出来れば、\overline{K} は逆離散フーリエ変換により求めることができ、高速フーリエ変換(FFT)アルゴリズムを用いれば、O(L \log L) で計算できる。
(証明終)
したがって、ここからは\hat{K} を求める問題について考えるのであるが、実は、\hat{K} について以下の補題が成り立つ。
Lemma [3, p.23, C.3]
\bold{A} が\bold{A=\Lambda - PQ^*} というDPLR形式を持つとすると、任意の1のべき根\bold{z \in \Omega} に対して、長さ\bold{L} までのSSM生成関数は、以下を満たす。
\begin{align}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C}) &= \frac{2}{1+z}[\tilde{C}^*R(z)B - \tilde{C}^*R(z)P (1 + Q^*R(z)P)^{-1} Q^* R(z)B]
\\
\tilde{C} &= (I - \overline{A}^L)^* \overline{C}
\\
R(z; \Lambda) &= (\frac{2}{\Delta} \frac{1-z}{1+z} I - \Lambda)^{-1}
\end{align}
証明 :
\begin{align*}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C})
&= \overline{C}^*\overline{B} + \overline{C}^*\overline{AB}z + \cdots + \overline{C}^*\overline{A}^{L-1}\overline{B} z^{L-1}
\\
&= \overline{C}^* (I - \overline{A}^L) (I - \overline{A}z)^{-1} \overline{B}
\\
&= \tilde{C}^* (I - \overline{A}z)^{-1} \overline{B}
\end{align*}
ここで、
\begin{align}
\tilde{C}^* =\overline{C}^* (I - \overline{A}^L)
\end{align}
とした。これは式(37)の共役転置である。
ここで、式(6)、(7)で以下のように離散化したことを思い出す。
\begin{align*}
\bar{A} &= (I - \frac{\Delta}{2}A)^{-1} (I + \frac{\Delta}{2}A)
\\
\bar{B} &= (I - \frac{\Delta}{2}A)^{-1} \Delta B
\end{align*}
まず\bar{A} = (I - \frac{\Delta}{2}A)^{-1} (I + \frac{\Delta}{2}A) を代入し、
\begin{align}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C}) =& \tilde{C}^* (I - \overline{A} z ) \notag
\\
=& \tilde{C}^* [(I - \frac{\Delta}{2}A)^{-1} (I - \frac{\Delta}{2}A) - (I - \frac{\Delta}{2}A)^{-1} (I + \frac{\Delta}{2}A) z ]^{-1} \overline{B} \notag
\\
=& \tilde{C}^* [(I - \frac{\Delta}{2}A) - (I + \frac{\Delta}{2}A) z]^{-1} (I - \frac{\Delta}{2}A) \overline{B} \notag
\\
=& \tilde{C}^* [I(1-z) - \frac{\Delta}{2}A(1+z)]^{-1} \Delta B \notag
\\
&(\because 式(7)より、\Delta B = \bar{B} (I - \frac{\Delta}{2}A)) \notag
\\
=& \tilde{C}^* [\frac{1+z}{2} ( 2 \frac{1-z}{1+z} I - \Delta A)]^{-1} \Delta B \notag
\\
=& \frac{2}{1+z} \tilde{C}^* [ \frac{2}{\Delta} \frac{1-z}{1+z} I - A]^{-1} B \notag
\\
=& \frac{2}{1+z} \tilde{C}^* [ \frac{2}{\Delta} \frac{1-z}{1+z} I - \Lambda + PQ^* ]^{-1} B \notag
\\
=& \frac{2}{1+z} \tilde{C}^* [ R(z)^{-1} + PQ^* ]^{-1} B
\end{align}
ここで、
\begin{align}
R(z) = (\frac{2}{\Delta} \frac{1-z}{1+z} I - \Lambda)^{-1}
\end{align}
と置いた。
ここで、Woodburyの行列恒等式を用いると、
\begin{align}
[ R(z)^{-1} + PQ^* ]^{-1} = R(z) - R(z) P (I + Q^* R(z) P)^{-1} Q^* R(z)
\end{align}
従って、式(40)に式(42)を代入して、
\begin{align*}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C}) &= \frac{2}{1+z} \tilde{C}^* [ R(z) - R(z) P (I + Q^* R(z) P)^{-1} Q^* R(z) ] B
\\
&= \frac{2}{1+z}[\tilde{C}^* R(z) B - \tilde{C}^* R(z) P (I + Q^* R(z) P)^{-1} Q^* R(z) B]
\end{align*}
が成り立つ。
(証明終)
この補題により、\hat{K} は式(36)で表すことが出来ることが分かった。この式で\hat{K} を求め、iFFTにより\overline{K} を求めるのである。
従って次は式(36)でどうやって\hat{K} を求めるのかということを考える。式(36)の右辺を見てみると、\hat{K} を求めるためには、R(z) が真ん中にある四つの積、\tilde{C}^* R(z) B 、\tilde{C}^* R(z) P 、Q^* R(z) P 、Q^* R(z) B を含むことが分かる。そして、これを計算できればあとは簡単に計算できそうである。ここで、実は、式(41)で表されるR(z) は以下のように定義されるコーシー行列と呼ばれるものである。
Definition [3, p.25, 3]
ノード\bold{\Omega = (\omega_i ) \in \mathbb{C}^M} および\bold{\Lambda = (\lambda_j) \in \mathbb{C}^N} 上のコーシー行列またはカーネルは以下で表される。
\begin{align*}
M \in \mathbb{C}^{M \times N} = M(\Omega, \Lambda) = (M_ij)_{i \in [M], j \in [N]}
\quad M_ij = \frac{1}{\omega_i - \lambda_j}
\end{align*}
また、サイズM \times N のコーシー行列の行列ベクトル積の計算時間をC(M, N) と表すこととする。コーシー行列の行列ベクトル積は、高速多重極法(Fast Multipole Method; FFM)と呼ばれるアルゴリズムを用いることで、以下が成り立つ。
Proposition [3, p.25, 5]
コーシー核は\bold{O(M+N)} の空間を必要とし、以下の計算量を必要とする。
\begin{align*}
C(M, N) = \left\{
\begin{array}{ll}
O(MN) & \text{if} \ ナイーブな計算 \\
O((M+N) \log^2 (M+N)) & 正確な計算 \\
O((M+N) \log (M+N) \log \frac{1}{\epsilon}) & εの精度での数値計算 \\
\end{array}
\right.
\end{align*}
したがって、ビッグオー記法O から対数因子を無視した記法である、「ソフトオー」と呼ばれる記法\tilde{O} を用いると、以下が成り立つ。
Corollary [3, p.25, C.5]
任意のノード集合\bold{\Omega \in \mathbb{C}^L} 、対角行列\bold{\Lambda} 、およびベクトル\bold{P} 、\bold{Q} に対して、\bold{Q^*R(\Omega; \Lambda)P} を評価すると、\bold{C(L, N)} の計算量と\bold{O(L+N)} の空間量で計算できる。ここで、\bold{C(L, N) = \tilde{O}(L+N)} はコーシー行列の行列ベクトル積のコストである。
このことから、\hat{K} を求めるために必要な四の行列ベクトル積、\tilde{C}^* R(z) B 、\tilde{C}^* R(z) P 、Q^* R(z) P 、Q^* R(z) B は、\tilde{O}(L+N) で計算することが出来、以下のように行列K を定義すると、一括で計算できる。
\begin{align}
K &=
\begin{bmatrix}
\tilde{C} & Q
\end{bmatrix}^*
R(z)
\begin{bmatrix}
B & P
\end{bmatrix}
=
\begin{bmatrix}
\tilde{C}^* \\
Q^*
\end{bmatrix}
R(z)
\begin{bmatrix}
B & P
\end{bmatrix}
=
\begin{bmatrix}
\tilde{C}^* R(z) \\
Q^* R(z)
\end{bmatrix}
\begin{bmatrix}
B & P
\end{bmatrix} \notag
\\
&=
\begin{bmatrix}
\tilde{C}^* R(z) B & \tilde{C}^* R(z) P\\
Q^* R(z) B & Q^* R(z) P
\end{bmatrix}
\end{align}
K の各要素を以下のように表すことにする。
\begin{align*}
K =
\begin{bmatrix}
k_{00} (\omega) & k_{01} (\omega) \\
k_{10} (\omega) & k_{11} (\omega) \\
\end{bmatrix}
\end{align*}
すると、\hat{K} は。K の各要素を用いて、以下のように計算することが出来る。
\begin{align*}
\hat{\mathcal{K}}_L (z; \bar{A}, \bar{B}, \bar{C}) &= \frac{2}{1+z}[k_{00} (\omega) - k_{01} (\omega) (1 + k_{11} (\omega))^{-1} k_{10} (\omega)]
\end{align*}
このようにして\hat{K} を求めた後、iFFTによって\overline{K} をO(L \log L) で求めるのである。
S4のSSMカーネルを求めるアルゴリズム
ここまでの流れをまとめると、S4のSSM畳み込みカーネル\overline{K} を求めるアルゴリズムは、以下のようになる。
S4の畳み込みカーネルアルゴリズム
Input: S4パラメータ\Lambda, P, Q, B, C \in \mathbb{C}^N とステップサイズ\Delta
Output: A = \Lambda - PQ^* とした際のSSN畳み込みカーネル \overline{K}=\mathcal{K}_L (\overline{A}, \overline{B}, \overline{C})
1. \tilde{C} \leftarrow (I - \overline{A}^L)^* \overline{C}
2. \begin{bmatrix} k_{00} (\omega) & k_{01} (\omega) \\ k_{10} (\omega) & k_{11} (\omega) \\ \end{bmatrix} \leftarrow \begin{bmatrix} \tilde{C} & Q \end{bmatrix}^* R(z) \begin{bmatrix} B & P \end{bmatrix}
3. \hat{\mathcal{K}} (z) \leftarrow \frac{2}{1+z}[k_{00} (\omega) - k_{01} (\omega) (1 + k_{11} (\omega))^{-1} k_{10} (\omega)]
4. \hat{\mathcal{K}} = \{\hat{\mathcal{K}}(\omega): \omega = \exp(2 \pi i \frac{k}{L}) \}
5. \overline{K} \leftarrow iFFT(\hat{K})
以上より、このアルゴリズムについて、以下の定理が成り立つ。
Theorem [3, p.6, 3]
任意のステップサイズ\Delta が与えられた時、SSM畳み込みカーネル\overline{K} を求める計算は、4つのコーシーカーネルを求める計算に削減され、\tilde{O}(N+L) の計算量とO(N+L) の空間量を必要とする。
このアルゴリズムによりSSMカーネル\bar{K} を求めた後、式(11)をFFTにより高速に計算することで、出力y を求めることが出来る。
Deep S4 Layer
Under Constructionn
実装とともに理解したい方へ
ここまで理解した読者であれば、Sasha Rush氏とSidd Karamcheti氏による以下の解説記事The annotated S4(文献[8])をすらすら読むことが出来ると思われる。JAXによる実装がついた非常に詳細で面白い記事となっているので、是非読んでみてください。
https://srush.github.io/annotated-s4/
S4D
Under Constructionn
参考文献
[1] Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33, 1474-1487.
[2] Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., & Ré, C. (2021). Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, 572-585.
[3] Gu, A., Goel, K., & Ré, C. (2021). Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
[4] Gu, A., Johnson, I., Timalsina, A., Rudra, A., & Ré, C. (2022). How to train your hippo: State space models with generalized orthogonal basis projections. arXiv preprint arXiv:2206.12037.
[5] 角田良太郎「HiPPO/S4解説」, https://techblog.morphoinc.com/entry/2022/05/24/102648 , 2024年2月10日閲覧
[6] Guofeng Zhang, Tongwen Chen, and Xiang Chen. Performance recovery in digital implementation of analogue systems. SIAM journal on control and optimization, 45(6):2207–2223, 2007.
[7] 竹内修「線形代数Ⅰ」, https://dora.bk.tsukuba.ac.jp/~takeuchi/?線形代数I , 2024年2月15日閲覧
[8] Sasha Rush and Sidd Karamcheti, "The Annotated S4", https://srush.github.io/annotated-s4/ , accessed on March 4th 2024.
Discussion