😺

Rectified Flow④

2025/01/21に公開

前回Rectified Flowの論文を読み終わったので, 今回はその改良論文である2つ目の論文を読みます. 全部読む予定でしたが長くなったので実験部分は次回にします.

https://openreview.net/forum?id=mSHs6C7Nfa


https://zenn.dev/fmuuly/articles/37cc3a2f17138e


https://zenn.dev/fmuuly/articles/a062fcd340207f


https://zenn.dev/fmuuly/articles/0f262fc003e202

書籍情報

参考文献も同一です.

Sangyun Lee and Zinan Lin and Giulia Fanti. Improving the Training of Rectified Flows. Advances in Neural Information Processing Systems, 38, 2024.

関連リンク

今回はarXivにあるものを読みます.

TL; DR

タイトルはImproving the Training of Rectified Flowsなのですが, この論文ではReflowに焦点を当ていて, Reflowが1回で十分であることを主張し, 1回のreflowで性能を上げる改善案を出しています. なお, 一部は画像ドメイン限定で利用可能なものです.

  • timestepはuniformではなくU-shapedな分布からサンプリングし, t=0, 1 付近を重点的に訓練する
  • L2 lossではなくPseudo-Huber lossを採用する. さらに重み付きLPIPSをlossに加える
  • 事前学習済み拡散モデル (EDM)のパラメータで初期化する
  • 実データを用いる

導入

Rectified Flow

まずはじめに, 簡単にRectified Flowについて思い出します. 以降では, 今回の論文のnotationを採用するのでこれまでのnotationとは異なります. Rectified Flowは2つの分布 p_{\bold{x}}p_{\bold{z}} をなめらかに移動するようなモデルです. なめらかというのは曖昧ですが, 軌跡の傾きの変化が緩やか, すなわち直線軌道のことと考えられます. \bold{x}\sim p_{\bold{x}}\bold{z}\sim p_{\bold{z}} に対して t\in[0, 1] として補間 \bold{x}_t=(1-t)\bold{x}+t\bold{z} を考えたとき, ODE

\dfrac{d\bold{z}_t}{dt}=\bold{v}(\bold{z}_t)\coloneqq\dfrac{1}{t}(\bold{z}_t-\mathbb{E}[\bold{x}\mid\bold{x}_t=\bold{z}_t])

は任意の t に対して同じ周辺分布 \bold{x}_t となります. 実際はこの条件付き期待値をNNで求めるわけですが, L2 lossで訓練します.

\min_{\boldsymbol{\theta}}\mathbb{E}_{\bold{x},\bold{z} \sim p_{\bold{x}\bold{z}}}\mathbb{E}_{t\sim p_t}[\omega(t)\|\bold{x}-\bold{x}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]

実装上では速度 \bold{v}_{t} をパラメータ化します.

\min_{\boldsymbol{\theta}}\mathbb{E}_{\bold{x},\bold{z} \sim p_{\bold{x}\bold{z}}}\mathbb{E}_{t\sim p_t}[\|(\bold{z}-\bold{x})-\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]

これは, 先ほどの式において \omega(t)=\dfrac{1}{t^2} としたときに相当します. 具体的には

\begin{align*} \int_0^1\mathbb{E}[\|(\bold{z}-\bold{x})-\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]dt&=\int_0^1[\|(\bold{x}_{t}-\bold{x})/t-\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]dt \\ &=\int_0^1[\|(\bold{x}_{t}-\bold{x})/t-(\bold{x}-\bold{x}_{\boldsymbol{\theta}}(\bold{x}_t, t))/t\|_2^2]dt \\ &=\int_0^1\mathbb{E}[\dfrac{1}{t^2}\|\bold{x}-\bold{x}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]dt \end{align*}

によって従います. この論文では片方の分布, ここでは p_{\bold{z}} をガウス分布 \mathcal{N}(\bold{0}, \bold{I}) である状況を考えます. この状況では \bold{x}\bold{z} は独立, すなわち p_{\bold{xz}}(\bold{x}, \bold{z})=p_{\bold{x}}(\bold{x})p_{\bold{z}}(\bold{z}) です. また, \bold{x}_t に関して非線形補間を用います.

Reflow

これまでtoy-dataで確認してきたように, 単にrectified flowを行っただけでは軌道が直線にはなりません. Reflowではrectified flowを行ったモデルでデータを生成し, もう一度rectified flowを行います. 理論上は無限回行うことで完全な直線が保証されますが, 前回確認したrectified flowの論文では2回行っていました.

Reflowの回数は...

この論文で著者らは, reflowの回数は1回でOKであることを主張しています. rectified flowの論文では2回行っていましたが, これは訓練設定などが十分に最適化されていないことが原因で, これを改善することで十分な結果を1回のみのreflowで手に入れることができるようです.

まず, 最適な2-rectified flowの曲率が0 (直線軌道である)ということは, 1-rectified flowによって生成されたペアの線形補間軌道が交差しない場合, または同等に全てのペア (\bold{x}',\bold{z}') に対して \mathbb{E}[\bold{x}\mid\bold{x}_t=(1-t)\bold{x}_t'+t\bold{z}_t']=\bold{x}' が成り立つ場合に限られることに注意します.

まずは, 合成分布 p^1(\bold{x}) = \int p^1(\bold{x}, \bold{z})\,d\bold{z} の多様体 \mathcal{M}_{\bold{x}} を考えます. この多様体上の2点 \bold{x}', \bold{x}'' と1-rectified flowによって \bold{x}', \bold{x}'' にマッピングされる2つのノイズ \bold{z}', \bold{z}'' を考えます. このとき2つのペア (\bold{x}', \bold{z}')(\bold{x}'', \bold{z}'') が交差するとはある t\in[0, 1] が存在し, (1-t)\bold{x}'+t\bold{z}'=(1-t)\bold{x}''+t\bold{z}'' が成り立つことを言います. 例えば下図の(a)では, 2つの軌道が中間の t で交差しています.

時刻 t で2つの軌道が交差するための条件は2つあります.

  1. 1-rectified flowが \bold{z}''\bold{x}'' にマッピングすること.
  2. \bold{z}''=\bold{z}'+\dfrac{1-t}{t}(\bold{x}'-\bold{x}'') が成り立つこと.

しかしながら, 現実のデータ分布に対して1-rectified flowが十分に訓練されている場合, \bold{z}''=\bold{z}'+\dfrac{1-t}{t}(\bold{x}'-\bold{x}'') は一般的なノイズとはなりません. これも上の図の (a)を見ればわかりますが, ガウス分布の球状分布には収まっていないことがわかります.

1-rectified flowは多くの場合, 一般的なガウス分布から得られたノイズを用いて訓練されるので, \bold{z}'' を多様体 \mathcal{M}_{\bold{x}} にマッピングすることは一般にはできないです. 上図の(c, d)を見てみると, \bold{z}'' が一般的なガウシアンノイズとは異なる性質を持っていること (normが大きいこと, 自己相関が0でないこと)が確認できます. その結果, 上図の(b)を見ればわかるように \bold{z}' から得られる生成例と \bold{z}'+\bold{x}'-\bold{x}'' から得られる生成例は全く異なることがわかります.

このことは, 2-rectified flowを訓練するときに交差はあまり発生しない (\mathbb{E}[\bold{x}\mid\bold{x}_t=(1-t)\bold{x}_t'+t\bold{z}_t']\approx\bold{x}')であることを示唆します. すると, 先ほどの定義から, 最適な2-rectified flowの軌道がほぼ直線であることがわかります. このことから, reflowを複数回行う必要はなく, 却って品質の低下につながることが示唆されます.

コーナーケースを考える

\|\bold{x}'-\bold{x}''\|_2 が小さい場合を考えます. このとき, 1-rectified flowは \bold{z}'' を多様体 \mathcal{M}_{\bold{x}} 上のある点にマッピングする可能性があります. しかし, この場合でも \bold{x}'\bold{x}'' の平均は \bold{x}' に近いので, \mathbb{E}[\bold{x}\mid\bold{x}_t=(1-t)\bold{x}_t'+t\bold{z}_t']\approx \bold{x}' が成り立ちますので, 結論は変わりません.

同様に, t が1に近い場合を考えます. このとき \dfrac{1-t}{t}(\bold{x}'-\bold{x}'')\approx\bold{0} となるので, 1-rectified flowが \bold{z}'' を多様体 \mathcal{M}_{\bold{x}} 上の点にマッピングすることができます. 1-rectified flowが L-Lipschitzの場合は \|\bold{x}'-\bold{x}''\|_2\leq L\|\bold{z}'-\bold{z}''\|_2 なので \mathbb{E}[\bold{x}\mid\bold{x}_t]\bold{x}' から大きく離れることはないです.

Reflowの訓練改善

先ほどの考察から最適な2-rectified flowはほぼ直線軌道です. そのため, 2-rectified flowモデルのone-step生成の品質が期待した通りでない場合は訓練設定が良くないことが原因として考えられます. 以降では, few-step生成の品質を向上させる手法を見ていきます.

Timestep distribution

拡散モデルもrectified flowも時刻 t を用いて訓練します. この t のサンプリング戦略は大事で, 一般にはuniformを使いますがそうである必要はないです. 例えば, t によって難易度差が生じている場合は難しい t を多くサンプリングしてあげると効率的に訓練ができます. 一般的なアプローチとしてlossを見て, lossが大きい t に割り振るというものがあります. しかし, rectified flowのlossは難易度を測ることができないです. lossを変形してみると

\mathcal{L}(\boldsymbol{\theta}, t)\coloneqq\mathbb{E}[\dfrac{1}{t^2}\|\bold{x}-\bold{x}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2]=\dfrac{1}{t^2}\mathbb{E}[\bold{x}-\mathbb{E}[\bold{x}|\bold{x}_t]]+\overline{\mathcal{L}}(\boldsymbol{\theta}, t)

となります. 最初の項は \boldsymbol{\theta} に依存しないので, 訓練が進むことでこの部分が減少することはありません. 2番目の項が本当のlossになりますが, これを直接可視化することは第1項が通常は未知のため, できません. しかし, 先ほど最適な2-rectified flowにおいては第1項がほぼ0であることがわかったので, そのままlossを難易度を測る指標として使うことができます.

実際に, CIFAR-10を用いてlossを見てみます.

緑の線がlossで, 薄く塗られているのはlossの1標準偏差です. これを見ると, t\in[0, 1] の端の領域ではlossが大きくなっているのに対して中間に対しては十分小さくなっています. そこで, U字型のtimestep分布を用意します. 図では赤の破線で描かれているものです. 具体的には p_t(u)\propto\exp(au)+\exp(-au) を採用します. 実験的には a=4 がいいと著者らは主張していますが, a の値による性能差は示されていません.

Loss function

ReflowにおいてもL2 lossを用いていますが, 先ほど見たように \mathbb{E}[\bold{x}|\bold{x}_t=(1-t)\bold{x}_t'+t\bold{z}_t']\approx \bold{x}' ですので, m を任意の距離関数として

\min_{\boldsymbol{\theta}}\mathbb{E}_{\bold{x},\bold{z} \sim p_{\bold{x}\bold{z}}}\mathbb{E}_{t\sim p_t}[m(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))]

と等価です. この論文では m としてL2以外にも以下の3つを検討します.

  • Pseudo-Huber (以下では d をデータの次元として c=0.00054d)

    m_{\mathrm{hub}}(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))=\sqrt{\|\bold{z}-\bold{x}-\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t)\|_2^2+c^2}-c
  • LPIPS-Huber

    m_{\mathrm{lp-hub}}(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))=(1-t)m_{\mathrm{hub}}(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))+\mathrm{LPIPS}(\bold{x}, \bold{x}_t-t\cdot\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))
  • LPIPS-Huber-\frac{1}{t}

    m_{\mathrm{lp-hub-}\frac{1}{t}}(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))=(1-t)m_{\mathrm{hub}}(\bold{z}-\bold{x}, \bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))+\frac{1}{t}\mathrm{LPIPS}(\bold{x}, \bold{x}_t-t\cdot\bold{v}_{\boldsymbol{\theta}}(\bold{x}_t, t))

この中でPseudo-Huber lossはL2に比べると外れ値に対する感度が低く, 勾配分散を減少させる可能性があり, 学習が容易になります. これはconsistency modelsの改善論文でも使われているものです.

https://openreview.net/forum?id=WNzy9bRDvG

LPIPSは生成データとGTの知覚的距離を測るので, この距離を小さくするようモデルにも強制したいです. 一方で, 異なる2点が知覚的に類似していると0になってしまうのでpremetricとして使うことはできないです. そのため, Pseudo-Huber lossと組み合わせて使い, 重みづけを行います. 1-t の重みをつけることで t=1 に近い場合はLPIPSをより重視するようになります (と論文では述べられていますが正確には「Pseudo-Huber lossを軽視する」が正しい気がします. 相対的にはLPIPSを重視しているので変わらないですが...). ただ, これは t=0 近辺で勾配消失の可能性があるので, t で割るLPIPS-Huber-\frac{1}{t} も用意します.

Initialization with pre-trained diffusion models

1-rectified flowをスクラッチで訓練するのは大変です. 既存研究には事前学習済みの拡散モデルでrectified flowの目的関数

\dfrac{d\bold{z}_t}{dt}=\bold{v}(\bold{z}_t)\coloneqq\dfrac{1}{t}(\bold{z}_t-\mathbb{E}[\bold{x}\mid\bold{x}_t=\bold{z}_t])

\mathbb{E}[\bold{x}|\bold{x}_t=\bold{z}_t] を予測しています. この既存研究

https://openreview.net/forum?id=PLIt3a4yTm

で提示されている補題2 (arXiv版にあり, 上記リンクでは3.2の内容です)の特別な場合の命題をここでは用います. 論文ではProposition 1とされていて,

p^{RE}(\bold{x}|\bold{x}_t, t),\ p^{VP}(\bold{x}|\bold{x}_t, t),\ p^{VE}(\bold{x}|\bold{x}_t, t) をそれぞれ \mathcal{N}((1-t)\bold{x}, t^2\bold{I}),\ \mathcal{N}(\alpha(t)\bold{x}, (1-\alpha(t))^2\bold{I}),\ \mathcal{N}(\bold{x}, t^2\bold{I}) の事後分布とするとき,

\int p^{RE}(\bold{x}|\bold{x}_t, t)\bold{x}d\bold{x}=\int p^{VP}(\bold{x}|\bold{x}_t=s_{VP}\bold{z}_t, t_{VP})\bold{x}d\bold{x}=\int p^{VE}(\bold{x}|\bold{x}_t=s_{VE}\bold{z}_t, t_{VE})\bold{x}d\bold{x}

である. s_{VP}, s_{VE} はscaling factor, t_{VP}, t_{VE} はconverted timeです.

証明

p_t(\bold{x}_t|\bold{x})=\mathcal{N}(s(t)\bold{x},\sigma(t)^2\bold{I})p_t'(\bold{x}_t|\bold{x})=\mathcal{N}(s'(t)\bold{x},\sigma'(t)^2\bold{I}) を考えます. すると

\begin{align*} &p_t(\bold{x}_t|\bold{x})=\dfrac{1}{(2\pi\sigma(t)^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma(t)^2}\|\bold{x}_t-s(t)\bold{x}\|_2^2\right) \\ &p_t'(\bold{x}_t|\bold{x})=\dfrac{1}{(2\pi\sigma'(t)^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma'(t)^2}\|\bold{x}_t-s'(t)\bold{x}\|_2^2\right) \end{align*}

です. t'(t)\dfrac{s(t)}{\sigma(t)}=\dfrac{s'(t')}{\sigma'(t')} を満たすようにとります. 以降では p_t(\bold{x}|\bold{x}_t)=p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right) を示します. 変形すると,

\begin{align*} p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right)&= \dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma'(t')^2}\|\dfrac{s'(t')}{s(t)}\bold{x}_t-s'(t')\bold{x}\|_2^2\right) \\ &=\dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma'(t')^2}\|\dfrac{s'(t')}{s(t)}(\bold{x}_t-s(t)\bold{x})\|_2^2\right) \\ &=\dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma'(t')^2}\dfrac{s'(t')^2}{s(t)^2}\|\bold{x}_t-s(t)\bold{x}\|_2^2\right) \\ &=\dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma(t)^2}\dfrac{s(t)^2}{s(t)^2}\|\bold{x}_t-s(t)\bold{x}\|_2^2\right) \\ &=\dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma(t)^2}\|\bold{x}_t-s(t)\bold{x}\|_2^2\right) \\ &=\dfrac{1}{(2\pi\sigma'(t')^2)^{d/2}}\dfrac{(2\pi\sigma(t)^2)^{d/2}}{(2\pi\sigma(t)^2)^{d/2}}\exp\left(-\dfrac{1}{2\sigma(t)^2}\|\bold{x}_t-s(t)\bold{x}\|_2^2\right) \\ &=\dfrac{(2\pi\sigma(t)^2)^{d/2}}{(2\pi\sigma'(t')^2)^{d/2}}p_t(\bold{x}_t|\bold{x}) \\ &=\left(\dfrac{\sigma(t)}{\sigma'(t')}\right)^dp_t(\bold{x}_t|\bold{x}) \end{align*}

となります. ここから, p_t(\bold{x}_t|\bold{x})\propto p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right) が言えます (比例するだけで = ではないです). よって,

\begin{align*} &p_t(\bold{x}|\bold{x}_t)=\dfrac{1}{p_t(\bold{x}_t)}p_{\bold{x}}(\bold{x})p_t(\bold{x}_t|\bold{x}) \\ &p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right)=\left(\dfrac{\sigma(t)}{\sigma'(t')}\right)^d\dfrac{1}{p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right)}p_{\bold{x}}(\bold{x})p_t(\bold{x}_t|\bold{x}) \end{align*}

ここで,

\left(\dfrac{\sigma(t)}{\sigma'(t')}\right)^d\dfrac{1}{p'_{t'}\left(\bold{x}|\dfrac{s'(t')}{s(t)}\bold{x}_t\right)}=\int p_{\bold{x}}(\bold{x})p_t(\bold{x}_t|\bold{x})d\bold{x}

となり, 2つの密度は等しいです. 事後密度が同一である場合, その期待値も同一となり, 証明終了です.

converted timeは

t_{VP} t_{VE} s_{VP} s_{VE}
\dfrac{1}{9.95}\left(-0.05+\sqrt{0.0025-19.9\log\dfrac{1-t}{\sqrt{(1-t)^2+t^2}}}\right) \dfrac{t}{1-t} \dfrac{\alpha(t_{VP})}{1-t} \dfrac{1}{1-t}

です. ここで \alpha(t)=\exp\left(-\dfrac{1}{2}\displaystyle\int_0^t(19.9s+0.1)ds\right) です. この命題を利用することでEDMやDDPMなどの学習済みの拡散モデルを用いてReflowを初期化し, 上の表で時間とスケールの調整を行うことができます.

Incorporating real data

最後の工夫です. 2-rectified flowの訓練はdata freeで行うことができるのが特徴でした (データは1-rectified flowで生成して手に入れるため). 実際には実データが入手できる場合が多いのでこれを使うことを考えてみます.

似たアイデアは

https://openreview.net/forum?id=ctSjIlYN74

でも実験されています.

なお, 生成データと実ノイズのペアの割合を変えた場合でも実験しているようですが, この割合が0 (すなわちこのペアを全く使わない)の場合が最も性能が高いようです.

簡単な実験

実際の実験パートに移る前に, それぞれの改善方法が効果があるのかをCIFAR-10, AFHQ, FFHQを用いて確かめます. 評価指標はFIDです. ベースライン (config A)はuniformの時間分布でL2 lossを採用したrectified flowです.

まず, 学習済みの拡散モデルでの初期化は大きく性能改善に寄与していそうです. ただ, これでは初期化あるいはlarger batch size (128から512)のどちらが寄与したのかわからないという問題があります. ちなみにbatch sizeにはcritical batch sizeというものがあり, そこに達しない範囲でbatch sizeを上げて性能向上と謳うのは当たり前のことしか言っていないように思えます.

loss以外の部分については導入したら性能が上がっていることが確認できます. batch sizeに対する感度の低いPseudo-Huber lossを採用するようです. 3種類あったlossですが, 全体的にLPIPSを導入した方が良さそうです.

lossについて

ここでの流れとしてはまずL2 lossとPseudo-Huber lossのどっちがいいか?を検討し, Pseudo-Huber lossにLPIPSつけるとどうか?を次に検証しています. L2 lossとPseudo-Huber lossでは性能に大差ないですが, 以下の理由でPseudo-Huber lossが選択されているようです.

既存研究では, Pseudo-Huber lossはL2 lossと比較して外れ値に対してsensitiveではないので, 勾配の分散を抑えて学習が安定するというメリットがあります.

実際にbatch sizeが小さい場合, L2 lossより優れている (config A vs B)と論文では主張されていますが先述の通り, AからBでは要素が2つ追加されているのでどちらが大きく寄与しているかは不明です. 上の表からbatch sizeが大きい場合, CIFAR-10やFFHQではL2 lossと同等の性能で, AFHQでは上回る結果となっています. そのため, batch sizeに対する感度が低いのでPseudo-Huber lossを選ぶとしています.

終わりに

ここまでで15000字くらい書いてしまいました. 実験パートも続けて載せていいような気がしますが長くなり過ぎてしまう気がするので一旦終わりにします. 次回は実験パートを見ます.

Discussion