📏

Rectified Flowを試す

2023/03/06に公開

概要

ICLR2023にContinuous Normalizing Flow関連のものが出ていて,それ関連で調べていたら以下の論文を見つけた.
https://arxiv.org/abs/2209.03003

定式化

色々端折って簡単に示す.
扱いやすい分布\pi_0からのサンプルを,目標とするデータ分布\pi_1のサンプルへの変換を考える.
この過程はODEで記述できて,t \in [0, 1]を用いて以下のように表せる.

\mathrm{d}X_t = v(X_t, t)\mathrm{d}t

ただし,X_0 \sim \pi_0, X_1 \sim \pi_1
というように,ODEで表せるから,あとは右項のv: \mathbb{R}^d \times [0, 1] \mapsto \mathbb{R}^dをNNで学習できればOK.

学習方法

X_0からX_1へのstraight pathを考えているのだから,その方向はX_1 - X_0で表せる.
したがって,本来のODEは以下の通り

\mathrm{d} X_t = (X_1 - X_0) \mathrm{d}t

しかしながら,推論時にX_1はわからないから,パスの途中をNNで学習させる.
直線を考えるのであれば,途中のX_tは線形補間で表してX_t = t X_1 + (1 - t) X_0
このときの変化量を予測するのだから,損失関数はX_1 - X_0X_ttを入力としたNNの出力とのLossということになる.
X_1をデータ,X_0を正規分布からのサンプル結果として,損失関数は以下のように表せる.

\begin{align} L(\theta) &= || (X_1 - X_0) - v_{\theta}(X_t, t) ||_2^2 \\ X_t &= t X_1 + (1 - t) X_0 \end{align}

つまり,軌道上での変化量をvに学習させるようなイメージである.

なので,これをpytorchで書くと以下のように表せる.

import torch
import torch.nn.functional as F

def loss_fn(model, x_1):
    x_0 = torch.randn_like(x_1)
    x_t = t * x_1 + (1 - t) * x_0
    v = model(x_t, t)
    loss = F.mse_loss(x1 - x0, v)
    return loss

なんとシンプルなのだろうか.
x_1はデータで,今回の記事の実験では,MNIST等の画像データを用いた.
なお,モデルv_\thetaの構造は何でもいいのだが,今回はdiffusion modelに合わせて,hugging faceの以下の記事のモデルをそのまま使う.

https://huggingface.co/blog/annotated-diffusion

理由としては,この記事のコードを改造して,score-sde系の実験をしていたため,そのまま使うのが楽だからである.

推論方法

定式化のところで,変換の過程がODEで表されることを見た.
よって,サンプリングではscipyなどにあるODEソルバーをそのまま使うことができる.
今回はscipy.integrate.solve_ivp(method='RK45')を利用した.

使い方としては,例えば以下のようにできる.

@torch.no_grad()
def sample_ode(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # x_0
    x = torch.randn(shape, device=device)
    
    def ode_func(t, x):
        x = torch.tensor(x, device=device, dtype=torch.float).reshape(shape)
        t = torch.full(size=(b,), fill_value=t, device=device, dtype=torch.float)
        v = model(x, t)
        return v.cpu().numpy().reshape((-1,)).astype(np.float64)
    
    res = integrate.solve_ivp(ode_func, (0, 1.), x.reshape((-1,)).cpu().numpy(), method='RK45')
    # x_1
    x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
    return x

基本1次元のnumpyベクトルとしてODEを解くわけだが,ODEの部分はpytorchに都度変換して行う.
ボトルネックになっていそうだが,致し方なし.
結局ODEが解ければ何でもいいので,他の手法でも代用可.

実際,本当に軌道が直線になっていれば,オイラー法で1ステップ推論が可能となるだろう.
したがって,RK45は過剰すぎる気はする.

実験

実際に学習させてみる.

  • Loss : 前節のもの
  • NN : UNet(前節のリンクのもの)
  • Optimizer : AdamW(lr = 1e-4)
    • gradient clip(1.0)
  • データ : MNIST, Fashion MNIST
  • 10 epoch

結果は以下の通りである.

モデルのパラメータ数的には大げさであるが,それなりに生成できているのではないだろうか.
ただ,データが比較的簡単であるし,定量的な評価はできていないので,なんとも言えない...

score-sdeでは,連続時間の拡散モデルをSDEで記述していたわけだが,今回のようにシンプルなODEで記述することで,SDEみたいな複雑なことを考えなくても良くなる.
それに,結局score-sdeでも逆過程はODE(probability flow ode)で解いたほうが品質良かったような気がするので,だったら,はじめからODEで記述して良いではないだろうか.

追実験(AFHQ cat 128x128)

  • 100epoch

100epochだとまだ崩れている部分や高周波が潰れている部分が見受けられた.
ただ,Lossはまだ下がりそうでもっと回せば良くなるかも.
また,モデルの構造についても,改善の余地あり.
まあ,上のようなシンプルなLossだけでここまでの品質になり得るなら十分ではないだろうか?

  • 300 epoch

300epochでも目の部分が崩れているものが多い.
モデルサイズを大きくしたりすればよいのだろうか.

感想等

何よりシンプルなのが良い.
学習もMSEのみで安定している点も利点といえる.
逆過程についても,時間を逆に考えれば良くて,シンプルに\mathrm{d}tをネガティブ時間ステップと考えれば良いだけだろう.

コード

読みづらいのは許してください.

https://github.com/reppy4620/diffusion/tree/master/rect_flow

Discussion