😎

mambaの理論を理解する①:HiPPOフレームワークとLSSL

2024/02/10に公開

はじめに

この前のブログでは、mambaの論文を翻訳した。

https://izmyon.hatenablog.com/entry/2023/12/11/155551

本シリーズでは、mambaの理論的背景を理解するために、それらの先行研究を順々にまとめて解説していく。重要な先行研究としては、LMU, HiPPO, LSSL, S4, H3などがあり、主に以下の流れで解説する。

  1. HiPPOフレームワークとLSSL
  2. S4のアルゴリズム
  3. H3とHyena Hierarchy
  4. mamba

本記事は、第一段として、HiPPOフレームワークとLSSLについてまとめる。

元論文を以下に記す。

HiPPO

https://arxiv.org/abs/2008.07669

LSSL

https://arxiv.org/abs/2110.13985

この記事の続きはこちら

https://zenn.dev/izmyon/articles/c56a2fd6670546

読者の方へ

  • 補足や訂正などがあれば、コメントにて優しく丁寧にご教示いただけると喜びます。
  • 書くのけっこう大変だったのでバッチを送っていただけると嬉しいです。たくさんもらえると僕のやる気が上がって投稿頻度が上がるかもしれません。(逆にもらえないと下がるかもしれない。)

状態空間モデル

時刻tにおける1次元の入力信号u(t)を、N次元の潜在空間x(t)に投影し、出力信号y(t)を計算する以下の状態空間モデルで表される線形システムを考える。LSSL、S4、H3、mambaなどの一連の状態空間モデルは、下記の方程式を離散化し、長さLの系列を同じく長さLの系列に変換する\mathbb{R}^L \rightarrow \mathbb{R}^Lのsequence-to-sequenceモデルを用いる。

\begin{align} \dot{x}(t) = A x(t) + B u(t) \quad \\ y(t) = C x(t) + D u(t) \quad \end{align}

状態空間モデルは、Hidden Markov Models (HMM)のように、制御工学や機械学習などの幅広い分野でみられる潜在空間を利用するモデルであり、行列A \in \mathbb{R}^{N \times N}B \in \mathbb{R}^{N \times 1}C \in \mathbb{R}^{1 \times N}D \in \mathbb{R}^{1 \times 1}は、勾配降下によって求められる学習パラメータである。

これは連続時間のシステムであるが、コンピュータ上では離散時間信号を扱うため、式(1)、式(2)の離散化を行う。ここで、式(2)の離散化は簡単だが問題なのは式(1)の常微分方程式(ODE)を離散化することである。

ゼロ次ホールド(ZOH)と同様に、時刻tにおける入力信号u(t)が次のタイムステップ\Delta tまで一定である、つまりu(t+\Delta t) = u(t)であると仮定する。(1)式の右辺をtx(t)の関数として、\dot{x}(t) = f(t, x(t))と表す、離散化された時刻t_iにおいて、状態x(t_0), x(t_1), \ldotsは、ピカールの逐次近似法x(t_{i+1}) = x(t_i) + \int^{t_{i+1}}_{t_i} f(s, x(s))dsによって求められる。この右側積分を推定する方法には様々な方法があり、ここでは、Generalized bilinear transform (GBT)を用いる。この方法では、定数\alphaを用いて、以下のように右側積分を近似する。

\begin{align*} \begin{split} &\int^{t+\Delta}_{t} f(s, x(s))ds \\ &\simeq \alpha \Delta t f(t+\Delta, x(t+\Delta t)) + (1-\alpha) \Delta t f(t, x(t)) \\ &= \alpha \Delta t (A x(t+\Delta) + B u(t+\Delta)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \\ &= \alpha \Delta t (A x(t+\Delta) + B u(t)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \quad (\because u(t+\Delta) = u(t)) \\ &= \alpha \Delta t A x(t+ \Delta) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t) \end{split} \end{align*}

この結果を用いると、x(t+\Delta t) = x(t) + \int^{t+\Delta t}_{t} f(s, x(s))dsより、

\begin{align*} x(t+ \Delta t) &= x(t) + \alpha \Delta t A x(t+ \Delta t) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t) \notag \\ (I- \alpha \Delta t A) x(t+\Delta t) &= (I + (1-\alpha) \Delta t A) x(t) + B \Delta t u(t) \notag \\ \therefore x(t+\Delta t) &= \left(I - \alpha \Delta t A \right)^{-1} \left(I + (1- \alpha) \Delta t A\right) x(t) + \left( I - \alpha \Delta t A \right)^{-1} B \Delta t u(t) \end{align*}

ここでは、\alpha = \frac{1}{2}として、下式を用いる。

\begin{align} \begin{split} x(t+\Delta t) = \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right) x(t) + \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t u(t) \end{split} \end{align}

さらに、以下のように行列A, Bを行列\bar{A}, \bar{B}に変換する。

\begin{align} \bar{A} = \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right) \\ \bar{B} = \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t \end{align}

すると、以下のように式(1), (2)の連続時間システムを離散化することが出来る。

\begin{align} x_t &= \bar{A} x_{t-1} + \bar{B} u_t \quad \\ y_t &= C x_t + D u_t \quad \end{align}

HiPPO (High-Order Polynomial Projection Operator; 高次多項式投影演算子)

ここまで線形システム(1)、(2)を離散化し、新たなシステムである式(6),(7)を得た。次に、本節で説明するHiPPOと呼ばれるフレームワークを適用し、以下のように係数行列Aを初期化する。

A_{nk} = \left\{ \begin{array}{ll} (2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\ n+1 & \text{if} \ n=k \\ 0 & \text{if} \ n<k \\ \end{array} \right.

なぜこのような初期化を行うのかをこれから説明化するが、長いのでLSSLの全体像を先につかみたい場合は、次の節「畳み込みによる高速化」に進んでもらって構わない。

HiPPOの定義

時間変化する(-\infty, t]でサポートされる測度族\mu^{(t)}N次元の(直交)多項式が張る部分空間\mathcal{G}、そして連続入力信号u: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R}が与えられた時、uを最適化された投影係数x: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R}^Nに移す演算\text{hippo}を定義する。この演算は、以下のように投影演算子\text{proj}_tと、係数抽出演算子\text{coef}_tをすべての時間ステップtで求め、それらの合成\text{coef}_t \circ \text{proj}_tである。(つまり、(\text{hippo}(u))(t)=\text{coef}_t (\text{proj}_t (u))である。)

  1. \text{proj}_tは、時刻tまでの信号u、つまりu_{<t}:=\{u(y)\}_{y \leq t}を、推論誤差\| u_{\leq t} - g^{(t)} \|_{L_2 (\mu^{(t)})}を最小にする多項式g^{(t)} \in \mathcal{G}に射影する。

  2. \text{coef}_t: \mathcal{G} \rightarrow \mathbb{R}^Nは、多項式g^{(t)}を、測度\mu^{(t)}に関して定義された直交多項式の基底関数の係数x(t) \in \mathbb{R}^Nに射影する。

HiPPOを導くためのHiPPOフレームワーク

\{P_n^{(t)}\}_{n \in \mathbb{N}}は、時間変化する測度\mu^{(t)}に関する直交多項式列であり、p_n^{(t)}を、P_n^{(t)}を正規した多項式、すなわちp_n^{(t)} = P_n^{(t)}/ \| P_n^{(t)} \|_{\mu}^2であるとする。正規直交基底の定義より、以下が成り立つ。

\begin{align} \int^{\infty}_{-\infty} p_n^{(t)}(y) p_m^{(t)}(y) \omega^{(t)}(y) dy = \delta_{m,n} \end{align}

ここで、\omega^{(t)}(y)は重みであり、d \mu^{(t)} = \omega^{(t)}(y) dy, \int^{\infty}_{-\infty} d \mu^{(t)} = 1である。

入力信号の履歴u_{\leq t}:=\{u(y)\}_{y \leq t}が、ステップサイズ\Delta tで連続関数u(y)からサンプリングされたと考える。N個の基底\{p_n^{(t)}\}_{n=0,1,\ldots, N-1}を用いて、u_{\leq t}を通る全時刻にわたる連続関数u(y)を多項式g^{(t)}で近似する。すなわち、u_{\leq t}\{p_n^{(t)}\}_{n=0,1,\ldots, N-1}が張る空間へ射影し、多項式g^{(t)}が得られるとする。

\begin{align} u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)} \end{align}

このように近似すると、時刻tまでの入力信号の履歴を、N個の係数\{x_n^{(t)}\}_{n=0,1,\ldots, N-1}で表現することが出来るようになる。係数x_n (t)は、時刻tまでの入力信号の履歴u_{\leq t}を用いて、フーリエ係数を求める要領で以下のように求めることが出来る。

\begin{align} x_n (t) = \int^{\infty}_{-\infty} g^{(t)} p_n^{(t)} \omega^{(t)}(y) dy &\simeq \int^{\infty}_{-\infty} u_{\leq t} p_n^{(t)} \omega^{(t)}(y) dy \end{align}

実は、式(10)に従って、以下の入力信号u_{\leq t}の履歴からN個の係数を計算し、出力することが\text{hippo}演算に他ならない。

\begin{align*} \text{hippo}(u_{\leq t}) := (x_0, x_1, \ldots, x_{N-1}) \\ \text{where} \ u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)} \end{align*}

しかしながら、これらの係数を各時刻で計算するのは容易ではない。そこでHiPPOフレームワークでは、式(1)を離散化して式(6)の更新式を導いたように、x(t)が満たすODEを求めた後、それを離散化して更新式を求め、各時刻でx(t)を更新していくことを考える。

係数が満たすODE

以上より、下式が成り立つことが分かった。

\begin{align} x_n (t) = \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy \end{align}

両辺を微分して、係数x^{(t)}_nが満たすODEを導く。ここで、微分と積分の可換性を認めた。

\begin{align*} \dot{x}_n (t) &= \frac{\partial}{\partial t} \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy \\ &= \int^{\infty}_{-\infty} \frac{\partial}{\partial t} (u(y) p_n (t, y) \omega(t, y) ) dy \\ &= \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \end{align*}

以上より、以下の式(11)が成り立つ。

\begin{align} \dot{x}_n (t) = \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \end{align}

HiPPOフレームワークでは、具体的に用いる基底\{P_n^{(t)}\}_{n \in \mathbb{N}}及び、測度\mu^{(t)}を選択し、式(12)に代入することで、\dot{x}(t)とその時点での係数x(t)、入力信号u(t)との関係を導き、式(1)に当てはめて係数行列ABを設定する。

HiPPO-LegS

HiPPOの論文では、様々な基底や測度を用いた場合について検証されており、LSSLではその中で最も良い結果を示したScaled Legendre Measure (LegS)という設定を用いて得られる行列Aに初期化する。この設定を用いたHiPPO-LegSは、基底としてルシャンドル多項式を用い、すべての履歴に対して一様に重みをつける\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]}を用いるものである。

ここではまず、ルシャンドル多項式の性質や、基底として用いた場合の\omega_tに関する直交性などを確認し、正規直交基底を求め、式(12)に代入してODEを求める。

正規直交規定および測度

ルシャンドル多項式は、以下のn次多項式である。

\begin{align} P_n (x) := \frac{1}{2^n n!} \frac{d^n}{{dx}^n} [ (x^2 - 1 )^n ] \end{align}

測度\omega^{\text{leg}} = \bold{1}_{[-1, 1]}に関して、以下の直行性が成り立つ。

\begin{align} \frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx = \delta_{nm} \end{align}

さらに、以下の性質を持つ。

\begin{align} P_n (1) &= 1 \\ P_n (-1) &= (-1)^n \\ (2n+1) P_n &= P'_{n+1} - P'_{n-1} \\ P'_{n+1} &= (n+1) P_n + x P'_n \end{align}

以上のルシャンドル多項式は区間x \in [-1, 1]で成り立つ性質であったが、区間[0, t]についても成り立つようにスケーリングするため、y = \frac{t}{2}(x+1)、つまりx = \frac{2}{t} y - 1と変数変換する。すると、dx = \frac{2}{t} dyであり、式(14)の左辺を変形すると、以下のようになる。

\begin{align*} & \frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx \\ =& \frac{2n+1}{2} \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \frac{2}{t}dy \\ =& (2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy \end{align*}

従って、以下が成り立つ。

\begin{align} (2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy = \delta_{nm} \end{align}

この式より、新たな測度\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]}を用い、さらに以下の正規直交規定を用いることにする。

\begin{align} p_n (t, y) = (2n+1)^{1/2} P_n ( \frac{2}{t} y - 1 ) \end{align}

すると、以下が成り立つ。

\begin{align} \int^t_{0} p_n(t, y) p_m(t, y) \omega_t dy = \delta_{nm} \end{align}

ODEの導出

HiPPO-LegSが用いるODEを導出する前に、導出で用いるルシャンドル基底の微分が満たす性質について述べる。まず、式(17)から、以下が導かれる。

\begin{align} P'_{n+1} = (2n+1) P_n + (2n-3)P_{n-2} + \ldots, \end{align}

次数を一つずらすと、

\begin{align} P'_{n} = (2n-1) P_{n-1} + (2n-5)P_{n-3} + \ldots, \end{align}

これらの性質を用いて(x+1)P'_n (x)を計算すると、以下が成り立つ。

\begin{align} (x+1) P'_{n} =& x P'_{n} + P'_{n} = P'_{n+1} + P'_n - (n+1) P_n \notag \\ &(\because \text{式(18)より、} xP'_n = P'_n - (n+1) P_n) \notag \\ =& nP_n + (2n-1) P_{n-1} + (2n-3)P_{n-2} + \ldots \\ &(\because \text{式(22), (23)を代入}) \notag \\ \end{align}

まず、\omega_tと、p_n(t, y)の微分をそれぞれ求める。

\begin{align} \frac{\partial}{\partial t} \omega(t) =& -t^{-2} \mathcal{1}_{[0,1]} + t^{-1} \delta_t \notag \\ =& t^{-1} ( - \omega (t) + \delta_t) \\ \frac{\partial}{\partial t} p_n (t, y) &= - (2n+1)^{\frac{1}{2}} 2y t^{-2} P_n' (\frac{2}{t} y - 1 ) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} ((\frac{2}{t} y - 1) + 1 ) P_n' (\frac{2}{t} y - 1 ) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} (x + 1 ) P_n' (x) \ (\because x = \frac{2}{t} y - 1) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [nP_n (x) + (2n-1) P_{n-1}(x) + (2n-3) P_{n-2}(x) + \ldots ] \ (\because 式(24)を用いた。) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \notag \\ & (\because 式(20)より、P_n (x) = (2n+1)^{-\frac{1}{2}} p_n (t, y) ) \end{align}

これらを用いて、式(12)のODEの右辺を計算する。

まず、第一項を求める。

\begin{align} &\int^t_0 u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag \\ \simeq & \int^t_0 g (t, y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag \\ =&\int^t_0 ( \sum_{n=0}^{N-1} x_n (t) p_n (t, y) ) \{ - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \} \omega (t, y) dy \notag \\ &(\because 式(26)を用いた。) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \\ &(\because 式(21)の正規直交性を用いた。) \notag \\ \end{align}

つぎに、第二項を求める。

\begin{align} &\int^t_0 u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \notag \\ =&\int^t_0 u(y) p_n (t, y) ( t^{-1} ( - \omega (t) + \delta_t) ) dy \ (\because 式(25)を用いた。) \notag \\ =& -t^{-1} \int^t_0 u(y) p_n (t, y) \omega (t) dy + t^{-1} \int^t_0 u(y) p_n (t, y) \delta_t dy \notag \\ =& -t^{-1} x_n (t) + t^{-1} u(t) p_n (t, t) \notag \\ =& -t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t) \end{align}

式(12)の右辺に式(27), (28)を代入して、以下が成り立つ。

\begin{align} \dot{x}_n (t) =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag \\ &-t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [(n+1) (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag \\ & + t^{-1} (2n+1)^{1/2} u(t) \end{align}

式(29)をベクトル化すると、以下の式(30)が成立する。

\begin{align} \dot{x}(t) &= - \frac{1}{t} A x(t) + \frac{1}{t} B u(t) \\ A_{nk} &= \left\{ \begin{array}{ll} (2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\ n+1 & \text{if} \ n=k \\ 0 & \text{if} \ n<k \\ \end{array} \right. \\ B_n &= (2n+1)^{\frac{1}{2}} \\ \end{align}

式(31)の行列Aは、LSSL、S4、H3、mambaなどの後続研究では、HiPPO行列と呼ばれる。

HiPPO-LegSは時間スケールに依存しない。

式(30)のODEを計算するにあたり、式(6)を求めたように両辺を積分し、GBT(\alpha=\frac{1}{2})を使って離散化して、x_n (t)の更新式を得る。

\begin{align} x(t+\Delta t) - x(t) &= \int^{t+\Delta t}_t - \frac{1}{t} A x(t) + \frac{1}{t} B u(t) dt \notag \\ &\simeq \frac{\Delta t}{2}{(-\frac{1}{t}Ac(t) + \frac{1}{t}Bu(t))+(-\frac{1}{t+\Delta t}Ac(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t+\Delta t))} \notag \end{align}

ゼロ次ホールドを仮定し、u(t) = u(t+\Delta t)であるとすると、

\begin{align} x(t+\Delta t) - x(t) &= \frac{\Delta t}{2}(-\frac{1}{t}Ax(t) + \frac{1}{t}Bu(t)) + \frac{\Delta t}{2}(-\frac{1}{t+\Delta t}Ax(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t)) \notag \\ x(t+\Delta t) + \frac{\Delta t}{2(t+\Delta t)}Ax(t+\Delta t) &= x(t) - \frac{\Delta t}{2t} Ax(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t) \notag \\ (I + \frac{\Delta t}{2(t+\Delta t)}A) x(t+\Delta t) &= (I - \frac{\Delta t}{2t} A)x(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t) \end{align}

ここで、時間を離散化し、t=k\Delta tとし、x_k := x(k\Delta t)u_k := u(k \Delta t)とすれば、

\begin{align*} (I + \frac{\Delta t}{2(k\Delta t+\Delta t)}A) x_{k+1} &= (I - \frac{\Delta t}{2k\Delta t} A)x_k + (\frac{\Delta t}{2k\Delta t} + \frac{\Delta t}{2(k\Delta t+\Delta t)})Bu_k \end{align*}

従って、以下が成り立つ。

\begin{align} (I + \frac{1}{2(k+1)}A) x_{k+1} &= (I - \frac{1}{2k} A)x_k + (\frac{1}{2k} + \frac{1}{2(k+1)})Bu_k \end{align}

式(35)では、時間スケール\Delta tがなくなっている。すなわち、HiPPO-LegSは時間スケールに依存しない。これはHiPPO-LegSが持つ特殊な特徴である。

この式に従ってx_n (t)を更新する。入力信号の履歴u_{\leq t}は、各時刻で式(9)に従って以下のように再構成することが出来る。

\begin{align*} u_{\leq t} \simeq g^{(t)} &= \sum_{n=0}^{N-1} x_n (t) p_n (t, y) \\ &= \sum_{n=0}^{N-1} x_n (t) (2n+1)^{\frac{1}{2}} P_n (\frac{2y}{t} - 1) \end{align*}

LSSL(Linear State-Space Layers)

ここまで、HiPPOおよびHiPPO-LegSについて説明してきた。HiPPO-LegSでは、式(30)のODEが導かれ、式(35)のように離散化された。LSSLやその発展であるS4、H3、mambaなどでは、式(1),(2)の線形システムを考え、式(6)、(7)の形に離散化するが、行列AはHiPPO-LegSから導かれ、HiPPO行列と呼ばれる式(31)と同じになるように初期化する。このように初期化することで、x(t)に入力信号u_{\leq t}の履歴を記憶することを可能とし、MNISTベンチマークでの性能を60%から98%に上昇させることが出来る。ここでは、これらの式(6), (7)から導かれるLSSLの性質について見ていく。

畳み込みによる高速化

式(6),(7)を見た時に思ったかもしれないが、式(6)を式(7)に代入することで、以下のようにxを用いずにu_{\leq t}のみで出力y_kを計算することが出来る。ただし、x_{-1} = 0とする。

\begin{align*} y_k =& C x_k + D u_k \\ =& C (\bar{A} x_{k-1} + \bar{B} u_k) + D u_k \\ =& C \bar{A} (\bar{A} x_{k-2} + \bar{B} u_{k-1}) + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^2 (\bar{A} x_{k-3} + \bar{B} u_{k-2}) + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^3 (\bar{A} x_{k-4} + \bar{B} u_{k-3}) + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ \vdots \\ =& C (\bar{A})^k x_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^k (\bar{A} x_{-1} + \bar{B} u_0) + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^k \bar{B} u_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ &(\because x_{-1} = 0) \end{align*}

従って、\mathcal{K}_k (\bar{A}, \bar{B}, C) = (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B})として、以下が成り立つ。

\begin{align} y_k &= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) * u_{\leq k} + D u_k \notag \\ &= \mathcal{K}_k (\bar{A}, \bar{B}, C) * u_{\leq k} + D u_k \end{align}

ここで、以下のように置いた。

\begin{align} \mathcal{K}_k (\bar{A}, \bar{B}, C) &= (C \bar{A}^i B)_{i = 0, \ldots, k} \notag \\ &= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) \\ u_{\leq k} &= (u_0, u_1, \ldots, u_k) \end{align}

また、系列長Lの入力信号u_{\leq L-1} = (u_0, u_1, \ldots, u_{L-1})が与えられたとき、出力y_{L-1}は、以下のように表される。

\begin{align} y_{L-1} = \mathcal{K}_{L-1} (\bar{A}, \bar{B}, C) * u_{\leq L-1} + D u_{L-1} \end{align}

このように畳み込み形式で計算することで、全時刻にわたるy \in \mathbb{R}^{H \times L}を三度のFFTで一度に求めることが出来る。

ただし、LSSLの計算にはボトルネックがある。式(6)の再帰形式では、離散化された状態行列A、つまり\bar{A}の、行列-ベクトル乗算(matrix-vector multiplication; MVM)、そして式(35)の畳み込み形式では、行列\bar{A}のべき乗を多く含むKrylov function \mathcal{K}_Lが、それぞれボトルネックとなる。これらをうまく解決し、効率よく計算するアルゴリズムがLSSLでも提案されているが、S4でより効率的な手法が提案されているため、ここでは紹介しない。(S4については、次の記事で説明する。)

LSSLの表現力

ここまで、LSSLが式(6)、(7)の再帰形式、式(39)の畳み込み形式の二つの形式を持つことを見てきた。LSSLでは線形計算のみを行い、RNNやTransformerなどで見られるような非線形関数は用いていない。それでもLSSLは高い表現力を持つことを示す。

畳み込みはLSSLである。

Under Construction

RNNはLSSLである。

Under Construction

Lemma [3, p.6, 3.1]

一次元ゲート付き再帰\bold{x_t = (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t}(ただし、\bold{\sigma}はシグモイド関数で\bold{z}は任意の表現)は、1次元の線形ODE\bold{\dot{x}(t) = -x(t) + u(t)}\bold{\alpha=1}のGBT(つまり、後退オイラー法)で離散化したものとみなすことが出来る。

証明:
\dot{x}(t) = -x(t) + u(t)を後退オイラー法で離散化すると、以下が成り立つ。

\begin{align*} x_{t-1} = x_t - \Delta t (- x_t + u_t) \end{align*}

ここで、\Delta t = e^zとすると、

\begin{align*} x_{t-1} = x_t - e^z (- x_t + u_t) = (1 + e^z)x_t - e^z u_t \end{align*}

従って、以下が成り立つ。

\begin{align} x_t &= \frac{1}{1 + e^z} x_{t-1} + \frac{e^z}{1 + e^z} u_t \notag \\ &= \frac{e^{-z}}{1 + e^{-z}} x_{t-1} + \frac{1}{1 + e^{-z}} u_t \notag \\ &= (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t \quad (\because \sigma(z) = \frac{1}{1 + e^{-z}}) \end{align} (証明終)

Lemma [3, p.6, 3.2]

Under Construction

Deep LSSLs

Under Construction

参考文献

[1] Voelker, A., Kajić, I., & Eliasmith, C. (2019). Legendre memory units: Continuous-time representation in recurrent neural networks. Advances in neural information processing systems, 32.
[2] Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33, 1474-1487.
[3] Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., & Ré, C. (2021). Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, 572-585.
[4] Gu, A., Goel, K., & Ré, C. (2021). Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
[5] Gu, A., Johnson, I., Timalsina, A., Rudra, A., & Ré, C. (2022). How to train your hippo: State space models with generalized orthogonal basis projections. arXiv preprint arXiv:2206.12037.
[6] 角田良太郎「HiPPO/S4解説」, https://techblog.morphoinc.com/entry/2022/05/24/102648, 2024年2月10日閲覧
[7] Guofeng Zhang, Tongwen Chen, and Xiang Chen. Performance recovery in digital implementation of analogue systems. SIAM journal on control and optimization, 45(6):2207–2223, 2007.
[8] 岡島寛「状態フィードバック制御・状態方程式に基づく制御のまとめ」、制御工学の教科書、2024年2月17日公開、2024年2月21日閲覧

Discussion