🧠

LTCにおけるODEソルバー比較:Fused Euler法 vs 閉形式近似(CfC)

に公開

はじめに

Liquid Time-Constant Networks (LTC) は、連続時間ダイナミクスをモデル化できる強力なリカレントニューラルネットワーク(RNN)です。しかし、その中核をなす常微分方程式(ODE)の状態更新処理は順伝搬・逆伝搬の双方において計算上のボトルネックとなり得ます。

この ODE を解くために主に 2 つのソルバーが提案されています。一つは Fused Euler 法に代表される数値的ソルバー[1]、もう一つは CfC (Closed-form Continuous-time)で用いられる解析的(閉形式)ソルバー[2]です。論文では CfC の計算効率の優位性が主張されていますが、その差が定量的には示されていません。

本記事では、公式リポジトリの実装に基づき、これら二つの手法の計算コスト(FLOPs)をコードから検討し、その性能差を明らかにすることを目的とします。

本記事の前提となる LNN の概要については、以下の記事をご参照ください。

https://zenn.dev/yryromrk/articles/71f60cf8cd0efb
https://zenn.dev/yryromrk/articles/a1aa2cf3fb1bff

数値的ソルバー: Fused Euler

数値的ソルバーは、微分方程式を微少な時間ステップで離散的に近似し、逐次的に解く手法である。LTC が扱う常微分方程式は時間による変化激しい 「硬い(stiff)方程式」であるため、一般的な解法の一つである Runge=Kutta 法では不安定になりやすいとされる。
そこで、Fused Euler と呼ばれる手法が LTC の原論文では提案されている[R.Hasani(2020)]。この手法は Euler 法でも 陰解法(Explicit Euler) と 陽解法(Implicit Euler) を融合(Fuse)した手法になっており、時間ステップを細かく刻みながら各ステップの状態を計算していく手法である。この時間ステップを複数回反復することで安定性と精度を両立する。具体的には以下のようなアルゴリズムで設計されている。
FusedEulerアルゴリズム
LTC update by fused ODE Solver

  • メリット:
    ・アルゴリズムが物理的な挙動を模倣しており、直感的で理解しやすい
  • デメリット:
    ・精度や安定性を高めるには反復回数を増やす必要があり、計算コストが線形に増大する

解析的ソルバー:CfC(閉形式近似)

解析的ソルバーでは、時間ステップを刻む代わりに、微分方程式の解そのものを閉形式(Closed-form)で近似する手法である[R.Hasani(2022)]。これにより、任意の時刻の状態を一度の計算で求めることを可能にする。
実際に LTC の解は以下のような指数関数を用いた式で近似される。

x(t)=(x_0-A)e^{-[\frac{1}{\tau}+f(I(t))]t}f(-I(t))+A

CfCの概要
Neural and Synapse Dynamics

  • メリット:
    時間ステップの反復が不要になるため、原理的に計算コストが低い。
  • デメリット:
    あくまでシステム全体の挙動を近似した解になるため、元の方程式の挙動を再現できるわけではなく、解釈性に問題がある可能性がある。

具体的にアルゴリズムは以下のように設計されている。
CfC algorithm
Translate a trained LTC network into its closed-form variant

二つの手法の計算量(FLOPs)比較

上記の二つの手法について公式リポジトリに NCP モデルの実装が公開されている。今回はこの実装例と論文の内容から計算量を可能な限り厳密に計算する。ただし、あくまでも私がコードを読み解き、推定したものであるため、過度に信用には注意されたい。

ここでは公式リポジトリの実装に基づき、各ソルバーの 1 タイムステップあたりの FLOPs を導出した。シーケンス長Lのデータに対しては以下の内容がL回繰り返されることになる。

変数定義

  • N: 隠れ状態の次元数(units)
  • D_{in}: 入力特徴量の次元数(input_size)
  • D_{out}: 出力特徴量の次元数(output_size)
  • K: Fused Euler 法の反復回数(ode_unfolds)
  • B_U: CfC バックボーンの中間層ユニット数
  • B_L: CfC バックボーンの層数

Fused Euler 法

NCP で使用されるパラメータの次元数は以下である。sensory 系はいずれも感覚ニューロンのパラメータである。

  • gleak, vleak, cm: units
  • sigma, mu, w, erev: (units,units)
  • sensory_sigma, mu, w,erev: (input_size,units)
  • sensory_mask: (units,units)
  • sensory_sparsity_mask: (input_size,units)
  • input_w, b: input_size
  • output_w,b: output_size

入力のマッピング(map_inputs)

入力にアフィン変換を行い、 NCP の感覚ニューロンに入力する。

def _map_inputs(self, inputs):
   if self._input_mapping in ["affine", "linear"]: # inputs(Din) * input_w(Din) -> 要素ごとの積
      inputs = inputs * self._params["input_w"]
   if self._input_mapping == "affine": # inputs(Din) + input_b(Din) -> 要素ごとの和
      inputs = inputs + self._params["input_b"]
   return inputs

input はD_{in}のサイズであるため、アフィン変換であれば乗算と加算でそれぞれD_{in}回演算が必要になり、2D_{in}FLOPs かかる。

微分方程式の演算(ode_solver)

この部分が最も大きい部分であるためいくつか分解して説明していく。

感覚ニューロンへの入力の影響計算(ループ前処理)

微分方程式を解くループを開始する前に現在の入力が状態更新に与える影響をあらかじめ計算する。

# 1. シグモイド関数の計算
sensory_w_activation = self.make_positive_fn(...) * self._sigmoid(
inputs, self._params["sensory_mu"], self._params["sensory_sigma"]
)
# 2. スパーシティマスクと erev の適用
sensory*w_activation = sensory_w_activation * self._params["sensory_sparsity_mask"]
sensory*rev_activation = sensory_w_activation * self._params["sensory_erev"]
# 3. 合計の計算
w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)
  1. シグモイド関数の計算
    入力をシグモイドによる活性化関数に通す。これは inputs と sensory×mu,sigma で演算が行われるため、この演算はD_{in}×N回の減算と乗算, torch.sigmoid 算(約 6FLOPs)が行われ、O(D_{in}×N)かかる。

  2. スパーシティマスクと erev の適用
    ここは単純に乗算するだけなのでO(D_{in}×N)

  3. 合計の計算
    形状がD_{in}×Nのテンソルを合計するためO(D_{in}×N)かかる。長さ N の一次元ベクトルを生成する。

ODE ソルバーの反復計算(ループ内処理)
for t in range(self._ode_unfolds):
   # 1. シグモイド関数の計算
   w_activation = w_param * self._sigmoid(
   v_pre, self._params["mu"], self._params["sigma"]
   )
   # 2. スパーシティマスクと Erev の適用
   w*activation = w_activation * self._params["sparsity_mask"]
   rev*activation = w_activation * self._params["erev"]
   # 3. 合計の計算
   w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
   w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory
   # 4. 状態の更新
   numerator = cm*t * v*pre + gleak * self._params["vleak"] + w_numerator
   denominator = cm_t + gleak + w_denominator
   v_pre = numerator / (denominator + self._epsilon)

概ねやることは前処理と変わらないが状態の更新のみ追加される。また sensory とは違い、それぞれのサイズがN×Nになる。したがって 1 と 2 と 3 はO(N^2)。FLOPs では約 10FLOPs×N^2かかる。

  1. 状態の更新
    numerator と denominator の双方は形状が(N)のベクトル同士の要素の演算になるため、一回につき FLOPs はおよそ 10 程度で、計算量は O(N)になる。

これらの演算が ode_unfolds の数 K だけループするため、この部分の全体の計算量はO(K×N^2)になる。

出力のマッピング

更新された隠れ状態の state から最終的な出力の計算を行う。

def _map_outputs(self, state):
   output = state
   if self.motor_size < self.state_size:
   output = output[:, 0 : self.motor_size] # スライス
   if self._output_mapping in ["affine", "linear"]: # output(Dout) * output_w(Dout) -> 要素ごとの積
   output = output * self._params["output_w"]
   if self._output_mapping == "affine": # output(Dout) + output_b(Dout) -> 要素ごとの和
   output = output + self._params["output_b"]
   return output

状態ベクトルから出力に必要な部分をスライスし、その部分に対してアフィン変換を行う。したがって、この部分の計算量はO(D_{out})になる。

以上から、Fused Euler 法を用いた順伝搬の計算量はO(D_{in}N+KN^2+D_{out})になるが、概ねO(KN^2)といえる。

CfC

次に CfC による閉形式近似の計算量を推定する。先ほどの Fused Euler とは打って変わり、数値的に反復計算で解くのではなく、特定の条件下で閉形式解を用いて次の状態を直接計算するため、ode_unfolds のような反復ループは不要である。

CfC には表現力と柔軟性を重視した実用版の default モードと元論文の閉形式解を忠実に再現した pure モードが存在する。ここでは実用的なタスクでの使用を意識し、default モードについて比較する。

前処理

バックボーンネットワークへの変換

入力と隠れ状態を結合したベクトルをB_L層の MLP に通し、より表現力の高い特徴ベクトルに変換する。

Linear(D_in + N, B_U)
for i in range(B_L-1):
   Linear(B_U,B_U)

すべての処理を合計するとB_U(2(D_{in}+N)+2)+(B_L-1)(2B_U^2+2B_U) FLOPs になる。O(B_L×B_U^2)$

主要な線形変換

バックボーンから出力されたベクトルを 4 つの異なる線形層(ff1,ff2,time_a,time_b)に入力し、隠れ状態の計算に必要な要素の生成を行う。

ff1 = self.ff1(x)
if self.mode != "pure":
   ff2 = self.ff2(x)
   t_a = self.time_a(x)
   t_b = self.time_b(x)

4 つの線形層を通すため、4(2B_U+1)×N FLOPs かかる。O(B_U×N)

時間的補間係数の計算

ff1 と ff2 を tanh を用いて活性化を行い、経過時間 ts を用いて状態をどれだけ更新するかを決定する補完係数 t_interp を計算する。

ff1 = self.tanh(ff1)
ff2 = self.tanh(ff2)
t_interp = self.sigmoid(t_a * ts + t_b)

合計で5N FLOPs だけかかる。O(N)

隠れ状態の計算

new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2

次の隠れ状態はこれまでの計算結果を使用して上の式で計算される。
この計算は4N FLOPs だけかかる。O(N)


これらを合計すると
B_U(2(D_{in}+N)+2)+(B_L-1)(2B_U^2+2B_U)+(8B_U+13)×NFLOPs になる。

少し複雑なのでバックボーンを考慮しない場合も考える。
そのときは線形層に入力するベクトルのサイズがD_{in}+Nになるから4(2(D_{in}+N)+1)×NFLOPs に変化し、合計は(8D_{in}+8N+13)NFLOPs、オーダーはO(D_{in}N+N^2)\simeq O(N^2)になる。

実際、論文[1]では Fused Euler と CfC などの ODE ソルバーの比較結果を以下の図に示している。
ODEソルバーの比較
Time Complexity of the process to compute K solver's steps

まとめ

今回の FLOPs の解析から、以下のことが明らかになった。

  1. CfC の効率性:バックボーンを使用しない場合、CfC は Fused Euler 法の K=1 の場合よりも低コストである。また、安定性を高めるためには K を 3~6 程度に増やす必要があるため、CfC の計算効率の優位性はさらに増大する。
  2. バックボーンの代償:CfC の性能はバックボーンに依存するため、B_Uの大きさにより安定性を高めることが出来る一方で、Fused Euler のコストに匹敵する可能性が存在する。

これらの比較を行うことが重要だろう。また、ODE ソルバーの検討は十分になされていないため、まだまだ改善の余地があると考える。

参考文献

[1]R. Hasani, M. Lechner, A. Amini, D. Rus, and R. Grosu, "Liquid time-constant networks," Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, no. 9, pp. 7657–7666, 2020.
[2]R. Hasani, M. Lechner, A. Amini, L. Liebenwein, A. Ray, M. Tschaikowski, G. Teschl, and D. Rus, "Closed-form continuous-time neural models," Nature Machine Intelligence, vol. 4, pp. 992–1003, 2021.

Discussion