はじめに
この前のブログでは、mambaの論文を翻訳した。
https://izmyon.hatenablog.com/entry/2023/12/11/155551
本シリーズでは、mambaの理論的背景を理解するために、それらの先行研究を順々にまとめて解説していく。重要な先行研究としては、LMU, HiPPO, LSSL, S4, H3などがあり、主に以下の流れで解説する。
HiPPOフレームワークとLSSL
S4のアルゴリズム
H3とHyena Hierarchy
mamba
本記事は、第一段として、HiPPOフレームワークとLSSLについてまとめる。
元論文を以下に記す。
HiPPO
https://arxiv.org/abs/2008.07669
LSSL
https://arxiv.org/abs/2110.13985
この記事の続きはこちら
https://zenn.dev/izmyon/articles/c56a2fd6670546
読者の方へ
補足や訂正などがあれば、コメントにて優しく丁寧にご教示いただけると喜びます。
書くのけっこう大変だったのでバッチを送っていただけると嬉しいです。たくさんもらえると僕のやる気が上がって投稿頻度が上がるかもしれません。(逆にもらえないと下がるかもしれない。)
状態空間モデル
時刻t における1次元の入力信号u(t) を、N 次元の潜在空間x(t) に投影し、出力信号y(t) を計算する以下の状態空間モデルで表される線形システムを考える。LSSL、S4、H3、mambaなどの一連の状態空間モデルは、下記の方程式を離散化し、長さL の系列を同じく長さL の系列に変換する\mathbb{R}^L \rightarrow \mathbb{R}^L のsequence-to-sequenceモデルを用いる。
\begin{align}
\dot{x}(t) = A x(t) + B u(t) \quad
\\
y(t) = C x(t) + D u(t) \quad
\end{align}
状態空間モデルは、Hidden Markov Models (HMM)のように、制御工学や機械学習などの幅広い分野でみられる潜在空間を利用するモデルであり、行列A \in \mathbb{R}^{N \times N} 、B \in \mathbb{R}^{N \times 1} 、 C \in \mathbb{R}^{1 \times N} 、D \in \mathbb{R}^{1 \times 1} は、勾配降下によって求められる学習パラメータである。
これは連続時間のシステムであるが、コンピュータ上では離散時間信号を扱うため、式(1)、式(2)の離散化を行う。ここで、式(2)の離散化は簡単だが問題なのは式(1)の常微分方程式(ODE)を離散化することである。
ゼロ次ホールド(ZOH)と同様に、時刻t における入力信号u(t) が次のタイムステップ\Delta t まで一定である、つまりu(t+\Delta t) = u(t) であると仮定する。(1)式の右辺をt とx(t) の関数として、\dot{x}(t) = f(t, x(t)) と表す、離散化された時刻t_i において、状態x(t_0), x(t_1), \ldots は、ピカールの逐次近似法x(t_{i+1}) = x(t_i) + \int^{t_{i+1}}_{t_i} f(s, x(s))ds によって求められる。この右側積分を推定する方法には様々な方法があり、ここでは、Generalized bilinear transform (GBT)を用いる。この方法では、定数\alpha を用いて、以下のように右側積分を近似する。
\begin{align*}
\begin{split}
&\int^{t+\Delta}_{t} f(s, x(s))ds \\
&\simeq \alpha \Delta t f(t+\Delta, x(t+\Delta t)) + (1-\alpha) \Delta t f(t, x(t)) \\
&= \alpha \Delta t (A x(t+\Delta) + B u(t+\Delta)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \\
&= \alpha \Delta t (A x(t+\Delta) + B u(t)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \quad (\because u(t+\Delta) = u(t)) \\
&= \alpha \Delta t A x(t+ \Delta) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t)
\end{split}
\end{align*}
この結果を用いると、x(t+\Delta t) = x(t) + \int^{t+\Delta t}_{t} f(s, x(s))ds より、
\begin{align*}
x(t+ \Delta t)
&= x(t) + \alpha \Delta t A x(t+ \Delta t) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t) \notag \\
(I- \alpha \Delta t A) x(t+\Delta t)
&= (I + (1-\alpha) \Delta t A) x(t) + B \Delta t u(t) \notag \\
\therefore x(t+\Delta t)
&= \left(I - \alpha \Delta t A \right)^{-1} \left(I + (1- \alpha) \Delta t A\right) x(t) + \left( I - \alpha \Delta t A \right)^{-1} B \Delta t u(t)
\end{align*}
ここでは、\alpha = \frac{1}{2} として、下式を用いる。
\begin{align}
\begin{split}
x(t+\Delta t)
= \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right) x(t) + \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t u(t)
\end{split}
\end{align}
さらに、以下のように行列A, B を行列\bar{A}, \bar{B} に変換する。
\begin{align}
\bar{A} = \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right)
\\
\bar{B} = \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t
\end{align}
すると、以下のように式(1), (2)の連続時間システムを離散化することが出来る。
\begin{align}
x_t &= \bar{A} x_{t-1} + \bar{B} u_t \quad
\\
y_t &= C x_t + D u_t \quad
\end{align}
!
GBTは、\alpha = 0 とすると、オイラー法x(t+\Delta t)= x(t) + \Delta t f(t, x(t)) と一致し、\alpha = 1 とすると、後退オイラー法x(t) = x(t + \Delta t) - \Delta t f(t+\Delta t, x(t+\Delta t)) と一致する。実は、GBTはこれらの手法を定数\alpha \in [0, 1] を用いて一般化した手法である。
積分\int^{t+\Delta}_{t} f(s, x(s))ds の近似方法に着目するならば、\alpha=0 のオイラー法の場合は、下図のように、高さf(t, x(t)) の長方形で近似し、\alpha=1 の後退オイラー法の場合は、高さf(t+\Delta t, x(t+\Delta t)) の長方形で近似する。一方で、\alpha = \frac{1}{2} の場合には、左辺がf(t, x(t)) 、右辺がf(t+\Delta t, x(t+\Delta t)) の台形で近似するということである。
HiPPO (High-Order Polynomial Projection Operator; 高次多項式投影演算子)
ここまで線形システム(1)、(2)を離散化し、新たなシステムである式(6),(7)を得た。次に、本節で説明するHiPPOと呼ばれるフレームワークを適用し、以下のように係数行列A を初期化する。
A_{nk} = \left\{
\begin{array}{ll}
(2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\
n+1 & \text{if} \ n=k \\
0 & \text{if} \ n<k \\
\end{array}
\right.
なぜこのような初期化を行うのかをこれから説明化するが、長いのでLSSLの全体像を先につかみたい場合は、次の節「畳み込みによる高速化」に進んでもらって構わない。
HiPPOの定義
時間変化する(-\infty, t] でサポートされる測度族\mu^{(t)} 、N 次元の(直交)多項式が張る部分空間\mathcal{G} 、そして連続入力信号u: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R} が与えられた時、u を最適化された投影係数x: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R}^N に移す演算\text{hippo} を定義する。この演算は、以下のように投影演算子\text{proj}_t と、係数抽出演算子\text{coef}_t をすべての時間ステップt で求め、それらの合成\text{coef}_t \circ \text{proj}_t である。(つまり、(\text{hippo}(u))(t)=\text{coef}_t (\text{proj}_t (u)) である。)
\text{proj}_t は、時刻t までの信号u 、つまりu_{<t}:=\{u(y)\}_{y \leq t} を、推論誤差\| u_{\leq t} - g^{(t)} \|_{L_2 (\mu^{(t)})} を最小にする多項式g^{(t)} \in \mathcal{G} に射影する。
\text{coef}_t: \mathcal{G} \rightarrow \mathbb{R}^N は、多項式g^{(t)} を、測度\mu^{(t)} に関して定義された直交多項式の基底関数の係数x(t) \in \mathbb{R}^N に射影する。
HiPPOを導くためのHiPPOフレームワーク
\{P_n^{(t)}\}_{n \in \mathbb{N}} は、時間変化する測度\mu^{(t)} に関する直交多項式列であり、p_n^{(t)} を、P_n^{(t)} を正規した多項式、すなわちp_n^{(t)} = P_n^{(t)}/ \| P_n^{(t)} \|_{\mu}^2 であるとする。正規直交基底の定義より、以下が成り立つ。
\begin{align}
\int^{\infty}_{-\infty} p_n^{(t)}(y) p_m^{(t)}(y) \omega^{(t)}(y) dy = \delta_{m,n}
\end{align}
ここで、\omega^{(t)}(y) は重みであり、d \mu^{(t)} = \omega^{(t)}(y) dy, \int^{\infty}_{-\infty} d \mu^{(t)} = 1 である。
入力信号の履歴u_{\leq t}:=\{u(y)\}_{y \leq t} が、ステップサイズ\Delta t で連続関数u(y) からサンプリングされたと考える。N個の基底\{p_n^{(t)}\}_{n=0,1,\ldots, N-1} を用いて、u_{\leq t} を通る全時刻にわたる連続関数u(y) を多項式g^{(t)} で近似する。すなわち、u_{\leq t} を\{p_n^{(t)}\}_{n=0,1,\ldots, N-1} が張る空間へ射影し、多項式g^{(t)} が得られるとする。
\begin{align}
u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)}
\end{align}
このように近似すると、時刻t までの入力信号の履歴を、N 個の係数\{x_n^{(t)}\}_{n=0,1,\ldots, N-1} で表現することが出来るようになる。係数x_n (t) は、時刻t までの入力信号の履歴u_{\leq t} を用いて、フーリエ係数を求める要領で以下のように求めることが出来る。
\begin{align}
x_n (t) = \int^{\infty}_{-\infty} g^{(t)} p_n^{(t)} \omega^{(t)}(y) dy
&\simeq \int^{\infty}_{-\infty} u_{\leq t} p_n^{(t)} \omega^{(t)}(y) dy
\end{align}
実は、式(10)に従って、以下の入力信号u_{\leq t} の履歴からN 個の係数を計算し、出力することが\text{hippo} 演算に他ならない。
\begin{align*}
\text{hippo}(u_{\leq t}) := (x_0, x_1, \ldots, x_{N-1}) \\
\text{where} \ u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)}
\end{align*}
しかしながら、これらの係数を各時刻で計算するのは容易ではない。そこでHiPPOフレームワークでは、式(1)を離散化して式(6)の更新式を導いたように、x(t) が満たすODEを求めた後、それを離散化して更新式を求め、各時刻でx(t) を更新していくことを考える。
!
ということでHiPPOでは式(10)を用いてx_n(t) を求めることはしないわけだが、式(10)の右辺を計算する方法を考えてみる。
まず、u_{\leq t} のk 番目の要素がu[k], k = 0, \ldots, L-1 で表されるとし、u_{\leq t} を以下のように定義し直す。
\begin{align*}
u_{\leq t}(y) = \left\{
\begin{array}{ll}
u[k] & \text{if} \quad y = k\Delta t, \quad k = 0, \ldots, L-1\\
0 & otherwise \\
\end{array}
\right.
\end{align*}
これを用いて、
\begin{align*}
\int^{\infty}_{-\infty} u_{\leq t} p_n (t, y) \omega (t,y) dy
&= \int^{\infty}_{-\infty} ( \sum^{L-1}_{k=0} u[k]\delta(y-k \Delta t) ) p_n (t, y) \omega (t,y) dy
\\
&= \sum^{L-1}_{k=0} u[k] \int^{\infty}_{-\infty} \delta(y-k \Delta t) p_n (t, y) \omega (t,y) dy
\\
&= \sum^{L-1}_{k=0} u[k] p_n (t, k \Delta t) \omega (t, k \Delta t)
\end{align*}
従って、以下の式でx_n(t) を計算できるのではないかと個人的には考えている。
\begin{align*}
x_n (t) \simeq \sum^{L-1}_{k=0} u[k] p_n (t, k \Delta t) \omega (t, k \Delta t)
\end{align*}
ちなみにここで、n を\omega と入れ替えて、基底をp_{\omega}^{(t)} (y) = e^{-j \omega y} 、測度を\omega (t, y) = 1 とし、求める係数をx^{(t)} (\omega) としてみると、
\begin{align*}
x^{(t)} (\omega) = \sum^{L-1}_{k=0} u[k] e^{-j \omega k \Delta t}
\end{align*}
となり、これは離散フーリエ変換(DFT)と考えることが出来る。
係数が満たすODE
以上より、下式が成り立つことが分かった。
\begin{align}
x_n (t) = \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy
\end{align}
両辺を微分して、係数x^{(t)}_n が満たすODEを導く。ここで、微分と積分の可換性を認めた。
\begin{align*}
\dot{x}_n (t) &= \frac{\partial}{\partial t} \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy
\\
&= \int^{\infty}_{-\infty} \frac{\partial}{\partial t} (u(y) p_n (t, y) \omega(t, y) ) dy
\\
&= \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy
\end{align*}
以上より、以下の式(11)が成り立つ。
\begin{align}
\dot{x}_n (t)
= \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy
\end{align}
HiPPOフレームワークでは、具体的に用いる基底\{P_n^{(t)}\}_{n \in \mathbb{N}} 及び、測度\mu^{(t)} を選択し、式(12)に代入することで、\dot{x}(t) とその時点での係数x(t) 、入力信号u(t) との関係を導き、式(1)に当てはめて係数行列A とB を設定する。
HiPPO-LegS
HiPPOの論文では、様々な基底や測度を用いた場合について検証されており、LSSLではその中で最も良い結果を示したScaled Legendre Measure (LegS)という設定を用いて得られる行列A に初期化する。この設定を用いたHiPPO-LegSは、基底としてルシャンドル多項式を用い、すべての履歴に対して一様に重みをつける\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]} を用いるものである。
ここではまず、ルシャンドル多項式の性質や、基底として用いた場合の\omega_t に関する直交性などを確認し、正規直交基底を求め、式(12)に代入してODEを求める。
正規直交規定および測度
ルシャンドル多項式は、以下のn 次多項式である。
\begin{align}
P_n (x) := \frac{1}{2^n n!} \frac{d^n}{{dx}^n} [ (x^2 - 1 )^n ]
\end{align}
測度\omega^{\text{leg}} = \bold{1}_{[-1, 1]} に関して、以下の直行性が成り立つ。
\begin{align}
\frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx = \delta_{nm}
\end{align}
さらに、以下の性質を持つ。
\begin{align}
P_n (1) &= 1
\\
P_n (-1) &= (-1)^n
\\
(2n+1) P_n &= P'_{n+1} - P'_{n-1}
\\
P'_{n+1} &= (n+1) P_n + x P'_n
\end{align}
以上のルシャンドル多項式は区間x \in [-1, 1] で成り立つ性質であったが、区間[0, t]についても成り立つようにスケーリングするため、y = \frac{t}{2}(x+1) 、つまりx = \frac{2}{t} y - 1 と変数変換する。すると、dx = \frac{2}{t} dy であり、式(14)の左辺を変形すると、以下のようになる。
\begin{align*}
& \frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx
\\
=& \frac{2n+1}{2} \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \frac{2}{t}dy
\\
=& (2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy
\end{align*}
従って、以下が成り立つ。
\begin{align}
(2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy = \delta_{nm}
\end{align}
この式より、新たな測度\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]} を用い、さらに以下の正規直交規定を用いることにする。
\begin{align}
p_n (t, y) = (2n+1)^{1/2} P_n ( \frac{2}{t} y - 1 )
\end{align}
すると、以下が成り立つ。
\begin{align}
\int^t_{0} p_n(t, y) p_m(t, y) \omega_t dy = \delta_{nm}
\end{align}
ODEの導出
HiPPO-LegSが用いるODEを導出する前に、導出で用いるルシャンドル基底の微分が満たす性質について述べる。まず、式(17)から、以下が導かれる。
\begin{align}
P'_{n+1} = (2n+1) P_n + (2n-3)P_{n-2} + \ldots,
\end{align}
次数を一つずらすと、
\begin{align}
P'_{n} = (2n-1) P_{n-1} + (2n-5)P_{n-3} + \ldots,
\end{align}
これらの性質を用いて(x+1)P'_n (x) を計算すると、以下が成り立つ。
\begin{align}
(x+1) P'_{n}
=& x P'_{n} + P'_{n} = P'_{n+1} + P'_n - (n+1) P_n \notag
\\
&(\because \text{式(18)より、} xP'_n = P'_n - (n+1) P_n) \notag
\\
=& nP_n + (2n-1) P_{n-1} + (2n-3)P_{n-2} + \ldots
\\
&(\because \text{式(22), (23)を代入}) \notag
\\
\end{align}
まず、\omega_t と、p_n(t, y) の微分をそれぞれ求める。
\begin{align}
\frac{\partial}{\partial t} \omega(t) =& -t^{-2} \mathcal{1}_{[0,1]} + t^{-1} \delta_t \notag
\\
=& t^{-1} ( - \omega (t) + \delta_t)
\\
\frac{\partial}{\partial t} p_n (t, y) &= - (2n+1)^{\frac{1}{2}} 2y t^{-2} P_n' (\frac{2}{t} y - 1 ) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} ((\frac{2}{t} y - 1) + 1 ) P_n' (\frac{2}{t} y - 1 ) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} (x + 1 ) P_n' (x) \ (\because x = \frac{2}{t} y - 1) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} [nP_n (x) + (2n-1) P_{n-1}(x) + (2n-3) P_{n-2}(x) + \ldots ] \ (\because 式(24)を用いた。) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \notag
\\
& (\because 式(20)より、P_n (x) = (2n+1)^{-\frac{1}{2}} p_n (t, y) )
\end{align}
これらを用いて、式(12)のODEの右辺を計算する。
まず、第一項を求める。
\begin{align}
&\int^t_0 u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag
\\
\simeq & \int^t_0 g (t, y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag
\\
=&\int^t_0 ( \sum_{n=0}^{N-1} x_n (t) p_n (t, y) ) \{ - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \} \omega (t, y) dy \notag
\\
&(\because 式(26)を用いた。) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ]
\\
&(\because 式(21)の正規直交性を用いた。) \notag
\\
\end{align}
つぎに、第二項を求める。
\begin{align}
&\int^t_0 u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \notag
\\
=&\int^t_0 u(y) p_n (t, y) ( t^{-1} ( - \omega (t) + \delta_t) ) dy \ (\because 式(25)を用いた。) \notag
\\
=& -t^{-1} \int^t_0 u(y) p_n (t, y) \omega (t) dy + t^{-1} \int^t_0 u(y) p_n (t, y) \delta_t dy \notag
\\
=& -t^{-1} x_n (t) + t^{-1} u(t) p_n (t, t) \notag
\\
=& -t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t)
\end{align}
式(12)の右辺に式(27), (28)を代入して、以下が成り立つ。
\begin{align}
\dot{x}_n (t)
=& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag
\\
&-t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t) \notag
\\
=& - (2n+1)^{\frac{1}{2}} t^{-1} [(n+1) (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag
\\
& + t^{-1} (2n+1)^{1/2} u(t)
\end{align}
式(29)をベクトル化すると、以下の式(30)が成立する。
\begin{align}
\dot{x}(t) &= - \frac{1}{t} A x(t) + \frac{1}{t} B u(t)
\\
A_{nk} &= \left\{
\begin{array}{ll}
(2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\
n+1 & \text{if} \ n=k \\
0 & \text{if} \ n<k \\
\end{array}
\right.
\\
B_n &= (2n+1)^{\frac{1}{2}}
\\
\end{align}
式(31)の行列A は、LSSL、S4、H3、mambaなどの後続研究では、HiPPO行列と呼ばれる。
!
式(21)の直交性を用いると、式(11)より、
\begin{align}
x_n (t) &= \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy \notag
\\
&= \int^t_0 u(y) p_n (t, y) \omega(t, y) dy
\end{align}
としてみる。
ライプニッツの積分法則
\begin{align*}
\frac{\partial}{\partial t} \int^{\beta (t)}_{\alpha (t)} f(t, y) dy
= \int^{\beta (t)}_{\alpha (t)} \frac{\partial}{\partial t} f(t, y) dy - \alpha' (t) f(\alpha(t), t) + \beta' (t) f( \beta(t), t)
\end{align*}
について考え、ここで、\beta (t) = t 、\alpha (t) = 0 とすると、
\begin{align*}
\frac{d}{dt} \int^t_0 f(t, y) dy
&= f(t, t) \frac{d}{dt} t - F(t,0)\frac{d}{dt} 0 + \int^t_0 \frac{\partial}{\partial t} f(t, y) dy
\\
&= \int^t_0 \frac{\partial}{\partial t} f(t, y) dy + f(t, t)
\end{align*}
f (t y) = u(y) p_n (t, y) \omega(t, y) として、式(33)の両辺を微分すると、
\begin{align*}
\dot{x}_n (t)
&= \frac{d}{dt} \int^t_0 u(y) p_n (t, y) \omega(t, y) dy
\\
&= \frac{d}{dt} \int^t_0 f(t, y) dy
\\
&= \int^t_0 \frac{\partial}{\partial t} f(t, y) dy + f(t, t) \quad (\because ライプニッツの積分法則)
\\
&= \int^t_0 u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^t_0 u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy
\\ & \quad + u(t) p_n(t,t) \omega (t, t)
\end{align*}
となり、式(12)と比較すると、第三項が増える。
実は、HiPPO-LegSにおいて、この第三項を計算すると、u(t) p_n(t,t) \omega (t, t) = (2n+1)^{\frac{1}{2}} t^{-1} u(t) となり、式(32)はB_n = 2(2n+1)^{\frac{1}{2}} となると思われるが、実際にHiPPOを用いた実験では式(32)でうまくいってるので、これはおそらく違うのだろうが、それがなぜかわからない。
https://github.com/HazyResearch/hippo-code/issues/11
HiPPO-LegSは時間スケールに依存しない。
式(30)のODEを計算するにあたり、式(6)を求めたように両辺を積分し、GBT(\alpha=\frac{1}{2} )を使って離散化して、x_n (t) の更新式を得る。
\begin{align}
x(t+\Delta t) - x(t) &= \int^{t+\Delta t}_t - \frac{1}{t} A x(t) + \frac{1}{t} B u(t) dt \notag
\\
&\simeq \frac{\Delta t}{2}{(-\frac{1}{t}Ac(t) + \frac{1}{t}Bu(t))+(-\frac{1}{t+\Delta t}Ac(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t+\Delta t))} \notag
\end{align}
ゼロ次ホールドを仮定し、u(t) = u(t+\Delta t) であるとすると、
\begin{align}
x(t+\Delta t) - x(t) &= \frac{\Delta t}{2}(-\frac{1}{t}Ax(t) + \frac{1}{t}Bu(t)) + \frac{\Delta t}{2}(-\frac{1}{t+\Delta t}Ax(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t)) \notag
\\
x(t+\Delta t) + \frac{\Delta t}{2(t+\Delta t)}Ax(t+\Delta t) &= x(t) - \frac{\Delta t}{2t} Ax(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t) \notag
\\
(I + \frac{\Delta t}{2(t+\Delta t)}A) x(t+\Delta t) &= (I - \frac{\Delta t}{2t} A)x(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t)
\end{align}
ここで、時間を離散化し、t=k\Delta t とし、x_k := x(k\Delta t) 、u_k := u(k \Delta t) とすれば、
\begin{align*}
(I + \frac{\Delta t}{2(k\Delta t+\Delta t)}A) x_{k+1} &= (I - \frac{\Delta t}{2k\Delta t} A)x_k + (\frac{\Delta t}{2k\Delta t} + \frac{\Delta t}{2(k\Delta t+\Delta t)})Bu_k
\end{align*}
従って、以下が成り立つ。
\begin{align}
(I + \frac{1}{2(k+1)}A) x_{k+1} &= (I - \frac{1}{2k} A)x_k + (\frac{1}{2k} + \frac{1}{2(k+1)})Bu_k
\end{align}
式(35)では、時間スケール\Delta t がなくなっている。すなわち、HiPPO-LegSは時間スケールに依存しない。これはHiPPO-LegSが持つ特殊な特徴である。
この式に従ってx_n (t) を更新する。入力信号の履歴u_{\leq t} は、各時刻で式(9)に従って以下のように再構成することが出来る。
\begin{align*}
u_{\leq t} \simeq g^{(t)} &= \sum_{n=0}^{N-1} x_n (t) p_n (t, y)
\\
&= \sum_{n=0}^{N-1} x_n (t) (2n+1)^{\frac{1}{2}} P_n (\frac{2y}{t} - 1)
\end{align*}
LSSL(Linear State-Space Layers)
ここまで、HiPPOおよびHiPPO-LegSについて説明してきた。HiPPO-LegSでは、式(30)のODEが導かれ、式(35)のように離散化された。LSSLやその発展であるS4、H3、mambaなどでは、式(1),(2)の線形システムを考え、式(6)、(7)の形に離散化するが、行列A はHiPPO-LegSから導かれ、HiPPO行列と呼ばれる式(31)と同じになるように初期化する。このように初期化することで、x(t) に入力信号u_{\leq t} の履歴を記憶することを可能とし、MNISTベンチマークでの性能を60%から98%に上昇させることが出来る。ここでは、これらの式(6), (7)から導かれるLSSLの性質について見ていく。
畳み込みによる高速化
式(6),(7)を見た時に思ったかもしれないが、式(6)を式(7)に代入することで、以下のようにx を用いずにu_{\leq t} のみで出力y_k を計算することが出来る。ただし、x_{-1} = 0 とする。
\begin{align*}
y_k =& C x_k + D u_k
\\
=& C (\bar{A} x_{k-1} + \bar{B} u_k) + D u_k
\\
=& C \bar{A} (\bar{A} x_{k-2} + \bar{B} u_{k-1}) + C \bar{B} u_k + D u_k
\\
=& C (\bar{A})^2 (\bar{A} x_{k-3} + \bar{B} u_{k-2}) + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k
\\
=& C (\bar{A})^3 (\bar{A} x_{k-4} + \bar{B} u_{k-3}) + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k
\\
\vdots
\\
=& C (\bar{A})^k x_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k
\\
=& C (\bar{A})^k (\bar{A} x_{-1} + \bar{B} u_0) + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k
\\
=& C (\bar{A})^k \bar{B} u_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k
\\
&(\because x_{-1} = 0)
\end{align*}
従って、\mathcal{K}_k (\bar{A}, \bar{B}, C) = (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) として、以下が成り立つ。
\begin{align}
y_k &= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) * u_{\leq k} + D u_k \notag
\\
&= \mathcal{K}_k (\bar{A}, \bar{B}, C) * u_{\leq k} + D u_k
\end{align}
ここで、以下のように置いた。
\begin{align}
\mathcal{K}_k (\bar{A}, \bar{B}, C)
&= (C \bar{A}^i B)_{i = 0, \ldots, k} \notag
\\
&= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B})
\\
u_{\leq k} &= (u_0, u_1, \ldots, u_k)
\end{align}
また、系列長L の入力信号u_{\leq L-1} = (u_0, u_1, \ldots, u_{L-1}) が与えられたとき、出力y_{L-1} は、以下のように表される。
\begin{align}
y_{L-1} = \mathcal{K}_{L-1} (\bar{A}, \bar{B}, C) * u_{\leq L-1} + D u_{L-1}
\end{align}
このように畳み込み形式で計算することで、全時刻にわたるy \in \mathbb{R}^{H \times L} を三度のFFTで一度に求めることが出来る。
ただし、LSSLの計算にはボトルネックがある。式(6)の再帰形式では、離散化された状態行列A 、つまり\bar{A} の、行列-ベクトル乗算(matrix-vector multiplication; MVM)、そして式(35)の畳み込み形式では、行列\bar{A} のべき乗を多く含むKrylov function \mathcal{K}_L が、それぞれボトルネックとなる。これらをうまく解決し、効率よく計算するアルゴリズムがLSSLでも提案されているが、S4でより効率的な手法が提案されているため、ここでは紹介しない。(S4については、次の記事で説明する。)
LSSLの表現力
ここまで、LSSLが式(6)、(7)の再帰形式、式(39)の畳み込み形式の二つの形式を持つことを見てきた。LSSLでは線形計算のみを行い、RNNやTransformerなどで見られるような非線形関数は用いていない。それでもLSSLは高い表現力を持つことを示す。
畳み込みはLSSLである。
Under Construction
RNNはLSSLである。
Under Construction
Lemma [3, p.6, 3.1]
一次元ゲート付き再帰\bold{x_t = (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t} (ただし、\bold{\sigma} はシグモイド関数で\bold{z} は任意の表現)は、1次元の線形ODE\bold{\dot{x}(t) = -x(t) + u(t)} を\bold{\alpha=1} のGBT(つまり、後退オイラー法)で離散化したものとみなすことが出来る。
証明:
\dot{x}(t) = -x(t) + u(t) を後退オイラー法で離散化すると、以下が成り立つ。
\begin{align*}
x_{t-1} = x_t - \Delta t (- x_t + u_t)
\end{align*}
ここで、\Delta t = e^z とすると、
\begin{align*}
x_{t-1} = x_t - e^z (- x_t + u_t) = (1 + e^z)x_t - e^z u_t
\end{align*}
従って、以下が成り立つ。
\begin{align}
x_t &= \frac{1}{1 + e^z} x_{t-1} + \frac{e^z}{1 + e^z} u_t \notag
\\
&= \frac{e^{-z}}{1 + e^{-z}} x_{t-1} + \frac{1}{1 + e^{-z}} u_t \notag
\\
&= (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t \quad (\because \sigma(z) = \frac{1}{1 + e^{-z}})
\end{align}
(証明終)
Lemma [3, p.6, 3.2]
Under Construction
Deep LSSLs
Under Construction
参考文献
[1] Voelker, A., Kajić, I., & Eliasmith, C. (2019). Legendre memory units: Continuous-time representation in recurrent neural networks. Advances in neural information processing systems, 32.
[2] Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33, 1474-1487.
[3] Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., & Ré, C. (2021). Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, 572-585.
[4] Gu, A., Goel, K., & Ré, C. (2021). Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
[5] Gu, A., Johnson, I., Timalsina, A., Rudra, A., & Ré, C. (2022). How to train your hippo: State space models with generalized orthogonal basis projections. arXiv preprint arXiv:2206.12037.
[6] 角田良太郎「HiPPO/S4解説」, https://techblog.morphoinc.com/entry/2022/05/24/102648 , 2024年2月10日閲覧
[7] Guofeng Zhang, Tongwen Chen, and Xiang Chen. Performance recovery in digital implementation of analogue systems. SIAM journal on control and optimization, 45(6):2207–2223, 2007.
[8] 岡島寛「状態フィードバック制御・状態方程式に基づく制御のまとめ」、制御工学の教科書、2024年2月17日公開、2024年2月21日閲覧
Discussion