🐍

時系列データ分析論文④「 S4 実装編 」

に公開

概要

前回に引き続き、S4論文の解説をします。今回は実装編です。

Github: https://github.com/state-spaces/s4

大変ありがたい事に、全実装は以下のページにまとまっています。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py

さて、理論編でどういった変数が登場し、どう計算すれば入力uからyが求まるのか、が分かりました。ではいったい、それらをどう実装に落とし込むのか。

ざっくり分けると、以下の要点をおさえる必要があります。

  • 何を学習するのか
  • どう計算するのか

ではこれらにそって解説していきましょう。

何を学習するのか

関数のフロー

モデル自体は S4Block になるますが、関係のある処理は全て self.layer = FFTConv / self.kernel = kernel_cls の中にあります。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1625-L1631

となっているため、注目すべき class は SSMKernelDPLR となります。

パラメータの初期化と学習パラメータの宣言について、クラスや関数の流れを、下記の instance を作成した際のフローチャートを記載しています。この流れに沿って見ていきましょう。

model = SSMKernelDPLR(disc="bilinear", d_model=1, d_state=4)

行列A, B

SSMの基本パラメータです。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L391

>>> A, B = transition("legs", 4)
>>> A = torch.as_tensor(A) # (N, N)
>>> B = torch.as_tensor(B)[:, 0] # (N,)
>>> A
tensor([[-1.0000,  0.0000,  0.0000,  0.0000],
        [-1.7321, -2.0000,  0.0000,  0.0000],
        [-2.2361, -3.8730, -3.0000,  0.0000],
        [-2.6458, -4.5826, -5.9161, -4.0000]], dtype=torch.float64)
>>> B
tensor([1.0000, 1.7321, 2.2361, 2.6458], dtype=torch.float64)

固有値分解

式(11)を再掲します。

A_{LegS} = S-pq^T

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L395-L396

>>> P = rank_correction("legs", 4, rank=1)
>>> P
tensor([[0.7071, 1.2247, 1.5811, 1.8708]])
>>> torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
tensor([[0.5000, 0.8660, 1.1180, 1.3229],
        [0.8660, 1.5000, 1.9365, 2.2913],
        [1.1180, 1.9365, 2.5000, 2.9580],
        [1.3229, 2.2913, 2.9580, 3.5000]])
>>> AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
>>> AP
tensor([[-0.5000,  0.8660,  1.1180,  1.3229],
        [-0.8660, -0.5000,  1.9365,  2.2913],
        [-1.1180, -1.9365, -0.5000,  2.9580],
        [-1.3229, -2.2913, -2.9580, -0.5000]], dtype=torch.float64)

この AP は式(11)で言うところの正規行列S=A+pq^Tに当たります。

式(12)を再掲します。

A_{LegS}=V\Lambda V^* - pq^T=V(\Lambda - (V^*p)(V^*q)^*)V^*

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L404-L413

>>> W_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
>>> W_re
tensor([-0.5000], dtype=torch.float64)
>>> W_im, V = torch.linalg.eigh(AP*-1j)
>>> W_im
tensor([-4.6033, -0.5565,  0.5565,  4.6033], dtype=torch.float64)
>>> V
tensor([[-0.2870-0.0000j,  0.6462+0.0000j, -0.6462+0.0000j,  0.2870+0.0000j],
        [-0.4335+0.1953j, -0.1925-0.4867j,  0.1925-0.4867j,  0.4335+0.1953j],
        [-0.1866+0.5358j, -0.0829+0.4139j,  0.0829+0.4139j,  0.1866+0.5358j],
        [ 0.4415+0.4181j,  0.1961-0.3030j, -0.1961-0.3030j, -0.4415+0.4181j]],
       dtype=torch.complex128)
>>> W = W_re + 1j * W_im
>>> W
tensor([-0.5000-4.6033j, -0.5000-0.5565j, -0.5000+0.5565j, -0.5000+4.6033j],
       dtype=torch.complex128)
>>> (V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)).real # ~= AP
tensor([[-0.5000,  0.8660,  1.1180,  1.3229],
        [-0.8660, -0.5000,  1.9365,  2.2913],
        [-1.1180, -1.9365, -0.5000,  2.9580],
        [-1.3229, -2.2913, -2.9580, -0.5000]], dtype=torch.float64)
>>> V_inv = V.conj().transpose(-1, -2)
>>> V_inv
tensor([[-0.2870+0.0000j, -0.4335-0.1953j, -0.1866-0.5358j,  0.4415-0.4181j],
        [ 0.6462-0.0000j, -0.1925+0.4867j, -0.0829-0.4139j,  0.1961+0.3030j],
        [-0.6462-0.0000j,  0.1925+0.4867j,  0.0829-0.4139j, -0.1961+0.3030j],
        [ 0.2870-0.0000j,  0.4335-0.1953j,  0.1866-0.5358j, -0.4415-0.4181j]],
       dtype=torch.complex128)

torch.linalg.eigh は固有値分解を行っており、正規行列Sを固有値分解してVを得ています。

そして W については、(V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)).realAP とほぼ等しくなる事から、S=V\Lambda V^* ( =AP ) なので、W\Lambda の固有値ベクトルです。

また、V_invV^* です。

行列B, P

>>> B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B
>>> P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P
>>> B
tensor([-0.2870-2.6425j,  0.6462+0.7193j, -0.6462+0.7193j,  0.2870-2.6425j],
       dtype=torch.complex128)
>>> P
tensor([[-0.2030-1.8685j,  0.4570+0.5086j, -0.4570+0.5086j,  0.2030-1.8685j]],
       dtype=torch.complex128)

nplr の返却

最終的に return W, P, B, V となっており、つまり、\Lambda(の固有値ベクトル), P, B', V を返しています。

離散化におけるdt

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1016

>>> model = SSMKernelDPLR(disc="bilinear", d_model=1, d_state=4)
>>> model.init_dt()
tensor([[-5.1659]])

inv_dtlog{\Delta } を表します。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L831-L833

\Delta の範囲(min ~ max)をハイパーパラメータで持っておいて、それを log 化した後に、その log の範囲の一様分布から選択しています。

行列C

C は以下のようになっており、random (虚数 torch.complex64) なパラメータ。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L868

channel の説明は以下.

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L692-L697

学習パラメータ

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1017

>>> A, P, B, C = self.init_ssm_dplr()
>>> A
tensor([[-0.5000-4.6033j, -0.5000-0.5565j]])
>>> P
tensor([[[-0.2030-1.8685j,  0.4570+0.5086j]]])
>>> B
tensor([[-0.2870-2.0000j,  0.6462+0.7193j]])
>>> C
tensor([[[0.0148+0.6038j, 0.5207-0.9530j]]])

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1020

最終的に学習するパラメータは A, B, C, inv_dt, P となり、\Lambda, B', C', log{\Delta}, P です。

パラメータとしての登録の詳細

虚数部は以下のように変換して parameter ( NN の weight ) として登録されます。_c2r = torch.view_as_real という関数がその役割です。

>>> C
tensor([[[0.0148+0.6038j, 0.5207-0.9530j]]])
>>> _resolve_conj(C)
tensor([[[0.0148-0.6038j, 0.5207+0.9530j]]])
>>> _c2r(_resolve_conj(C))
tensor([[[[ 0.0148, -0.6038],
          [ 0.5207,  0.9530]]]])

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1307-L1314

どう計算するのか

では、forward 関数を見ていきましょう。S4では、カーネル計算を行い SSM に出てくるyを計算しますが、その前後に activation や Linear なレイヤーを挟んでいます。本解説では それらは考慮しません のでご注意ください。

関数のフロー

カーネル計算の関数の流れを示します。

頻出の係数

頻出の係数の内容について見ていきましょう。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1455

D は式(3)と一部関連する (sI-\Lambda)^{-1} です。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1470

E は式(3)と一部関連する (sI+\Lambda) です。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1507-L1509

この new_state は式(2) の x_k=\={A}x_{k-1}+\={B}u_k を計算していますが、計算内部は少し異なっています。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1238

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1520-L1521

この dA, _ = self._setup_state() は、少し複雑な過程を経ているのですが、\={A}' にあたります。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1239-L1246

以降の計算で使いまわすために self.C を update しています。ここは \tilde{C}'=\={C}'(I-{\={A}'}^L) を計算しています。

Omega (z)

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1348

>>> discrete_L = 8
>>> omega, z = self._omega(discrete_L, dtype=A.dtype, device=A.device, cache=(rate==1.0))
>>> omega
tensor([ 1.0000e+00-0.0000e+00j,  7.0711e-01-7.0711e-01j,
        -4.3711e-08-1.0000e+00j, -7.0711e-01-7.0711e-01j,
        -1.0000e+00+8.7423e-08j])
>>> z
tensor([0.0000e+00+0.0000e+00j, 3.4229e-08+8.2843e-01j, 5.9605e-08+2.0000e+00j,
        7.3861e-07+4.8284e+00j, 2.1296e+07-3.1235e+07j])

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1261-L1265

omegazz2\frac{1-z}{1+z} にあたります。

カーネル計算

さて、いよいよカーネルを計算していきます。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1385-L1389

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1415

この k_f は式(22) の {\hat{K}_L}'(z) を計算しています。

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1418

そして、周波数領域から逆フーリエ変換して、それを返却しています。

y の計算

https://github.com/state-spaces/s4/blob/e757cef57d89e448c413de7325ed5601aceaac13/models/s4/s4.py#L1731-L1734

これを FFTConv.forward で受け取り、上述の記述によって、再度、フーリエ変換でDFT(K)DFT(u) を計算し、掛けて DFL(y) を計算し、逆フーリエ変換で y を得ています。

最後に

以上で S4 についての解説を終えます。まだまだ解説が甘い箇所もあるかと思いますが、自分自身も、だいぶ理解が深まったように感じます。

理論編、実装編を通して、どちらも理解が大変でした。理論編を理解したとしても、どう実装するのかというのは難しいです。このように複素数などバンバン入り乱れるものになってくると、自分でイチから実装できる自信はありません。その意味でも、すごく丁寧なコードに落として公開してくれているのは、大変ありがたいです。

Discussion