時系列データ分析論文④「 S4 実装編 」
概要
Github: https://github.com/state-spaces/s4
大変ありがたい事に、全実装は以下のページにまとまっています。
さて、理論編でどういった変数が登場し、どう計算すれば入力
ざっくり分けると、以下の要点をおさえる必要があります。
- 何を学習するのか
- どう計算するのか
ではこれらにそって解説していきましょう。
何を学習するのか
関数のフロー
モデル自体は S4Block
になるますが、関係のある処理は全て self.layer = FFTConv
/ self.kernel = kernel_cls
の中にあります。
となっているため、注目すべき class は SSMKernelDPLR
となります。
パラメータの初期化と学習パラメータの宣言について、クラスや関数の流れを、下記の instance を作成した際のフローチャートを記載しています。この流れに沿って見ていきましょう。
model = SSMKernelDPLR(disc="bilinear", d_model=1, d_state=4)
行列A, B
SSMの基本パラメータです。
>>> A, B = transition("legs", 4)
>>> A = torch.as_tensor(A) # (N, N)
>>> B = torch.as_tensor(B)[:, 0] # (N,)
>>> A
tensor([[-1.0000, 0.0000, 0.0000, 0.0000],
[-1.7321, -2.0000, 0.0000, 0.0000],
[-2.2361, -3.8730, -3.0000, 0.0000],
[-2.6458, -4.5826, -5.9161, -4.0000]], dtype=torch.float64)
>>> B
tensor([1.0000, 1.7321, 2.2361, 2.6458], dtype=torch.float64)
固有値分解
式(11)を再掲します。
>>> P = rank_correction("legs", 4, rank=1)
>>> P
tensor([[0.7071, 1.2247, 1.5811, 1.8708]])
>>> torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
tensor([[0.5000, 0.8660, 1.1180, 1.3229],
[0.8660, 1.5000, 1.9365, 2.2913],
[1.1180, 1.9365, 2.5000, 2.9580],
[1.3229, 2.2913, 2.9580, 3.5000]])
>>> AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
>>> AP
tensor([[-0.5000, 0.8660, 1.1180, 1.3229],
[-0.8660, -0.5000, 1.9365, 2.2913],
[-1.1180, -1.9365, -0.5000, 2.9580],
[-1.3229, -2.2913, -2.9580, -0.5000]], dtype=torch.float64)
この AP
は式(11)で言うところの正規行列
式(12)を再掲します。
>>> W_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
>>> W_re
tensor([-0.5000], dtype=torch.float64)
>>> W_im, V = torch.linalg.eigh(AP*-1j)
>>> W_im
tensor([-4.6033, -0.5565, 0.5565, 4.6033], dtype=torch.float64)
>>> V
tensor([[-0.2870-0.0000j, 0.6462+0.0000j, -0.6462+0.0000j, 0.2870+0.0000j],
[-0.4335+0.1953j, -0.1925-0.4867j, 0.1925-0.4867j, 0.4335+0.1953j],
[-0.1866+0.5358j, -0.0829+0.4139j, 0.0829+0.4139j, 0.1866+0.5358j],
[ 0.4415+0.4181j, 0.1961-0.3030j, -0.1961-0.3030j, -0.4415+0.4181j]],
dtype=torch.complex128)
>>> W = W_re + 1j * W_im
>>> W
tensor([-0.5000-4.6033j, -0.5000-0.5565j, -0.5000+0.5565j, -0.5000+4.6033j],
dtype=torch.complex128)
>>> (V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)).real # ~= AP
tensor([[-0.5000, 0.8660, 1.1180, 1.3229],
[-0.8660, -0.5000, 1.9365, 2.2913],
[-1.1180, -1.9365, -0.5000, 2.9580],
[-1.3229, -2.2913, -2.9580, -0.5000]], dtype=torch.float64)
>>> V_inv = V.conj().transpose(-1, -2)
>>> V_inv
tensor([[-0.2870+0.0000j, -0.4335-0.1953j, -0.1866-0.5358j, 0.4415-0.4181j],
[ 0.6462-0.0000j, -0.1925+0.4867j, -0.0829-0.4139j, 0.1961+0.3030j],
[-0.6462-0.0000j, 0.1925+0.4867j, 0.0829-0.4139j, -0.1961+0.3030j],
[ 0.2870-0.0000j, 0.4335-0.1953j, 0.1866-0.5358j, -0.4415-0.4181j]],
dtype=torch.complex128)
torch.linalg.eigh
は固有値分解を行っており、正規行列
そして W
については、(V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)).real
が AP
とほぼ等しくなる事から、=AP
) なので、W
は
また、V_inv
は
行列B, P
>>> B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B
>>> P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P
>>> B
tensor([-0.2870-2.6425j, 0.6462+0.7193j, -0.6462+0.7193j, 0.2870-2.6425j],
dtype=torch.complex128)
>>> P
tensor([[-0.2030-1.8685j, 0.4570+0.5086j, -0.4570+0.5086j, 0.2030-1.8685j]],
dtype=torch.complex128)
nplr の返却
最終的に return W, P, B, V
となっており、つまり、
離散化におけるdt
>>> model = SSMKernelDPLR(disc="bilinear", d_model=1, d_state=4)
>>> model.init_dt()
tensor([[-5.1659]])
inv_dt
は
行列C
C
は以下のようになっており、random (虚数 torch.complex64
) なパラメータ。
channel の説明は以下.
学習パラメータ
>>> A, P, B, C = self.init_ssm_dplr()
>>> A
tensor([[-0.5000-4.6033j, -0.5000-0.5565j]])
>>> P
tensor([[[-0.2030-1.8685j, 0.4570+0.5086j]]])
>>> B
tensor([[-0.2870-2.0000j, 0.6462+0.7193j]])
>>> C
tensor([[[0.0148+0.6038j, 0.5207-0.9530j]]])
最終的に学習するパラメータは A, B, C, inv_dt, P
となり、
パラメータとしての登録の詳細
虚数部は以下のように変換して parameter ( NN の weight ) として登録されます。_c2r = torch.view_as_real
という関数がその役割です。
>>> C
tensor([[[0.0148+0.6038j, 0.5207-0.9530j]]])
>>> _resolve_conj(C)
tensor([[[0.0148-0.6038j, 0.5207+0.9530j]]])
>>> _c2r(_resolve_conj(C))
tensor([[[[ 0.0148, -0.6038],
[ 0.5207, 0.9530]]]])
どう計算するのか
では、forward
関数を見ていきましょう。S4では、カーネル計算を行い SSM に出てくる
関数のフロー
カーネル計算の関数の流れを示します。
頻出の係数
頻出の係数の内容について見ていきましょう。
D
は式(3)と一部関連する
E
は式(3)と一部関連する
この new_state
は式(2) の
この dA, _ = self._setup_state()
は、少し複雑な過程を経ているのですが、
以降の計算で使いまわすために self.C
を update しています。ここは
Omega (z)
>>> discrete_L = 8
>>> omega, z = self._omega(discrete_L, dtype=A.dtype, device=A.device, cache=(rate==1.0))
>>> omega
tensor([ 1.0000e+00-0.0000e+00j, 7.0711e-01-7.0711e-01j,
-4.3711e-08-1.0000e+00j, -7.0711e-01-7.0711e-01j,
-1.0000e+00+8.7423e-08j])
>>> z
tensor([0.0000e+00+0.0000e+00j, 3.4229e-08+8.2843e-01j, 5.9605e-08+2.0000e+00j,
7.3861e-07+4.8284e+00j, 2.1296e+07-3.1235e+07j])
omega
は z
は
カーネル計算
さて、いよいよカーネルを計算していきます。
この k_f
は式(22) の
そして、周波数領域から逆フーリエ変換して、それを返却しています。
y の計算
これを FFTConv.forward
で受け取り、上述の記述によって、再度、フーリエ変換で
最後に
以上で S4 についての解説を終えます。まだまだ解説が甘い箇所もあるかと思いますが、自分自身も、だいぶ理解が深まったように感じます。
理論編、実装編を通して、どちらも理解が大変でした。理論編を理解したとしても、どう実装するのかというのは難しいです。このように複素数などバンバン入り乱れるものになってくると、自分でイチから実装できる自信はありません。その意味でも、すごく丁寧なコードに落として公開してくれているのは、大変ありがたいです。
Discussion