Zenn
😎

mambaの理論を理解する①:HiPPOフレームワークとLSSL

2024/02/10に公開

はじめに

この前のブログでは、mambaの論文を翻訳した。

https://izmyon.hatenablog.com/entry/2023/12/11/155551

本シリーズでは、mambaの理論的背景を理解するために、それらの先行研究を順々にまとめて解説していく。重要な先行研究としては、LMU, HiPPO, LSSL, S4, H3などがあり、主に以下の流れで解説する。

  1. HiPPOフレームワークとLSSL
  2. S4のアルゴリズム
  3. H3とHyena Hierarchy
  4. mamba

本記事は、第一段として、HiPPOフレームワークとLSSLについてまとめる。

元論文を以下に記す。

HiPPO

https://arxiv.org/abs/2008.07669

LSSL

https://arxiv.org/abs/2110.13985

この記事の続きはこちら

https://zenn.dev/izmyon/articles/c56a2fd6670546

読者の方へ

  • 補足や訂正などがあれば、コメントにて優しく丁寧にご教示いただけると喜びます。
  • 書くのけっこう大変だったのでバッチを送っていただけると嬉しいです。たくさんもらえると僕のやる気が上がって投稿頻度が上がるかもしれません。(逆にもらえないと下がるかもしれない。)

状態空間モデル

時刻ttにおける1次元の入力信号u(t)u(t)を、NN次元の潜在空間x(t)x(t)に投影し、出力信号y(t)y(t)を計算する以下の状態空間モデルで表される線形システムを考える。LSSL、S4、H3、mambaなどの一連の状態空間モデルは、下記の方程式を離散化し、長さLLの系列を同じく長さLLの系列に変換するRLRL\mathbb{R}^L \rightarrow \mathbb{R}^Lのsequence-to-sequenceモデルを用いる。

x˙(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t) \begin{align} \dot{x}(t) = A x(t) + B u(t) \quad \\ y(t) = C x(t) + D u(t) \quad \end{align}

状態空間モデルは、Hidden Markov Models (HMM)のように、制御工学や機械学習などの幅広い分野でみられる潜在空間を利用するモデルであり、行列ARN×NA \in \mathbb{R}^{N \times N}BRN×1B \in \mathbb{R}^{N \times 1}CR1×NC \in \mathbb{R}^{1 \times N}DR1×1D \in \mathbb{R}^{1 \times 1}は、勾配降下によって求められる学習パラメータである。

これは連続時間のシステムであるが、コンピュータ上では離散時間信号を扱うため、式(1)、式(2)の離散化を行う。ここで、式(2)の離散化は簡単だが問題なのは式(1)の常微分方程式(ODE)を離散化することである。

ゼロ次ホールド(ZOH)と同様に、時刻ttにおける入力信号u(t)u(t)が次のタイムステップΔt\Delta tまで一定である、つまりu(t+Δt)=u(t)u(t+\Delta t) = u(t)であると仮定する。(1)式の右辺をttx(t)x(t)の関数として、x˙(t)=f(t,x(t))\dot{x}(t) = f(t, x(t))と表す、離散化された時刻tit_iにおいて、状態x(t0),x(t1),x(t_0), x(t_1), \ldotsは、ピカールの逐次近似法x(ti+1)=x(ti)+titi+1f(s,x(s))dsx(t_{i+1}) = x(t_i) + \int^{t_{i+1}}_{t_i} f(s, x(s))dsによって求められる。この右側積分を推定する方法には様々な方法があり、ここでは、Generalized bilinear transform (GBT)を用いる。この方法では、定数α\alphaを用いて、以下のように右側積分を近似する。

tt+Δf(s,x(s))dsαΔtf(t+Δ,x(t+Δt))+(1α)Δtf(t,x(t))=αΔt(Ax(t+Δ)+Bu(t+Δ))+(1α)Δt(Ax(t)+Bu(t))=αΔt(Ax(t+Δ)+Bu(t))+(1α)Δt(Ax(t)+Bu(t))(u(t+Δ)=u(t))=αΔtAx(t+Δ)+(1α)ΔtAx(t)+BΔtu(t) \begin{align*} \begin{split} &\int^{t+\Delta}_{t} f(s, x(s))ds \\ &\simeq \alpha \Delta t f(t+\Delta, x(t+\Delta t)) + (1-\alpha) \Delta t f(t, x(t)) \\ &= \alpha \Delta t (A x(t+\Delta) + B u(t+\Delta)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \\ &= \alpha \Delta t (A x(t+\Delta) + B u(t)) + (1-\alpha) \Delta t (A x(t) + B u(t)) \quad (\because u(t+\Delta) = u(t)) \\ &= \alpha \Delta t A x(t+ \Delta) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t) \end{split} \end{align*}

この結果を用いると、x(t+Δt)=x(t)+tt+Δtf(s,x(s))dsx(t+\Delta t) = x(t) + \int^{t+\Delta t}_{t} f(s, x(s))dsより、

x(t+Δt)=x(t)+αΔtAx(t+Δt)+(1α)ΔtAx(t)+BΔtu(t)(IαΔtA)x(t+Δt)=(I+(1α)ΔtA)x(t)+BΔtu(t)x(t+Δt)=(IαΔtA)1(I+(1α)ΔtA)x(t)+(IαΔtA)1BΔtu(t) \begin{align*} x(t+ \Delta t) &= x(t) + \alpha \Delta t A x(t+ \Delta t) + (1-\alpha) \Delta t A x(t) + B \Delta t u(t) \notag \\ (I- \alpha \Delta t A) x(t+\Delta t) &= (I + (1-\alpha) \Delta t A) x(t) + B \Delta t u(t) \notag \\ \therefore x(t+\Delta t) &= \left(I - \alpha \Delta t A \right)^{-1} \left(I + (1- \alpha) \Delta t A\right) x(t) + \left( I - \alpha \Delta t A \right)^{-1} B \Delta t u(t) \end{align*}

ここでは、α=12\alpha = \frac{1}{2}として、下式を用いる。

x(t+Δt)=(I12ΔtA)1(I+12ΔtA)x(t)+(I12ΔtA)1BΔtu(t) \begin{align} \begin{split} x(t+\Delta t) = \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right) x(t) + \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t u(t) \end{split} \end{align}

さらに、以下のように行列A,BA, Bを行列Aˉ,Bˉ\bar{A}, \bar{B}に変換する。

Aˉ=(I12ΔtA)1(I+12ΔtA)Bˉ=(I12ΔtA)1BΔt \begin{align} \bar{A} = \left(I - \frac{1}{2} \Delta t A \right)^{-1} \left(I + \frac{1}{2} \Delta t A\right) \\ \bar{B} = \left( I - \frac{1}{2} \Delta t A \right)^{-1} B \Delta t \end{align}

すると、以下のように式(1), (2)の連続時間システムを離散化することが出来る。

xt=Aˉxt1+Bˉutyt=Cxt+Dut \begin{align} x_t &= \bar{A} x_{t-1} + \bar{B} u_t \quad \\ y_t &= C x_t + D u_t \quad \end{align}

HiPPO (High-Order Polynomial Projection Operator; 高次多項式投影演算子)

ここまで線形システム(1)、(2)を離散化し、新たなシステムである式(6),(7)を得た。次に、本節で説明するHiPPOと呼ばれるフレームワークを適用し、以下のように係数行列AAを初期化する。

Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<k A_{nk} = \left\{ \begin{array}{ll} (2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\ n+1 & \text{if} \ n=k \\ 0 & \text{if} \ n<k \\ \end{array} \right.

なぜこのような初期化を行うのかをこれから説明化するが、長いのでLSSLの全体像を先につかみたい場合は、次の節「畳み込みによる高速化」に進んでもらって構わない。

HiPPOの定義

時間変化する(,t](-\infty, t]でサポートされる測度族μ(t)\mu^{(t)}NN次元の(直交)多項式が張る部分空間G\mathcal{G}、そして連続入力信号u:R0Ru: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R}が与えられた時、uuを最適化された投影係数x:R0RNx: \mathbb{R}_{\geq 0} \rightarrow \mathbb{R}^Nに移す演算hippo\text{hippo}を定義する。この演算は、以下のように投影演算子projt\text{proj}_tと、係数抽出演算子coeft\text{coef}_tをすべての時間ステップttで求め、それらの合成coeftprojt\text{coef}_t \circ \text{proj}_tである。(つまり、(hippo(u))(t)=coeft(projt(u))(\text{hippo}(u))(t)=\text{coef}_t (\text{proj}_t (u))である。)

  1. projt\text{proj}_tは、時刻ttまでの信号uu、つまりu<t:={u(y)}ytu_{<t}:=\{u(y)\}_{y \leq t}を、推論誤差utg(t)L2(μ(t))\| u_{\leq t} - g^{(t)} \|_{L_2 (\mu^{(t)})}を最小にする多項式g(t)Gg^{(t)} \in \mathcal{G}に射影する。

  2. coeft:GRN\text{coef}_t: \mathcal{G} \rightarrow \mathbb{R}^Nは、多項式g(t)g^{(t)}を、測度μ(t)\mu^{(t)}に関して定義された直交多項式の基底関数の係数x(t)RNx(t) \in \mathbb{R}^Nに射影する。

HiPPOを導くためのHiPPOフレームワーク

{Pn(t)}nN\{P_n^{(t)}\}_{n \in \mathbb{N}}は、時間変化する測度μ(t)\mu^{(t)}に関する直交多項式列であり、pn(t)p_n^{(t)}を、Pn(t)P_n^{(t)}を正規した多項式、すなわちpn(t)=Pn(t)/Pn(t)μ2p_n^{(t)} = P_n^{(t)}/ \| P_n^{(t)} \|_{\mu}^2であるとする。正規直交基底の定義より、以下が成り立つ。

pn(t)(y)pm(t)(y)ω(t)(y)dy=δm,n \begin{align} \int^{\infty}_{-\infty} p_n^{(t)}(y) p_m^{(t)}(y) \omega^{(t)}(y) dy = \delta_{m,n} \end{align}

ここで、ω(t)(y)\omega^{(t)}(y)は重みであり、dμ(t)=ω(t)(y)dy,dμ(t)=1d \mu^{(t)} = \omega^{(t)}(y) dy, \int^{\infty}_{-\infty} d \mu^{(t)} = 1である。

入力信号の履歴ut:={u(y)}ytu_{\leq t}:=\{u(y)\}_{y \leq t}が、ステップサイズΔt\Delta tで連続関数u(y)u(y)からサンプリングされたと考える。N個の基底{pn(t)}n=0,1,,N1\{p_n^{(t)}\}_{n=0,1,\ldots, N-1}を用いて、utu_{\leq t}を通る全時刻にわたる連続関数u(y)u(y)を多項式g(t)g^{(t)}で近似する。すなわち、utu_{\leq t}{pn(t)}n=0,1,,N1\{p_n^{(t)}\}_{n=0,1,\ldots, N-1}が張る空間へ射影し、多項式g(t)g^{(t)}が得られるとする。

utg(t)=n=0N1xn(t)pn(t) \begin{align} u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)} \end{align}

このように近似すると、時刻ttまでの入力信号の履歴を、NN個の係数{xn(t)}n=0,1,,N1\{x_n^{(t)}\}_{n=0,1,\ldots, N-1}で表現することが出来るようになる。係数xn(t)x_n (t)は、時刻ttまでの入力信号の履歴utu_{\leq t}を用いて、フーリエ係数を求める要領で以下のように求めることが出来る。

xn(t)=g(t)pn(t)ω(t)(y)dyutpn(t)ω(t)(y)dy \begin{align} x_n (t) = \int^{\infty}_{-\infty} g^{(t)} p_n^{(t)} \omega^{(t)}(y) dy &\simeq \int^{\infty}_{-\infty} u_{\leq t} p_n^{(t)} \omega^{(t)}(y) dy \end{align}

実は、式(10)に従って、以下の入力信号utu_{\leq t}の履歴からNN個の係数を計算し、出力することがhippo\text{hippo}演算に他ならない。

hippo(ut):=(x0,x1,,xN1)where utg(t)=n=0N1xn(t)pn(t) \begin{align*} \text{hippo}(u_{\leq t}) := (x_0, x_1, \ldots, x_{N-1}) \\ \text{where} \ u_{\leq t} \simeq g^{(t)} = \sum_{n=0}^{N-1} x_n (t) p_n^{(t)} \end{align*}

しかしながら、これらの係数を各時刻で計算するのは容易ではない。そこでHiPPOフレームワークでは、式(1)を離散化して式(6)の更新式を導いたように、x(t)x(t)が満たすODEを求めた後、それを離散化して更新式を求め、各時刻でx(t)x(t)を更新していくことを考える。

係数が満たすODE

以上より、下式が成り立つことが分かった。

xn(t)=u(y)pn(t,y)ω(t,y)dy \begin{align} x_n (t) = \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy \end{align}

両辺を微分して、係数xn(t)x^{(t)}_nが満たすODEを導く。ここで、微分と積分の可換性を認めた。

x˙n(t)=tu(y)pn(t,y)ω(t,y)dy=t(u(y)pn(t,y)ω(t,y))dy=u(y)(tpn(t,y))ω(t,y)dy +u(y)pn(t,y)(tω(t,y))dy \begin{align*} \dot{x}_n (t) &= \frac{\partial}{\partial t} \int^{\infty}_{-\infty} u(y) p_n (t, y) \omega(t, y) dy \\ &= \int^{\infty}_{-\infty} \frac{\partial}{\partial t} (u(y) p_n (t, y) \omega(t, y) ) dy \\ &= \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \end{align*}

以上より、以下の式(11)が成り立つ。

x˙n(t)=u(y)(tpn(t,y))ω(t,y)dy +u(y)pn(t,y)(tω(t,y))dy \begin{align} \dot{x}_n (t) = \int^{\infty}_{-\infty} u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy + \int^{\infty}_{-\infty} u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \end{align}

HiPPOフレームワークでは、具体的に用いる基底{Pn(t)}nN\{P_n^{(t)}\}_{n \in \mathbb{N}}及び、測度μ(t)\mu^{(t)}を選択し、式(12)に代入することで、x˙(t)\dot{x}(t)とその時点での係数x(t)x(t)、入力信号u(t)u(t)との関係を導き、式(1)に当てはめて係数行列AABBを設定する。

HiPPO-LegS

HiPPOの論文では、様々な基底や測度を用いた場合について検証されており、LSSLではその中で最も良い結果を示したScaled Legendre Measure (LegS)という設定を用いて得られる行列AAに初期化する。この設定を用いたHiPPO-LegSは、基底としてルシャンドル多項式を用い、すべての履歴に対して一様に重みをつけるωt=1t1[0,t]\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]}を用いるものである。

ここではまず、ルシャンドル多項式の性質や、基底として用いた場合のωt\omega_tに関する直交性などを確認し、正規直交基底を求め、式(12)に代入してODEを求める。

正規直交規定および測度

ルシャンドル多項式は、以下のnn次多項式である。

Pn(x):=12nn!dndxn[(x21)n] \begin{align} P_n (x) := \frac{1}{2^n n!} \frac{d^n}{{dx}^n} [ (x^2 - 1 )^n ] \end{align}

測度ωleg=1[1,1]\omega^{\text{leg}} = \bold{1}_{[-1, 1]}に関して、以下の直行性が成り立つ。

2n+1211Pn(x)Pm(x)dx=δnm \begin{align} \frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx = \delta_{nm} \end{align}

さらに、以下の性質を持つ。

Pn(1)=1Pn(1)=(1)n(2n+1)Pn=Pn+1Pn1Pn+1=(n+1)Pn+xPn \begin{align} P_n (1) &= 1 \\ P_n (-1) &= (-1)^n \\ (2n+1) P_n &= P'_{n+1} - P'_{n-1} \\ P'_{n+1} &= (n+1) P_n + x P'_n \end{align}

以上のルシャンドル多項式は区間x[1,1]x \in [-1, 1]で成り立つ性質であったが、区間[0, t]についても成り立つようにスケーリングするため、y=t2(x+1)y = \frac{t}{2}(x+1)、つまりx=2ty1x = \frac{2}{t} y - 1と変数変換する。すると、dx=2tdydx = \frac{2}{t} dyであり、式(14)の左辺を変形すると、以下のようになる。

2n+1211Pn(x)Pm(x)dx=2n+120tPn(2ty1)Pm(2ty1)2tdy=(2n+1)0tPn(2ty1)Pm(2ty1)ωleg(2ty1)1tdy \begin{align*} & \frac{2n+1}{2} \int^1_{-1} P_n(x) P_m(x) dx \\ =& \frac{2n+1}{2} \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \frac{2}{t}dy \\ =& (2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy \end{align*}

従って、以下が成り立つ。

(2n+1)0tPn(2ty1)Pm(2ty1)ωleg(2ty1)1tdy=δnm \begin{align} (2n+1) \int^t_{0} P_n( \frac{2}{t} y - 1 ) P_m( \frac{2}{t} y - 1 ) \omega^{\text{leg}} (\frac{2}{t} y - 1 ) \frac{1}{t} dy = \delta_{nm} \end{align}

この式より、新たな測度ωt=1t1[0,t]\omega_t = \frac{1}{t} \mathbb{1}_{[0, t]}を用い、さらに以下の正規直交規定を用いることにする。

pn(t,y)=(2n+1)1/2Pn(2ty1) \begin{align} p_n (t, y) = (2n+1)^{1/2} P_n ( \frac{2}{t} y - 1 ) \end{align}

すると、以下が成り立つ。

0tpn(t,y)pm(t,y)ωtdy=δnm \begin{align} \int^t_{0} p_n(t, y) p_m(t, y) \omega_t dy = \delta_{nm} \end{align}

ODEの導出

HiPPO-LegSが用いるODEを導出する前に、導出で用いるルシャンドル基底の微分が満たす性質について述べる。まず、式(17)から、以下が導かれる。

Pn+1=(2n+1)Pn+(2n3)Pn2+, \begin{align} P'_{n+1} = (2n+1) P_n + (2n-3)P_{n-2} + \ldots, \end{align}

次数を一つずらすと、

Pn=(2n1)Pn1+(2n5)Pn3+, \begin{align} P'_{n} = (2n-1) P_{n-1} + (2n-5)P_{n-3} + \ldots, \end{align}

これらの性質を用いて(x+1)Pn(x)(x+1)P'_n (x)を計算すると、以下が成り立つ。

(x+1)Pn=xPn+Pn=Pn+1+Pn(n+1)Pn(式(18)より、xPn=Pn(n+1)Pn)=nPn+(2n1)Pn1+(2n3)Pn2+(式(22), (23)を代入) \begin{align} (x+1) P'_{n} =& x P'_{n} + P'_{n} = P'_{n+1} + P'_n - (n+1) P_n \notag \\ &(\because \text{式(18)より、} xP'_n = P'_n - (n+1) P_n) \notag \\ =& nP_n + (2n-1) P_{n-1} + (2n-3)P_{n-2} + \ldots \\ &(\because \text{式(22), (23)を代入}) \notag \\ \end{align}

まず、ωt\omega_tと、pn(t,y)p_n(t, y)の微分をそれぞれ求める。

tω(t)=t21[0,1]+t1δt=t1(ω(t)+δt)tpn(t,y)=(2n+1)122yt2Pn(2ty1)=(2n+1)12t1((2ty1)+1)Pn(2ty1)=(2n+1)12t1(x+1)Pn(x) (x=2ty1)=(2n+1)12t1[nPn(x)+(2n1)Pn1(x)+(2n3)Pn2(x)+] ((24)を用いた。)=(2n+1)12t1[n(2n+1)12pn(t,y)+(2n1)12pn1(t,y)+(2n3)12pn2(t,y)+]((20)より、Pn(x)=(2n+1)12pn(t,y)) \begin{align} \frac{\partial}{\partial t} \omega(t) =& -t^{-2} \mathcal{1}_{[0,1]} + t^{-1} \delta_t \notag \\ =& t^{-1} ( - \omega (t) + \delta_t) \\ \frac{\partial}{\partial t} p_n (t, y) &= - (2n+1)^{\frac{1}{2}} 2y t^{-2} P_n' (\frac{2}{t} y - 1 ) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} ((\frac{2}{t} y - 1) + 1 ) P_n' (\frac{2}{t} y - 1 ) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} (x + 1 ) P_n' (x) \ (\because x = \frac{2}{t} y - 1) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [nP_n (x) + (2n-1) P_{n-1}(x) + (2n-3) P_{n-2}(x) + \ldots ] \ (\because 式(24)を用いた。) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \notag \\ & (\because 式(20)より、P_n (x) = (2n+1)^{-\frac{1}{2}} p_n (t, y) ) \end{align}

これらを用いて、式(12)のODEの右辺を計算する。

まず、第一項を求める。

0tu(y)(tpn(t,y))ω(t,y)dy0tg(t,y)(tpn(t,y))ω(t,y)dy=0t(n=0N1xn(t)pn(t,y)){(2n+1)12t1[n(2n+1)12pn(t,y)+(2n1)12pn1(t,y)+(2n3)12pn2(t,y)+]}ω(t,y)dy((26)を用いた。)=(2n+1)12t1[n(2n+1)12xn(t)+(2n1)12xn1(t)+(2n3)12xn2(t)+]((21)の正規直交性を用いた。) \begin{align} &\int^t_0 u(y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag \\ \simeq & \int^t_0 g (t, y) (\frac{\partial}{\partial t} p_n (t, y)) \omega (t, y) dy \notag \\ =&\int^t_0 ( \sum_{n=0}^{N-1} x_n (t) p_n (t, y) ) \{ - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} p_n (t, y) + (2n-1)^{\frac{1}{2}} p_{n-1} (t, y) + (2n-3)^{\frac{1}{2}} p_{n-2}(t, y) + \ldots ] \} \omega (t, y) dy \notag \\ &(\because 式(26)を用いた。) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \\ &(\because 式(21)の正規直交性を用いた。) \notag \\ \end{align}

つぎに、第二項を求める。

0tu(y)pn(t,y)(tω(t,y))dy=0tu(y)pn(t,y)(t1(ω(t)+δt))dy ((25)を用いた。)=t10tu(y)pn(t,y)ω(t)dy+t10tu(y)pn(t,y)δtdy=t1xn(t)+t1u(t)pn(t,t)=t1xn(t)+t1(2n+1)1/2u(t) \begin{align} &\int^t_0 u(y) p_n (t, y) (\frac{\partial}{\partial t} \omega (t, y) ) dy \notag \\ =&\int^t_0 u(y) p_n (t, y) ( t^{-1} ( - \omega (t) + \delta_t) ) dy \ (\because 式(25)を用いた。) \notag \\ =& -t^{-1} \int^t_0 u(y) p_n (t, y) \omega (t) dy + t^{-1} \int^t_0 u(y) p_n (t, y) \delta_t dy \notag \\ =& -t^{-1} x_n (t) + t^{-1} u(t) p_n (t, t) \notag \\ =& -t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t) \end{align}

式(12)の右辺に式(27), (28)を代入して、以下が成り立つ。

x˙n(t)=(2n+1)12t1[n(2n+1)12xn(t)+(2n1)12xn1(t)+(2n3)12xn2(t)+]t1xn(t)+t1(2n+1)1/2u(t)=(2n+1)12t1[(n+1)(2n+1)12xn(t)+(2n1)12xn1(t)+(2n3)12xn2(t)+]+t1(2n+1)1/2u(t) \begin{align} \dot{x}_n (t) =& - (2n+1)^{\frac{1}{2}} t^{-1} [n (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag \\ &-t^{-1} x_n (t) + t^{-1} (2n+1)^{1/2} u(t) \notag \\ =& - (2n+1)^{\frac{1}{2}} t^{-1} [(n+1) (2n+1)^{-\frac{1}{2}} x_n (t) + (2n-1)^{\frac{1}{2}} x_{n-1} (t) + (2n-3)^{\frac{1}{2}} x_{n-2}(t) + \ldots ] \notag \\ & + t^{-1} (2n+1)^{1/2} u(t) \end{align}

式(29)をベクトル化すると、以下の式(30)が成立する。

x˙(t)=1tAx(t)+1tBu(t)Ank={(2n+1)1/2(2k+1)1/2if n>kn+1if n=k0if n<kBn=(2n+1)12 \begin{align} \dot{x}(t) &= - \frac{1}{t} A x(t) + \frac{1}{t} B u(t) \\ A_{nk} &= \left\{ \begin{array}{ll} (2n+1)^{1/2} (2k+1)^{1/2} & \text{if} \ n>k \\ n+1 & \text{if} \ n=k \\ 0 & \text{if} \ n<k \\ \end{array} \right. \\ B_n &= (2n+1)^{\frac{1}{2}} \\ \end{align}

式(31)の行列AAは、LSSL、S4、H3、mambaなどの後続研究では、HiPPO行列と呼ばれる。

HiPPO-LegSは時間スケールに依存しない。

式(30)のODEを計算するにあたり、式(6)を求めたように両辺を積分し、GBT(α=12\alpha=\frac{1}{2})を使って離散化して、xn(t)x_n (t)の更新式を得る。

x(t+Δt)x(t)=tt+Δt1tAx(t)+1tBu(t)dtΔt2(1tAc(t)+1tBu(t))+(1t+ΔtAc(t+Δt)+1t+ΔtBu(t+Δt)) \begin{align} x(t+\Delta t) - x(t) &= \int^{t+\Delta t}_t - \frac{1}{t} A x(t) + \frac{1}{t} B u(t) dt \notag \\ &\simeq \frac{\Delta t}{2}{(-\frac{1}{t}Ac(t) + \frac{1}{t}Bu(t))+(-\frac{1}{t+\Delta t}Ac(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t+\Delta t))} \notag \end{align}

ゼロ次ホールドを仮定し、u(t)=u(t+Δt)u(t) = u(t+\Delta t)であるとすると、

x(t+Δt)x(t)=Δt2(1tAx(t)+1tBu(t))+Δt2(1t+ΔtAx(t+Δt)+1t+ΔtBu(t))x(t+Δt)+Δt2(t+Δt)Ax(t+Δt)=x(t)Δt2tAx(t)+(Δt2t+Δt2(t+Δt))Bu(t)(I+Δt2(t+Δt)A)x(t+Δt)=(IΔt2tA)x(t)+(Δt2t+Δt2(t+Δt))Bu(t) \begin{align} x(t+\Delta t) - x(t) &= \frac{\Delta t}{2}(-\frac{1}{t}Ax(t) + \frac{1}{t}Bu(t)) + \frac{\Delta t}{2}(-\frac{1}{t+\Delta t}Ax(t+\Delta t) + \frac{1}{t+\Delta t}Bu(t)) \notag \\ x(t+\Delta t) + \frac{\Delta t}{2(t+\Delta t)}Ax(t+\Delta t) &= x(t) - \frac{\Delta t}{2t} Ax(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t) \notag \\ (I + \frac{\Delta t}{2(t+\Delta t)}A) x(t+\Delta t) &= (I - \frac{\Delta t}{2t} A)x(t) + (\frac{\Delta t}{2t} + \frac{\Delta t}{2(t+\Delta t)})Bu(t) \end{align}

ここで、時間を離散化し、t=kΔtt=k\Delta tとし、xk:=x(kΔt)x_k := x(k\Delta t)uk:=u(kΔt)u_k := u(k \Delta t)とすれば、

(I+Δt2(kΔt+Δt)A)xk+1=(IΔt2kΔtA)xk+(Δt2kΔt+Δt2(kΔt+Δt))Buk \begin{align*} (I + \frac{\Delta t}{2(k\Delta t+\Delta t)}A) x_{k+1} &= (I - \frac{\Delta t}{2k\Delta t} A)x_k + (\frac{\Delta t}{2k\Delta t} + \frac{\Delta t}{2(k\Delta t+\Delta t)})Bu_k \end{align*}

従って、以下が成り立つ。

(I+12(k+1)A)xk+1=(I12kA)xk+(12k+12(k+1))Buk \begin{align} (I + \frac{1}{2(k+1)}A) x_{k+1} &= (I - \frac{1}{2k} A)x_k + (\frac{1}{2k} + \frac{1}{2(k+1)})Bu_k \end{align}

式(35)では、時間スケールΔt\Delta tがなくなっている。すなわち、HiPPO-LegSは時間スケールに依存しない。これはHiPPO-LegSが持つ特殊な特徴である。

この式に従ってxn(t)x_n (t)を更新する。入力信号の履歴utu_{\leq t}は、各時刻で式(9)に従って以下のように再構成することが出来る。

utg(t)=n=0N1xn(t)pn(t,y)=n=0N1xn(t)(2n+1)12Pn(2yt1) \begin{align*} u_{\leq t} \simeq g^{(t)} &= \sum_{n=0}^{N-1} x_n (t) p_n (t, y) \\ &= \sum_{n=0}^{N-1} x_n (t) (2n+1)^{\frac{1}{2}} P_n (\frac{2y}{t} - 1) \end{align*}

LSSL(Linear State-Space Layers)

ここまで、HiPPOおよびHiPPO-LegSについて説明してきた。HiPPO-LegSでは、式(30)のODEが導かれ、式(35)のように離散化された。LSSLやその発展であるS4、H3、mambaなどでは、式(1),(2)の線形システムを考え、式(6)、(7)の形に離散化するが、行列AAはHiPPO-LegSから導かれ、HiPPO行列と呼ばれる式(31)と同じになるように初期化する。このように初期化することで、x(t)x(t)に入力信号utu_{\leq t}の履歴を記憶することを可能とし、MNISTベンチマークでの性能を60%から98%に上昇させることが出来る。ここでは、これらの式(6), (7)から導かれるLSSLの性質について見ていく。

畳み込みによる高速化

式(6),(7)を見た時に思ったかもしれないが、式(6)を式(7)に代入することで、以下のようにxxを用いずにutu_{\leq t}のみで出力yky_kを計算することが出来る。ただし、x1=0x_{-1} = 0とする。

yk=Cxk+Duk=C(Aˉxk1+Bˉuk)+Duk=CAˉ(Aˉxk2+Bˉuk1)+CBˉuk+Duk=C(Aˉ)2(Aˉxk3+Bˉuk2)+CAˉBˉuk1+CBˉuk+Duk=C(Aˉ)3(Aˉxk4+Bˉuk3)+C(Aˉ)2Bˉuk2+CAˉBˉuk1+CBˉuk+Duk=C(Aˉ)kx0++C(Aˉ)2Bˉuk2+CAˉBˉuk1+CBˉuk+Duk=C(Aˉ)k(Aˉx1+Bˉu0)++C(Aˉ)2Bˉuk2+CAˉBˉuk1+CBˉuk+Duk=C(Aˉ)kBˉu0++C(Aˉ)2Bˉuk2+CAˉBˉuk1+CBˉuk+Duk(x1=0) \begin{align*} y_k =& C x_k + D u_k \\ =& C (\bar{A} x_{k-1} + \bar{B} u_k) + D u_k \\ =& C \bar{A} (\bar{A} x_{k-2} + \bar{B} u_{k-1}) + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^2 (\bar{A} x_{k-3} + \bar{B} u_{k-2}) + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^3 (\bar{A} x_{k-4} + \bar{B} u_{k-3}) + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ \vdots \\ =& C (\bar{A})^k x_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^k (\bar{A} x_{-1} + \bar{B} u_0) + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ =& C (\bar{A})^k \bar{B} u_0 + \cdots + C (\bar{A})^2 \bar{B} u_{k-2} + C \bar{A} \bar{B} u_{k-1} + C \bar{B} u_k + D u_k \\ &(\because x_{-1} = 0) \end{align*}

従って、Kk(Aˉ,Bˉ,C)=(CBˉ,CAˉBˉ,,C(Aˉ)kBˉ)\mathcal{K}_k (\bar{A}, \bar{B}, C) = (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B})として、以下が成り立つ。

yk=(CBˉ,CAˉBˉ,,C(Aˉ)kBˉ)uk+Duk=Kk(Aˉ,Bˉ,C)uk+Duk \begin{align} y_k &= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) * u_{\leq k} + D u_k \notag \\ &= \mathcal{K}_k (\bar{A}, \bar{B}, C) * u_{\leq k} + D u_k \end{align}

ここで、以下のように置いた。

Kk(Aˉ,Bˉ,C)=(CAˉiB)i=0,,k=(CBˉ,CAˉBˉ,,C(Aˉ)kBˉ)uk=(u0,u1,,uk) \begin{align} \mathcal{K}_k (\bar{A}, \bar{B}, C) &= (C \bar{A}^i B)_{i = 0, \ldots, k} \notag \\ &= (C \bar{B}, C \bar{A} \bar{B}, \ldots, C (\bar{A})^k \bar{B}) \\ u_{\leq k} &= (u_0, u_1, \ldots, u_k) \end{align}

また、系列長LLの入力信号uL1=(u0,u1,,uL1)u_{\leq L-1} = (u_0, u_1, \ldots, u_{L-1})が与えられたとき、出力yL1y_{L-1}は、以下のように表される。

yL1=KL1(Aˉ,Bˉ,C)uL1+DuL1 \begin{align} y_{L-1} = \mathcal{K}_{L-1} (\bar{A}, \bar{B}, C) * u_{\leq L-1} + D u_{L-1} \end{align}

このように畳み込み形式で計算することで、全時刻にわたるyRH×Ly \in \mathbb{R}^{H \times L}を三度のFFTで一度に求めることが出来る。

ただし、LSSLの計算にはボトルネックがある。式(6)の再帰形式では、離散化された状態行列AA、つまりAˉ\bar{A}の、行列-ベクトル乗算(matrix-vector multiplication; MVM)、そして式(35)の畳み込み形式では、行列Aˉ\bar{A}のべき乗を多く含むKrylov function KL\mathcal{K}_Lが、それぞれボトルネックとなる。これらをうまく解決し、効率よく計算するアルゴリズムがLSSLでも提案されているが、S4でより効率的な手法が提案されているため、ここでは紹介しない。(S4については、次の記事で説明する。)

LSSLの表現力

ここまで、LSSLが式(6)、(7)の再帰形式、式(39)の畳み込み形式の二つの形式を持つことを見てきた。LSSLでは線形計算のみを行い、RNNやTransformerなどで見られるような非線形関数は用いていない。それでもLSSLは高い表現力を持つことを示す。

畳み込みはLSSLである。

Under Construction

RNNはLSSLである。

Under Construction

Lemma [3, p.6, 3.1]

一次元ゲート付き再帰xt=(1σ(z))xt1+σ(z)ut\bold{x_t = (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t}(ただし、σ\bold{\sigma}はシグモイド関数でz\bold{z}は任意の表現)は、1次元の線形ODEx˙(t)=x(t)+u(t)\bold{\dot{x}(t) = -x(t) + u(t)}α=1\bold{\alpha=1}のGBT(つまり、後退オイラー法)で離散化したものとみなすことが出来る。

証明:
x˙(t)=x(t)+u(t)\dot{x}(t) = -x(t) + u(t)を後退オイラー法で離散化すると、以下が成り立つ。

xt1=xtΔt(xt+ut) \begin{align*} x_{t-1} = x_t - \Delta t (- x_t + u_t) \end{align*}

ここで、Δt=ez\Delta t = e^zとすると、

xt1=xtez(xt+ut)=(1+ez)xtezut \begin{align*} x_{t-1} = x_t - e^z (- x_t + u_t) = (1 + e^z)x_t - e^z u_t \end{align*}

従って、以下が成り立つ。

xt=11+ezxt1+ez1+ezut=ez1+ezxt1+11+ezut=(1σ(z))xt1+σ(z)ut(σ(z)=11+ez) \begin{align} x_t &= \frac{1}{1 + e^z} x_{t-1} + \frac{e^z}{1 + e^z} u_t \notag \\ &= \frac{e^{-z}}{1 + e^{-z}} x_{t-1} + \frac{1}{1 + e^{-z}} u_t \notag \\ &= (1 - \sigma(z)) x_{t-1} + \sigma(z) u_t \quad (\because \sigma(z) = \frac{1}{1 + e^{-z}}) \end{align} (証明終)

Lemma [3, p.6, 3.2]

Under Construction

Deep LSSLs

Under Construction

参考文献

[1] Voelker, A., Kajić, I., & Eliasmith, C. (2019). Legendre memory units: Continuous-time representation in recurrent neural networks. Advances in neural information processing systems, 32.
[2] Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33, 1474-1487.
[3] Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., & Ré, C. (2021). Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34, 572-585.
[4] Gu, A., Goel, K., & Ré, C. (2021). Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
[5] Gu, A., Johnson, I., Timalsina, A., Rudra, A., & Ré, C. (2022). How to train your hippo: State space models with generalized orthogonal basis projections. arXiv preprint arXiv:2206.12037.
[6] 角田良太郎「HiPPO/S4解説」, https://techblog.morphoinc.com/entry/2022/05/24/102648, 2024年2月10日閲覧
[7] Guofeng Zhang, Tongwen Chen, and Xiang Chen. Performance recovery in digital implementation of analogue systems. SIAM journal on control and optimization, 45(6):2207–2223, 2007.
[8] 岡島寛「状態フィードバック制御・状態方程式に基づく制御のまとめ」、制御工学の教科書、2024年2月17日公開、2024年2月21日閲覧

Discussion

ログインするとコメントできます