LTCにおけるODEソルバー比較:Fused Euler法 vs 閉形式近似(CfC)
はじめに
Liquid Time-Constant Networks (LTC) は、連続時間ダイナミクスをモデル化できる強力なリカレントニューラルネットワーク(RNN)です。しかし、その中核をなす常微分方程式(ODE)の状態更新処理は順伝搬・逆伝搬の双方において計算上のボトルネックとなり得ます。
この ODE を解くために主に 2 つのソルバーが提案されています。一つは Fused Euler 法に代表される数値的ソルバー[1]、もう一つは CfC (Closed-form Continuous-time)で用いられる解析的(閉形式)ソルバー[2]です。論文では CfC の計算効率の優位性が主張されていますが、その差が定量的には示されていません。
本記事では、公式リポジトリの実装に基づき、これら二つの手法の計算コスト(FLOPs)をコードから検討し、その性能差を明らかにすることを目的とします。
本記事の前提となる LNN の概要については、以下の記事をご参照ください。
数値的ソルバー: Fused Euler
数値的ソルバーは、微分方程式を微少な時間ステップで離散的に近似し、逐次的に解く手法である。LTC が扱う常微分方程式は時間による変化激しい 「硬い(stiff)方程式」であるため、一般的な解法の一つである Runge=Kutta 法では不安定になりやすいとされる。
そこで、Fused Euler と呼ばれる手法が LTC の原論文では提案されている[R.Hasani(2020)]。この手法は Euler 法でも 陰解法(Explicit Euler) と 陽解法(Implicit Euler) を融合(Fuse)した手法になっており、時間ステップを細かく刻みながら各ステップの状態を計算していく手法である。この時間ステップを複数回反復することで安定性と精度を両立する。具体的には以下のようなアルゴリズムで設計されている。
LTC update by fused ODE Solver
- メリット:
・アルゴリズムが物理的な挙動を模倣しており、直感的で理解しやすい - デメリット:
・精度や安定性を高めるには反復回数を増やす必要があり、計算コストが線形に増大する
解析的ソルバー:CfC(閉形式近似)
解析的ソルバーでは、時間ステップを刻む代わりに、微分方程式の解そのものを閉形式(Closed-form)で近似する手法である[R.Hasani(2022)]。これにより、任意の時刻の状態を一度の計算で求めることを可能にする。
実際に LTC の解は以下のような指数関数を用いた式で近似される。
Neural and Synapse Dynamics
- メリット:
時間ステップの反復が不要になるため、原理的に計算コストが低い。 - デメリット:
あくまでシステム全体の挙動を近似した解になるため、元の方程式の挙動を再現できるわけではなく、解釈性に問題がある可能性がある。
具体的にアルゴリズムは以下のように設計されている。
Translate a trained LTC network into its closed-form variant
二つの手法の計算量(FLOPs)比較
上記の二つの手法について公式リポジトリに NCP モデルの実装が公開されている。今回はこの実装例と論文の内容から計算量を可能な限り厳密に計算する。ただし、あくまでも私がコードを読み解き、推定したものであるため、過度に信用には注意されたい。
ここでは公式リポジトリの実装に基づき、各ソルバーの 1 タイムステップあたりの FLOPs を導出した。シーケンス長
変数定義
-
: 隠れ状態の次元数(units)N -
: 入力特徴量の次元数(input_size)D_{in} -
: 出力特徴量の次元数(output_size)D_{out} -
: Fused Euler 法の反復回数(ode_unfolds)K -
: CfC バックボーンの中間層ユニット数B_U -
: CfC バックボーンの層数B_L
Fused Euler 法
NCP で使用されるパラメータの次元数は以下である。sensory 系はいずれも感覚ニューロンのパラメータである。
- gleak, vleak, cm: units
- sigma, mu, w, erev: (units,units)
- sensory_sigma, mu, w,erev: (input_size,units)
- sensory_mask: (units,units)
- sensory_sparsity_mask: (input_size,units)
- input_w, b: input_size
- output_w,b: output_size
入力のマッピング(map_inputs)
入力にアフィン変換を行い、 NCP の感覚ニューロンに入力する。
def _map_inputs(self, inputs):
if self._input_mapping in ["affine", "linear"]: # inputs(Din) * input_w(Din) -> 要素ごとの積
inputs = inputs * self._params["input_w"]
if self._input_mapping == "affine": # inputs(Din) + input_b(Din) -> 要素ごとの和
inputs = inputs + self._params["input_b"]
return inputs
input は
微分方程式の演算(ode_solver)
この部分が最も大きい部分であるためいくつか分解して説明していく。
感覚ニューロンへの入力の影響計算(ループ前処理)
微分方程式を解くループを開始する前に現在の入力が状態更新に与える影響をあらかじめ計算する。
# 1. シグモイド関数の計算
sensory_w_activation = self.make_positive_fn(...) * self._sigmoid(
inputs, self._params["sensory_mu"], self._params["sensory_sigma"]
)
# 2. スパーシティマスクと erev の適用
sensory*w_activation = sensory_w_activation * self._params["sensory_sparsity_mask"]
sensory*rev_activation = sensory_w_activation * self._params["sensory_erev"]
# 3. 合計の計算
w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)
-
シグモイド関数の計算
入力をシグモイドによる活性化関数に通す。これは inputs と sensory×mu,sigma で演算が行われるため、この演算は 回の減算と乗算, torch.sigmoid 算(約 6FLOPs)が行われ、D_{in}×N かかる。O(D_{in}×N) -
スパーシティマスクと erev の適用
ここは単純に乗算するだけなので 。O(D_{in}×N) -
合計の計算
形状が のテンソルを合計するためD_{in}×N 。長さ N の一次元ベクトルを生成する。O(D_{in}×N)かかる
ODE ソルバーの反復計算(ループ内処理)
for t in range(self._ode_unfolds):
# 1. シグモイド関数の計算
w_activation = w_param * self._sigmoid(
v_pre, self._params["mu"], self._params["sigma"]
)
# 2. スパーシティマスクと Erev の適用
w*activation = w_activation * self._params["sparsity_mask"]
rev*activation = w_activation * self._params["erev"]
# 3. 合計の計算
w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory
# 4. 状態の更新
numerator = cm*t * v*pre + gleak * self._params["vleak"] + w_numerator
denominator = cm_t + gleak + w_denominator
v_pre = numerator / (denominator + self._epsilon)
概ねやることは前処理と変わらないが状態の更新のみ追加される。また sensory とは違い、それぞれのサイズが
- 状態の更新
numerator と denominator の双方は形状が(N)のベクトル同士の要素の演算になるため、一回につき FLOPs はおよそ 10 程度で、計算量は O(N)になる。
これらの演算が ode_unfolds の数 K だけループするため、この部分の全体の計算量は
出力のマッピング
更新された隠れ状態の state から最終的な出力の計算を行う。
def _map_outputs(self, state):
output = state
if self.motor_size < self.state_size:
output = output[:, 0 : self.motor_size] # スライス
if self._output_mapping in ["affine", "linear"]: # output(Dout) * output_w(Dout) -> 要素ごとの積
output = output * self._params["output_w"]
if self._output_mapping == "affine": # output(Dout) + output_b(Dout) -> 要素ごとの和
output = output + self._params["output_b"]
return output
状態ベクトルから出力に必要な部分をスライスし、その部分に対してアフィン変換を行う。したがって、この部分の計算量は
以上から、Fused Euler 法を用いた順伝搬の計算量は
CfC
次に CfC による閉形式近似の計算量を推定する。先ほどの Fused Euler とは打って変わり、数値的に反復計算で解くのではなく、特定の条件下で閉形式解を用いて次の状態を直接計算するため、ode_unfolds のような反復ループは不要である。
CfC には表現力と柔軟性を重視した実用版の default モードと元論文の閉形式解を忠実に再現した pure モードが存在する。ここでは実用的なタスクでの使用を意識し、default モードについて比較する。
前処理
バックボーンネットワークへの変換
入力と隠れ状態を結合したベクトルを
Linear(D_in + N, B_U)
for i in range(B_L-1):
Linear(B_U,B_U)
すべての処理を合計すると
主要な線形変換
バックボーンから出力されたベクトルを 4 つの異なる線形層(ff1,ff2,time_a,time_b)に入力し、隠れ状態の計算に必要な要素の生成を行う。
ff1 = self.ff1(x)
if self.mode != "pure":
ff2 = self.ff2(x)
t_a = self.time_a(x)
t_b = self.time_b(x)
4 つの線形層を通すため、
時間的補間係数の計算
ff1 と ff2 を tanh を用いて活性化を行い、経過時間 ts を用いて状態をどれだけ更新するかを決定する補完係数 t_interp を計算する。
ff1 = self.tanh(ff1)
ff2 = self.tanh(ff2)
t_interp = self.sigmoid(t_a * ts + t_b)
合計で
隠れ状態の計算
new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2
次の隠れ状態はこれまでの計算結果を使用して上の式で計算される。
この計算は
これらを合計すると
少し複雑なのでバックボーンを考慮しない場合も考える。
そのときは線形層に入力するベクトルのサイズが
実際、論文[1]では Fused Euler と CfC などの ODE ソルバーの比較結果を以下の図に示している。
Time Complexity of the process to compute K solver's steps
まとめ
今回の FLOPs の解析から、以下のことが明らかになった。
- CfC の効率性:バックボーンを使用しない場合、CfC は Fused Euler 法の K=1 の場合よりも低コストである。また、安定性を高めるためには K を 3~6 程度に増やす必要があるため、CfC の計算効率の優位性はさらに増大する。
- バックボーンの代償:CfC の性能はバックボーンに依存するため、
の大きさにより安定性を高めることが出来る一方で、Fused Euler のコストに匹敵する可能性が存在する。B_U
これらの比較を行うことが重要だろう。また、ODE ソルバーの検討は十分になされていないため、まだまだ改善の余地があると考える。
参考文献
[1]R. Hasani, M. Lechner, A. Amini, D. Rus, and R. Grosu, "Liquid time-constant networks," Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, no. 9, pp. 7657–7666, 2020.
[2]R. Hasani, M. Lechner, A. Amini, L. Liebenwein, A. Ray, M. Tschaikowski, G. Teschl, and D. Rus, "Closed-form continuous-time neural models," Nature Machine Intelligence, vol. 4, pp. 992–1003, 2021.
Discussion