😎

論文解説:粗視化MDのための拡散モデルと力場 (2/3)

2023/08/15に公開

注意

半分メモです。勉強のために読みながら書いています。和訳を中心に適当に補足しています

この記事ではこの論文の本質である粗視化MDのための拡散モデルと力場の抽出方法、実際のネットワーク構造に関して解説していきます。イントロに関するメモはこちらから。

解説する論文

Two for One: Diffusion Models and Force Fields for Coarse-Grained
Molecular Dynamics

Microsoft Researchらのチームによって提案された手法。2023年2月にarxivに公開された。スコアベースの生成モデルにより訓練データ中に力の情報なしで粗視化力場を学習している。また、拡散生成モデルを用いることで、平衡分布からの粗視化構造のサンプリングができるだけでなく、スコア関数自体が粗視化力場として直接利用できることを示している。(おそらく、1つのモデルで構造サンプリングとダイナミクスの計算両方ができるというのがタイトルのtwo for oneの理由と思われる。)

拡散モデルを用いることで、訓練自体がシンプルになり、また、中規模程度までのタンパク質について平衡分布やダイナミクスの再現が先行研究であるCG-NetやFlowよりも改善されたとのこと。

粗視化MDのための拡散モデル

  • デノイジング拡散確率モデル (Denoising diffusion probabilistic models, DDPM)は逆拡散過程(デノイジング過程と呼ばれる)によって近似された確率分布からサンプルを生成する。拡散過程は次のようなLステップのマルコフ連鎖として定義される。ここで\mathbf{z}_0は不明なデータ分布q(\mathbf{z}_0)からのサンプルである。
q(\mathbf{z}_{1:L} | \mathbf{z}_0)= \prod_{i=1}^{L}q(\mathbf{z}_{i} | \mathbf{z}_{i-1})
  • 学習された逆拡散過程はデノイジングステップLの逆時間マルコフ連鎖として次のように定義される。
p(\mathbf{z}_{0:L}) := p(\mathbf{z}_L)\prod_{i=1}^L p_\theta(\mathbf{z}_{i-1} | \mathbf{z}_{i})
  • 実数の乱数に対しては、拡散過程は分散が固定されたガウス分布q(\mathbf{z}_i|\mathbf{z}_{i-1})=N(\mathbf{z}_{i-1};\sqrt{1-\beta_i}\mathbf{z}_{i}, \beta_i\bm{I})を使用し、標準正規定常分布の性質を持つマルコフ連鎖として動作する。
  • 逆拡散過程の分布は次のように拡散過程と同一の関数形を持つように選ばれる。ここで\mu_{\theta}(\mathbf{z}_i,i)はパラメータ\thetaで学習可能な関数で\sigma^2_iは分散を固定していることを示す。
p(\mathbf{z}_L)=N(\bm{0},\bm{I})
p_{\theta}(\mathbf{z}_{i-1} | \mathbf{z}_i) = N(\mathbf{z}_{i-1};\mu_{\theta}(\mathbf{z}_i,i), \sigma_i^2\bm{I})
  • ガウス分布の閉形式に対する周辺化を利用し、平均\mu_\theta(\mathbf{z}_i,i)を以下のようにパラメータ化することで、ロスを最小化することにより訓練を進めることができる。ここで\alpha=1-\beta_i\={\alpha}_i=\prod_{s=1}^i\alpha_sである。K_i=\frac{\beta_i^2}{2\sigma_i^2\alpha_i(1-\={\alpha}_i)}が成り立つ場合、この式は負の変分下界となる。しかし、Hoらは実用上はK_i=1として再重みづけする方が効果的だと報告している。
\mu_\theta(\mathbf{z}_i,i)=\frac{1}{\sqrt{\alpha_i}} ( \mathbf{z}_i-\frac{\beta_i}{\sqrt{1-\={\alpha}_i}}\epsilon_{\theta}(\mathbf{z}_i,i))
\tag{3} \sum_{i=1}^{L}K_i\mathbb{E}_{q(\mathbf{z}_0)}\mathbb{E}_{N(\epsilon;\bm{0},\bm{I})}\lbrack||\epsilon-\epsilon_\theta(\sqrt{\={\alpha}_i}\mathbf{z};\sqrt{1-\={\alpha}_i}\epsilon_i,i)||^2\rbrack
  • この論文では、データはボルツマン分布q(\mathbf{z}_0)\propto^{-\frac{V(\mathbf{z})}{k_BT}}からサンプルされたデータで構成される。ノイズ予測ネットワーク\epsilon_\theta(\mathbf{z}_i,i)を通じてパラメータ化された学習済み拡散モデルが与えられると、グラフモデルp(\mathbf{z}_L)\prod_{i=1}^Lp_\theta(\mathbf{z}_{i-1}|\mathbf{z}_i)からの祖先サンプリングにより、近似された粗視化分布の独立同分布サンプルを作成することができる。

デノイジング力場:拡散モデルからの力場の抽出

  • SongらはK_i=1における式(3)のDDPMロスが次のデノイジングスコアマッチング関数の重みづけ和と等価であることを示した。
\sum_{i=1}^{L}(1-\={\alpha}_i)\mathbb{E}_{q(\mathbf{z}_0)}\mathbb{E}_{q(\mathbf{z}_i|\mathbf{z}_0)}\lbrack||s_\theta(\mathbf{z}_i,i)-\nabla_{\mathbf{z}_i}\log q(\mathbf{z}_i|\mathbf{z}_0)||^2\rbrack
  • ここで、q(\mathbf{z}_i|\mathbf{z}_0)=N(\mathbf{z}_i;\sqrt{\={\alpha}_i}\mathbf{z}_0,(1-\={\alpha}_i)\bm{I})であり、s_\theta(\mathbf{z}_i,i)はスコアモデルである。これはSongらの論文では明示的に書かれてはいないが、これら2つのロスが等価であることは、スコアモデルs_\theta(\mathbf{z}_i,i)とノイズ予測ネットワーク\epsilon_\theta(\mathbf{z}_i,i)が次のように関連付けられることになる。(詳しくはAppendix 1)
s_\theta(\mathbf{z}_i,i)=-\frac{\epsilon_\theta(\mathbf{z}_i,i)}{\sqrt{1-\={\alpha}_i}}
  • 十分に表現力があるモデルと十分なデータがある場合、最適化スコアs_\theta^*(\mathbf{z}_i,i)\nabla_{\mathbf{z}_i}\log q(\mathbf{z}_i)と一致することが期待される。ここでq(\mathbf{z}_i)=\int d\mathbf{z}_0 q(\mathbf{z}_i|\mathbf{z}_0)q(\mathbf{z}_0)は拡散プロセスのi段階目における周辺確率分布である。
  • 十分にノイズレベルが低いとき、周辺確率分布q(\mathbf{z}_i)s_\theta^*(\mathbf{z}_i,i)が不明なデータ分布のスコアを有効的に近似するようなデータ分布q(\mathbf{z}_0)と似てくる。また後者が粗視化ボルツマン分布q(\mathbf{z}_0)\propto e^{\frac{V(\mathbf{z})}{k_BT}}に等しいとき、レベルi=1における最適化スコアs_\theta^*(\mathbf{z}_i,i)は粗視化力\nabla_{z}\log q(\mathbf{z})=\frac{-\nabla_zV(\mathbf{z})}{k_BT}=\frac{\mathbf{F}_\mathbf{z}}{k_BT}と近似的に一致する。
  • 最終的にs_\theta^*(\mathbf{z}_i,i)とノイズ予測ネットワーク\epsilon_\theta(\mathbf{z}_i,i)を用いて、式(3)のロスで学習されたデノイジング拡散モデルから粗視化された力を取り出すことができる。
\mathbf{F}_\mathbf{z}^{DFF}=-\frac{k_BT}{\sqrt{1-\={\alpha}_i}}\epsilon_\theta^*(\mathbf{z},i)
  • この近似された粗視化力場をdenoising force field (DFF)と呼ぶ。粗視化力場が良い近似になるのはi=1のときであるはずが、実用的にはiをハイパーパラメータとして扱い、クロスバリデーションによりベストなiを選ぶこともできる。
  • 力場とデノイジング拡散モデルとの接続は以前にも研究されている。Zaidiら(2022)は、局所的にボルツマン分布を最大化する(またはエネルギーを最小化する)分子構造をノイズ除去することによって、ノイズ除去拡散セットアップで特性予測を行うグラフニューラルネットワークを事前に訓練した。
  • データ分布をこれらの極小値周辺のガウス分布の混合として近似することで、スコアマッチング目的関数が混合ガウス分布の混合近似の力場を学習することと等価であることを示している。
  • 同様に、Xieらは小さなノイズに対するデノイジングネットワークの学習スコアはエネルギー最小な構造の周辺での調和振動子の力場になることを示している。重要な点はこれらの接続は安定構造周辺の近似された力場を与えるだけであり、用途が限られることである。
  • この研究では、平衡ボルツマン分布からのサンプル上のデノイジング拡散モデルの訓練が、全体の平衡分布の教師なし学習として力場が作成できることを示す。これは安定で信頼性の高い粗視化力場を用いた粗視化MDシミュレーションを安定して実行するために重要となる。

デノイジング力場を用いた分子動力学

  • 式5のDFFから次のようなランジュバン方程式の時間発展により粗視化MDを行うことができる。ここで、-\nabla_{\mathbf{z}}V(\mathbf{z})=\mathbf{F}_\mathbf{z}^{DFF}であり、Mは粗視化ビーズの質量、\gammaは摩擦係数、\mathbf{w}\mathbb{E}_{p(\mathbf{x})} \lbrack \mathbf{w}(t) \cdot \mathbf{w}(t') \rbrack = \delta (t-t')を満たすホワイトノイズである。
M\frac{d^2\mathbf{z}}{dt^2}=-\nabla_{\mathbf{z}}V(\mathbf{z})-\gamma M\frac{d\mathbf{z}}{dt}+\sqrt{2M\gamma k_BT\mathbf{w(t)}}
  • 我々の実験では\gammaTはデータを生成するために行ったオリジナルの全原子シミュレーションと同じ値を用いた。ゆえに、学習されたネットワーク\epsilon_\theta与えられたとき、ハイパーパラメータはレベルiだけが残ることになる。
  • ランジュバン方程式のよく知られた極限は、質量が無視でき、摩擦係数が大きい場合ときであり(つまり\eta = \gamma M)、これはブラウン力学または過減衰ランジュバン力学と呼ばれる。興味深いことに低いレベル(例えばi=1)で拡散とノイズを除去を繰り返すとタイムステップ\Delta t\Delta \frac{k_B T}{M \gamma} = 1 - \={\alpha}_{1} = \beta_{1}であるブラウン動力学が近似されることを示す。

デノイジング力場のアーキテクチャー

  • ニューラルネットワーク\epsilon_\thetaの選択は系の物理的な対称性に大きな影響を受ける。例えば、粗視化力場は保存されている必要があり粗視化ポテンシャルV_\theta(\mathbf{z})の負の勾配と等しい必要がある。そのため、\epsilon_\theta(\mathbf{z}_i,i)をスカラー出力を持つエネルギーのニューラルネットワークの勾配(\epsilon_\theta(\mathbf{z}_i,i)=\nabla_{{\mathbf{z}}_i}\mathrm{nn}_\theta (\mathbf{z}_i,i)))としてパラメータ化する。
  • 画像生成における先行研究では、スコアネットワークに拘束がなかったり、スコアがエネルギー関数の勾配としてパラメータ化されていたりしてもサンプルの品質に経験的な違いはなかった。しかし、拡散モデルにおける保存されたスコアを用いることは、デノイジングされた力場を使って安定的な粗視化MDを行うためには極めて重要であった。
  • さらに、力場は並進不変であり、回転同変である必要がある。我々はモデルの並進不変性を実現するためにネットワークの入力として2つの粗視化ビーズの座標をペアごとの差ベクトル\mathbf{z}_{(i)}-\mathbf{z}_{(j)}のみを考慮している。
  • 力は回転に対して等変である必要があるが、Trippeらによって報告されたような鏡像タンパク質の生成を避けるため、明確の鏡像に対する同変性は明確には必要ない。言い換えれば、O(3)の代わりにSO(3)同変が必要であるということである。高コストな球面調和近似や角度表現を使わずに、O(3)同変を近似するシンプルな戦略はデータ拡張である。
  • Gruverらの先行研究ではtransformerで学習した同変性は実際の同変ネットワークと競合できることを示している。Appendix B3に示したようにデノイジング力場はバリデーションセットに対して回転による相対二乗誤差が10^{-6}であることを示している。

アーキテクチャーの詳細

  • 本研究ではネットワークモデル\mathrm{nn}_\thetaとしてGraph Transformerを用いた。まずGraph Transformerで呼び出される関数を次のように名づける。
\mathrm{nodes}_{\mathrm{out}} = \mathrm{GT}\lbrack \mathrm{nodes}_{\mathrm{in}},\mathrm{edges}_{\mathrm{in}} \rbrack
  • ここで、\mathrm{nodes}_{\mathrm{in}} \in \mathbb{R}^{n \times d_n}\mathrm{edges}_{\mathrm{in}} \in \mathbb{R}^{n\times n \times d_e}である。またnはノードの数を表し、d_nはノード当たりの次元数を表し、d_eはエッジ当たりの次元数を表す。\mathrm{GT}はnode embedding \mathbf{h} \in \mathbb{R}^{n \times (\cdot)}とノイズレベルiをノードの特徴として受け取る。ベクトル差\mathbf{z}_j-\mathbf{k}のペアは次のようなエッジ特徴量を作る。
\mathrm{nodes}_{\mathrm{in}}\lbrack j,:\rbrack = \mathrm{concat}\lbrack \mathbf{h} \lbrack j,: \rbrack, i \rbrack \rbrack
\mathrm{edges}_{\mathrm{in}}\lbrack j,k,:\rbrack = \mathbf{z}\lbrack j, : \rbrack - \mathbf{z} \lbrack k,: \rbrack
  • ここで、ネットワーク\mathrm{nn}_{\theta^{'}} \rightarrow \mathbb{R}^{1}としてGraph Transformerを定義し、スカラーを出力する。
\mathrm{nn}_{\theta^{'}} : \lbrace \mathrm{nodes}_{\mathrm{in}}, \mathrm{edges}_{\mathrm{in}} \rbrace \rightarrow \mathrm{GT} \rightarrow \lbrace \mathrm{nodes}_{\mathrm{out}} \rbrace \rightarrow \mathrm{nn.Linear}(d_n, 1) \rightarrow \mathrm{sum}(\cdot) \rightarrow \lbrace \mathrm{output} \rbrace
  • 最終的に\epsilon_\thetaを定義するためにネットワークの学習可能なパラメータとして\theta=\lbrace \theta^{'}, \mathbb{h} \rbraceを導入し、\mathbf{z}に対する\mathrm{nn}_{\theta}の勾配を次のように計算する。
\epsilon_{\theta}(\mathbf{z},i)=\nabla_{\mathbf{z}}\mathrm{nn}_{\theta^{'}}(\mathbf{h},\mathbf{z},i)
  • なおニューラルネットワークベースの粗視化力場に関する先行研究では、スカラーエネルギーニューラルネットワークに事前エネルギー項を追加して、トレーニングデータセットから離れたCG力場の挙動を改善することが多いが、デノイジング力場では安定的な粗視化MDを行うためにこういった操作は不要である。

Discussion