Rectified Flowを試す
概要
ICLR2023にContinuous Normalizing Flow関連のものが出ていて,それ関連で調べていたら以下の論文を見つけた.
定式化
色々端折って簡単に示す.
扱いやすい分布
この過程はODEで記述できて,
ただし,
というように,ODEで表せるから,あとは右項の
学習方法
したがって,本来のODEは以下の通り
しかしながら,推論時に
直線を考えるのであれば,途中の
このときの変化量を予測するのだから,損失関数は
つまり,軌道上での変化量を
なので,これをpytorchで書くと以下のように表せる.
import torch
import torch.nn.functional as F
def loss_fn(model, x_1):
x_0 = torch.randn_like(x_1)
x_t = t * x_1 + (1 - t) * x_0
v = model(x_t, t)
loss = F.mse_loss(x1 - x0, v)
return loss
なんとシンプルなのだろうか.
x_1はデータで,今回の記事の実験では,MNIST等の画像データを用いた.
なお,モデル
理由としては,この記事のコードを改造して,score-sde系の実験をしていたため,そのまま使うのが楽だからである.
推論方法
定式化のところで,変換の過程がODEで表されることを見た.
よって,サンプリングではscipyなどにあるODEソルバーをそのまま使うことができる.
今回はscipy.integrate.solve_ivp(method='RK45')を利用した.
使い方としては,例えば以下のようにできる.
@torch.no_grad()
def sample_ode(model, shape):
device = next(model.parameters()).device
b = shape[0]
# x_0
x = torch.randn(shape, device=device)
def ode_func(t, x):
x = torch.tensor(x, device=device, dtype=torch.float).reshape(shape)
t = torch.full(size=(b,), fill_value=t, device=device, dtype=torch.float)
v = model(x, t)
return v.cpu().numpy().reshape((-1,)).astype(np.float64)
res = integrate.solve_ivp(ode_func, (0, 1.), x.reshape((-1,)).cpu().numpy(), method='RK45')
# x_1
x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
return x
基本1次元のnumpyベクトルとしてODEを解くわけだが,ODEの部分はpytorchに都度変換して行う.
ボトルネックになっていそうだが,致し方なし.
結局ODEが解ければ何でもいいので,他の手法でも代用可.
実際,本当に軌道が直線になっていれば,オイラー法で1ステップ推論が可能となるだろう.
したがって,RK45は過剰すぎる気はする.
実験
実際に学習させてみる.
- Loss : 前節のもの
- NN : UNet(前節のリンクのもの)
- Optimizer : AdamW(lr = 1e-4)
- gradient clip(1.0)
- データ : MNIST, Fashion MNIST
- 10 epoch
結果は以下の通りである.
モデルのパラメータ数的には大げさであるが,それなりに生成できているのではないだろうか.
ただ,データが比較的簡単であるし,定量的な評価はできていないので,なんとも言えない...
score-sdeでは,連続時間の拡散モデルをSDEで記述していたわけだが,今回のようにシンプルなODEで記述することで,SDEみたいな複雑なことを考えなくても良くなる.
それに,結局score-sdeでも逆過程はODE(probability flow ode)で解いたほうが品質良かったような気がするので,だったら,はじめからODEで記述して良いではないだろうか.
追実験(AFHQ cat 128x128)
- 100epoch
100epochだとまだ崩れている部分や高周波が潰れている部分が見受けられた.
ただ,Lossはまだ下がりそうでもっと回せば良くなるかも.
また,モデルの構造についても,改善の余地あり.
まあ,上のようなシンプルなLossだけでここまでの品質になり得るなら十分ではないだろうか?
- 300 epoch
300epochでも目の部分が崩れているものが多い.
モデルサイズを大きくしたりすればよいのだろうか.
感想等
何よりシンプルなのが良い.
学習もMSEのみで安定している点も利点といえる.
逆過程についても,時間を逆に考えれば良くて,シンプルに
コード
読みづらいのは許してください.
Discussion