👋

Rectified Flow②

2025/01/05に公開

論文の続きを読みます.

https://openreview.net/forum?id=XVjTT1nw5z

今回もベースとなっている1つ目の論文で, 2.2の途中からです. 若干式変形がわからず, 論文の式変形をそのまま載せています.

前回
https://zenn.dev/fmuuly/articles/37cc3a2f17138e

書籍情報

参考文献も同様です

Xingchao Liu, Chengyue Gong, and Qiang Liu. Flow straight and fast: Learning to generate and transfer data with rectified flow. The Eleventh International Conference on Learning Representations, 2023

関連リンク

Main Results and Properties

前回は, rectified flowの概要を軽くさらった後に性質を3つ確認しました.

  1. \forall t\in[0, 1]\mathrm{Law}(X_t)=\mathrm{Law}(Z_t)
  2. 任意のconvext cost関数 c:\mathbb{R}^d\rightarrow\mathbb{R} に対して \mathbb{E}[c(Z_1-Z_0)]\leq\mathbb{E}[c(X_1-X_0)]
  3. \displaystyle\sum_{k=0}^{K}S(\boldsymbol{Z}^{k+1})+V((Z_0, Z_1^k))\leq\mathbb{E}[\|X_1-X_0\|^2]

Distillation

k-th rectified flow \boldsymbol{Z}^k を得た後, (Z_0^k, Z_1^k) の関係をNN \hat{T} に蒸留することで, flowをシミュレーションせずに直接 Z_0^k から Z_1^k を予測する手法として使えます. これにより推論速度が向上します. flowがほぼ直線的 (すなわち1stepの更新で十分に近似できる) な状況で, かつ蒸留が効率的にできる場合を考えます. 特に, 蒸留モデルとして \hat{T}=Z_0+v(Z_0, 0) とした場合, \boldsymbol{Z}^k の蒸留では

\mathbb{E}\left[\|(Z^k_1-Z^k_0)-v(Z^k_0,0)\|^2\right]

を損失関数として用います. これはOverviewで示した v を得るための回帰問題

\min_{v}\int_0^1\mathbb{E}\left[\|(X_1-X_0)-v(X_t, t)\|^2\right]\mathrm{d}t

t=0 の状況と一致します.

ここまでの説明ではdistillationとreflow (reflection)の違いがあまりよく分かりません. 著者らはカップリング (Z_0^k, Z_1^k) を忠実に近似するのがdistillation, reflectionはより低い輸送コストとより直線的なflowを持つ別のカップリング (Z_0^{k+1}, Z_1^{k+1}) を生成すること, としています. この違いから, 蒸留は最終段階で, 一段階の推論を高速化するためにモデルを微調整したい場合のみに適用すべきとしています.

On the velocity field v^X

X_0X_1=x_1 の条件下で条件付き密度関数 \rho(x_0\mid x_1) を与えるとき, 最適な速度場 v^X(z, t)=\mathbb{E}[X_1-X_0\mid X_t=z]

v^X(z,t)=\mathbb{E}\left[\dfrac{X_1-z}{1-t}\eta_t(X_1,z)\right]

で表されます. ここで,

\eta_t(X_1,z)=\dfrac{\rho\left(\dfrac{z-tX_1}{1-t}\middle|X_1\right)}{\mathbb{E}\left[\rho\left(\dfrac{z-tX_1}{1-t}\middle|X_1\right)\right]}

で, 期待値 \mathbb{E}[\cdot]X_1\sim\pi_1 に関して取られます.

これは, X_0=\dfrac{z-tX_1}{1-t},\ X_1-X_0=\dfrac{X_1-z}{1-t} とすると, 条件 X_t=z のとき, と見ることができます.

したがって \rho が正かつあらゆるところで連続な場合, v^X はwell definedかつ\mathbb{R}^d\times[0,1) で連続です. さらに, \log\eta_tz に関して連続的に微分可能であるとき

\nabla_z v^X(z, t)=\dfrac{1}{1-t}\mathbb{E}\left[((X_1-z)\nabla_z\log \eta_t(X_1, z)-1)\eta_t(X_1, z)\right].

です. もし v^X が任意の a<1 に対して [0, a] 上で一様にLipschitz連続であるならば, dZ_t=v^X(Z_t,t)\mathrm{d}t は一意の解を持つことが保証されます.

もし X_0\mid X_1=x_1 が条件付き密度関数を与えない場合は v^X(z,t) が未定義または不連続となる可能性があります. この場合はODE dZ_t=v^X(Z_t,t)\mathrm{d}t が適切に振る舞わないことも考えられます. 簡単な解決策として, X_0(X_0, X_1) と独立なガウシアンノイズ \xi\sim\mathcal{N}(0, \sigma^2I) を加えて平滑化された変数 \tilde{X}_0=X_0+\xi を作り, これを使って X_1 とのrectified flowを構成します.

Smooth function approximation

先ほどの速度場の式ですが, これに従うと \rho(\cdot\mid x_1) が存在して既知であることに加え, \pi_1 が有限個の経験測度である場合 (すなわちその期待値が正確に算出可能である場合) 正確に v^X(z, t) を計算できます. この場合, rectified flowを順方向に実行すると, \pi_1 のデータを正確に再現できますが, これは過学習状態であるため, 実用的ではないです. したがって, v^X をなめらかな関数近似器で適合させることが必要かつ有益です.

DNNは大規模問題において最良ですが, 低次元の場合は以下のようなシンプルなNadaraya–Watson styleのnon-parametricなものを用いると \rho が未知でも正確なrectified flowに対して精度のいい近似を得ることができます.

v^{X, h}(z, t)=\mathbb{E}\left[\dfrac{X_1-z}{1-t}\omega_h(X_t, z)\right]

ここで, \omega_h(X_t, z)=\dfrac{\kappa_h(X_t, z)}{\mathbb{E}[\kappa_h(X_t, z)]},\ \kappa_h(x, z) はsmoothing kernelで h>0 のbandwith parameterを持ち, zx の類似性を測定します. Gaussian BFR kernel \kappa_h(z, t)=\exp(-\|x-z\|^2/2h^2) として, h\to0^+ を考えると v^{X, h}(z, t)v^X(z, t)=\mathbb{E}\left[\dfrac{X_1-z}{1-t}\mid X_t=z\right] に収束します.

A Nonlinear Extension

rectified flowの非線形拡張を考えます. 今まで X_tX_0X_1 の線形補間として用いていましたが, この部分を X_0X_1 を結ぶ任意の時間微分可能な曲線に置き換えて拡張をします.

この場合, 前回見た定理の一部は成り立ちません. 例えば

\forall t\in[0, 1]\mathrm{Law}(X_t)=\mathrm{Law}(Z_t)

は成り立ちます (X_t の求め方に依存せず成り立つ)が, X_tX_0X_1 の線形補間であることを用いる

任意のconvext cost関数 c:\mathbb{R}^d\rightarrow\mathbb{R} に対して \mathbb{E}[c(Z_1-Z_0)]\leq\mathbb{E}[c(X_1-X_0)]

などは成り立たないことに注意が必要です. 具体的には3つ挙げた性質うち一つ目のみが成り立ちます. では, なぜ拡張をするかというと, rectified flowを非線形 (も含む形)に拡張すると, rectified flow, probability flow, DDIMを同じ枠組みの中で議論できるからです. それによってそれぞれの手法の特徴などを説明することが可能になります.

まず, \boldsymbol{X}=\{X_t:t\in[0,1]\} を任意の時間微分可能で X_0X_1 を結ぶrandom processとします. \dot{X}_tX_tt で微分したものとします. \boldsymbol{X} に誘発される (nonlinear) rectified flowは以下のように定義されます.

\mathrm{d}Z_t=v^{\boldsymbol{X}}(Z_t, t)\mathrm{d}t,\quad Z_0=X_0,\quad v^{\boldsymbol{X}}(Z_t, t)=\mathbb{E}\left[\dot{X}_t\mid X_t=t\right]

v^{\boldsymbol{X}} は以下を解くことで得られます.

\min_{v}\int_{0}^{1}\mathbb{E}\left[w_t\|v(X_t-t)-\dot{X}_t\|^2\right]\,\mathrm{d}t

ここで, w_t: (0, 1)\rightarrow(0, +\infty) は重みで, w_t=1 がデフォルトです.

簡単な補間プロセスの1つは X_t=\alpha_t X_1+\beta_tX_0 です. \alpha_t, \beta_t は微分可能で \alpha_1=\beta_0=1,\ \alpha_0=\beta_1=0 を満たすとします. この初期条件は X_0, X_1 をスタートとゴールにするためのものです. この場合では, \dot{X}_t=\dot{\alpha}_tX_1+\dot{\beta}_tX_0 で, カーブの形状は \alpha_t, \beta_t の関係性によって決定されます. ちなみに, \alpha_t+\beta_t=1 のときと, 特別な場合のみ直線軌道になります. ただし, \alpha_t+\beta_t=1 を満たしても速度一定は保証されないです.

Probability Flow ODEs and DDIM

今回のメインの部分になります. Probability Flow ODEs (PF-ODEs)とDDIMは球面ガウス分布[1] \pi_0 から始まる \pi_1 のODEベースの生成手法です. これは, DDIMで学習されたSDEと等価なODEに変換して導出されます. 詳細は以下の論文を参照ください.

https://openreview.net/forum?id=PxTIG12RRHS

この論文では, 3種類のSDEs (VE SDE, VP SDE, sub-VP SDE)から3種類のPF-ODEs (VE ODE, VP ODE, sub-VP ODE)を導出しています. 特に, VP ODEはDDIMの連続時間極限と等価です. 簡単にPF-ODEsとDDIMについてまとめます.

denoising diffusion modelsは標準ブラウン運動 W_t によって駆動されるSDEモデルを構築して学習を行います. 以下では, rectified flowの論文に沿ったnotationを使うのであまり見慣れないかもしれないです.

\mathrm{d}U_t=b(U_t,t)\mathrm{d}t+\sigma_t\mathrm{d}W_t,\quad U_0\sim\pi_0

\sigma_t:[0, 1]\rightarrow[0,+\infty] は多くの場合固定されます. b はNNで初期分布 \pi_0 は球面ガウス分布ですがアルゴリズム依存です. アイデアとしてはまず, 拡散過程 (forward process)でデータを近似的なガウス分布に壊し, reverse processでforward processの時間反転として過程を推定します. 証明はここでの本質ではないので省略しますが, 上記論文におけるVE, VP, sub-VP SDEsのlossは以下のようにまとめられます.

\min_{v}\int_{0}^{1}\mathbb{E}[w_t\|v(V_t, t)-Y_t\|_2^2]\,\mathrm{d}t,\quad V_t=\alpha_tX_1+\beta_t\xi_t,\quad Y_t=-\eta_tV_t-\dfrac{\sigma_t^2}{\beta_t}\xi_t

ここで, \xi_t はforward processで \xi_t\sim\mathcal{N}(0, I) を満たします. \eta_t, \sigma_t はハイパーパラメータの列で, \alpha_t, \beta_t

\alpha_t=\exp\left(\int_{t}^{1}\eta_s\,\mathrm{d}s\right),\quad\beta_t^2=\int_{t}^{1}\exp\left(2\int_{t}^{s}\eta_r\,\mathrm{d}r\right)\sigma_s^2\,\mathrm{d}s

です. VE SDEでは, \eta_t=0 とします. ここから \alpha_t=1 も従います. sub-VP SDEでは \eta_ss の線形関数として設定します. これにより a=19.9, b=0.1 として

\alpha_t=\exp\left(-\dfrac{1}{4}a(1-t)^2-\dfrac{1}{2}b(1-t)\right),\quad \beta_t=1-\alpha_t^2

となります. VP SDEでは \eta_t=-\dfrac{1}{2}\sigma_t^2 として \beta_t=\sqrt{1-\alpha_t^2} とします. DDPMでは b(x, t)=-\eta_tx-\dfrac{\sigma_t^2}{\beta_t}\varepsilon(x, t) としています. \varepsilon はNNとして推定され, (V_t, t) から \xi_t を予測します.

理論的には, 先ほどの式

\mathrm{d}U_t=b(U_t,t)\mathrm{d}t+\sigma_t\mathrm{d}W_t,\quad U_0\sim\pi_0

のSDEにおいて, b

\min_{v}\int_{0}^{1}\mathbb{E}[w_t\|v(V_t, t)-Y_t\|_2^2]\,\mathrm{d}t,\quad V_t=\alpha_tX_1+\beta_t\xi_t,\quad Y_t=-\eta_tV_t-\dfrac{\sigma_t^2}{\beta_t}\xi_t

を解いているとき, 初期値 U_0 = \alpha_0 X_1 + \beta_0 \xi_0 のとき \mathrm{Law}=(U_1)=\mathrm{Law}(X_1)=\pi_1 が得られます. また, \alpha_0X_1\ll\beta_0\xi_0 のときは U_0\approx\beta_0\xi_0 と近似できます.

Fokker-Planck方程式の性質を用いることで, DDIMなどの文献では上の目的関数 \min_{v}\int_{0}^{1}\mathbb{E}[w_t\|v(V_t, t)-Y_t\|^2]\,\mathrm{d}t に基づいて訓練された b を用いる場合には \mathrm{d}U_t=b(U_t,t)\mathrm{d}t+\sigma_t\mathrm{d}W_t のSDEは同じ周辺分布を持つODEに変換可能であることが観察されています. この辺りの内容は

https://www.youtube.com/watch?v=qo-pR-kgKbc

でわかりやすく説明されています. さて, この内容を数式で表すと

\mathrm{d}Z_t=\tilde{b}(Z_t,t)\mathrm{d}t,\quad \tilde{b}(\tilde{z},t)=\dfrac{1}{2}(b(z, t)-\eta_tz),\quad Z_0=U_0=\alpha_0X_1+\beta_0\xi_0

となります. 同様に, \tilde{b}

\min_{v}\int_{0}^{1}\mathbb{E}[w_t\|v(V_t, t)-\tilde{Y}_t\|_2^2]\,\mathrm{d}t,\quad V_t=\alpha_tX_1+\beta_t\xi_t,\quad \tilde{Y}_t=-\eta_tV_t-\dfrac{\sigma_t^2}{2\beta_t}\xi_t

の解として見なすことができます. これは \mathrm{d}U_t=b(U_t,t)\mathrm{d}t+\sigma_t\mathrm{d}W_t とは \dfrac{1}{2} の違いを除いて同一です. これは初期値が Z_0=U_0=\alpha_0X_1+\beta_0\xi_0 のときのみに成り立つ同値性です.

ここまで長々とDDIMやSDEsの話をしてきましたが, ようやく本題に入れます. 以降では X_t=\alpha_tX_1+\beta_t\xi を用いてnonlinear rectified flowのフレームワークにPB-ODEsやDDIMが含まれることを示します. 主に, \eta_t, \sigma_t を消去して \tilde{Y}_t\dot{X}_t と等価であることを示します.

論文では命題3.11とされていて, 以下の内容です.

\alpha_t=\exp\left(\displaystyle\int_{t}^{1}\eta_s\,\mathrm{d}s\right),\quad\beta_t^2=\displaystyle\int_{t}^{1}\exp\left(2\displaystyle\int_{t}^{s}\eta_r\,\mathrm{d}r\right)\sigma_s^2\,\mathrm{d}s を仮定する. X_t=\alpha_tX_1+\beta_t\xi の下では \displaystyle\min_{v}\displaystyle\int_0^1\mathbb{E}[w_t\|v(X_t, t)-\dot{X}_t\|^2]\,\mathrm{d}t\displaystyle\min_{v}\int_{0}^{1}\mathbb{E}[w_t\|v(V_t, t)-\tilde{Y}_t\|_2^2]\,\mathrm{d}t は同値.

先ほど述べたように, \tilde{Y}_t\dot{X}_t が等価であることを示ればOKです.

まず, \xi_t の相関構造が結果に影響を与えることはないので, 任意の時刻 t に対して \xi_t=\xi とできます. これにより, V_t=X_t=\alpha_tX_1+\beta_t\xi です. これにより, 命題を示すには \dot{X}_t=\tilde{Y}_t を示せばいいので, 変形します.

\begin{align*} \tilde{Y}_t&=-\eta_tV_t-\dfrac{\sigma_t^2}{2\beta_t}\xi_t=-\eta_tV_t-\dfrac{\sigma_t^2}{2\beta_t}\xi \quad (\because\xi_t=\xi) \\ &=-\eta_t(\alpha_tX_1+\beta_t\xi)-\dfrac{\sigma_t^2}{2\beta_t}\xi \quad (\because V_t=X_t=\alpha_tX_1+\beta_t\xi) \\ &=-\eta_t\alpha_tX_1+\left(-\eta_t\beta_t+\dfrac{\sigma_t^2}{2\beta_t}\right)\xi \\ &=\dot{\alpha}_tX_1+\dot{\beta}_t\xi \\ &=\dot{X}_t \end{align*}

下から2つ目の等号について, 第1項は \eta_t=-\dfrac{\dot{\alpha}_t}{\alpha_t}, 第2項は \sigma_t^2=2\beta_t^2\left(\dfrac{\dot{\alpha}_t}{\alpha_t}-\dfrac{\dot{\beta_t}}{\beta_t}\right) を用いています.

第1項について

\alpha_t=\exp\left(\displaystyle\int_{t}^{1}\eta_s\,\mathrm{d}s\right) なので両辺の対数を取り t で微分すると

\begin{align*} &\alpha_t=\exp\left(\int_{t}^{1}\eta_s\,\mathrm{d}s\right) \\ &\log\alpha_t=\int_t^1\eta_s \mathrm{d}s \\ &\frac{\dot{\alpha_t}}{\alpha_t}=-\eta_t\qquad \dot{\alpha_t}=-\eta_t\alpha_t \end{align*}
第2項について

\displaystyle\int\eta_rdr=F(r)+C とすると, \exp の中身は 2F(s)-2F(t)

\begin{align*} &\beta_t^2=\int_t^1\exp\left(2\int_t^s\eta_rdr\right)\sigma_s^2ds \\ &\beta_t^2=\int_t^1\exp\left(2F(s)-2F(t)\right)\sigma_s^2ds \\ &\beta_t^2=-\int_1^t\exp\left(2F(s)-2F(t)\right)\sigma_s^2ds \\ &\beta_t^2=-\int_1^t\exp(2F(s))\cdot\exp(-2F(t))\sigma_s^2ds \\ &\beta_t^2=-\exp(-2F(t))\int_1^t\exp(2F(s))\sigma_s^2ds \\ &\exp(2F(t))\beta_t^2=-\int_1^t\exp(2F(s))\sigma_s^2ds \\ \end{align*}

両辺を t で微分して

\begin{align*} &\dfrac{d}{dt}\{\exp(2F(t))\beta_t^2\}=-\dfrac{d}{dt}\int_1^t\exp(2F(s))\sigma_s^2ds \\ &2F'(t)\exp(2F(t))\beta_t^2+2\exp(2F(t))\beta_t\dot{\beta_t}=-\exp(2F(t))\sigma_t^2 \\ &2F'(t)\beta_t^2+2\beta_t\dot{\beta_t}=-\sigma_t^2 \\ \end{align*}

F'(t)=\eta_t なので \eta_t=-\dfrac{\dot{\alpha_t}}{\alpha_t} も用いると

\begin{align*} -\sigma_t^2&=2F'(t)\beta_t^2+2\beta_t\dot{\beta_t} \\ &=2\eta_t\beta_t^2+2\beta_t\dot{\beta_t} \\ &=-2\dfrac{\dot{\alpha}_t}{\alpha_t}\beta_t^2+2\beta_t\dot{\beta_t} \\ &=-2\beta_t^2\left(\dfrac{\dot{\alpha}_t}{\alpha_t}-\dfrac{\dot{\beta_t}}{\beta_t}\right) \end{align*}

よって \sigma_t^2=2\beta_t^2\left(\dfrac{\dot{\alpha}_t}{\alpha_t}-\dfrac{\dot{\beta_t}}{\beta_t}\right) となります.

さて, 実際にtoy-dataを使って実験してみます.

特に, N=5 がわかりやすい結果となっています. N=1 ではそもそも中間の分布 (オレンジの分布)を推論しないので直線軌道は当たり前です. ある程度推論回数が増えて N=5 になると \pi_1 である赤色の分布にどの手法でもたどり着くものの, rectified flow以外は曲がった軌道 (非直線)であることがわかります. また, 真ん中の2つは分布同士の間隔も一定ではなく, これが一定速度でないことを示しています.

ところで, この実験結果ではVE ODEがありません. それは以下の理由によります. VE ODEでは \alpha_t=1,\ \beta_t=\sigma_{\min}=\sqrt{r^{2(1-t)}-1} が使われ, デフォルトでは \sigma_{\min}=0.01 です. r\pi_1 からの全ての訓練データの点の最大のユークリッド距離が \sigma_{\max}=r\sigma_{\min} になるように設定します. \sigma_{\max}^2\sigma_{\min}^2X_1 の分散より遥かに大きいとすると, X_0=X_1+\beta_0\xi\approx\sigma_{\max}\xi となり, 初期分布を \pi_0\sim\mathcal{N}(0, \sigma_{\max}I^2) にできます. これは \pi_1 より遥かに大きい分散を持っているので, toy-dataには適用できなくなってしまいます.

まとめ

\min_{v}\int_{0}^{1}\mathbb{E}\left[w_t\|v(X_t-t)-\dot{X}_t\|^2\right]\,\mathrm{d}t

で示されたnonlinearなrectified flowのフレームワークは既存のものを簡略化および拡張するもので, いくつかの重要なことが言えそうです.

  • ODEの学習はdiffusion/SDEな手法に頼ることなく, 直接的かつ独立して考えることができる
  • 学習したODEの経路は X_0X_1 の間の任意の滑らかな補間曲線 X_tによって指定できる
  • 初期分布 \pi_0は補間 X_t の選択とは無関係に任意に選択可能
  • X_t は任意に選べるが線形補間 X_t=tX_1+(1-t)X_0 が推奨
  • 非線形な X_t は変数の非ユークリッド幾何学的構造を取り入れたり, ODEの軌跡に特定の制約を設けたい場合に有用

おわりに

13000字ほど書いたので今回はここで終了です. これで論文のsection 3まで終わりました (section 3.3の定理3.6とsection3.4はスキップしてます). 関連研究は飛ばすので (適宜触れることはあります), 次回はarXivにある論文の22ページからの実験パートです.

脚注
  1. 高次元の正規分布は超球面上に分布するという話だと思います ↩︎

Discussion