🐥
【拡散モデル】SBMサンプリング
拡散モデル本 の勉強。
今回は、p.39に記載されているSBMサンプリングのアルゴリズムを、前回学習したスコアを用いてサンプリングしてみようと思う。
コード
import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from mim import train
from generate_dist import GaussianMixture, display_pdf, display_score
from score_matching import MLP
def sampling(input_size, score, noises):
BASE_ALPHA = 0.5
K = 50 # loop num
x = torch.normal(0, noises[-1], (input_size,1)).squeeze()
xs = []
for noise in reversed(noises):
xs_per_noise = []
alpha = BASE_ALPHA * noise/noises[-1]
for k in range(K):
u = torch.normal(0, 1, (2,1)).squeeze()
s = score(x, noise)
x = x + alpha*s + np.sqrt(2*alpha)*u
xs_per_noise.append(x.detach().numpy())
xs.append(xs_per_noise)
return xs
def _calc_true_score(x, noise):
dist = GaussianMixture(noise=noise)
x = x.detach().numpy()
v = dist.score(np.array([[x[0]]]), np.array([[x[1]]]))
return torch.tensor(v).squeeze()
def plot_trail(model, noises):
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(), x_1_line, x_2_line)
x_1_line = np.linspace(-5, 10, 20)
x_2_line = np.linspace(-5, 10, 20)
display_score(GaussianMixture(), x_1_line, x_2_line)
xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
for i, xs in enumerate(xs):
color = 'rgbcmyk'[i]
plt.plot([x[0] for x in xs], [x[1] for x in xs], c=color, marker='o', markersize=max(6-i,1), alpha=0.5)
plt.show()
if __name__ == "__main__":
model = MLP(3, 2)
noises = np.array([1/4, 1/2, 1, 2, 4, 8])
model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())
plot_trail(model, noises)
sample_result = []
for t in range(1000):
xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
sample_result.append(xs[-1][-1])
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(), x_1_line, x_2_line)
for i, xs in enumerate(xs):
plt.scatter([x[0] for x in sample_result], [x[1] for x in sample_result], color='red', marker=".", s=48, alpha=0.05)
plt.show()
サンプリング1過程の軌跡の表示と、1000回サンプルした時の散布図を表示している。
結果
まずは、サンプリングの過程の結果。赤 -> 緑 -> 青...の順で、ノイズが小さくなっていくが、ノイズが大きい時には、スコアの値を無視して、大きく動くことがわかる。これにより、初期値付近にいた峰を回避して、別の峰に移動できており、多峰性の困難さを克服していることがわかる。ノイズが小さくなるときは、より尤もらしい点をサンプルできるように、サンプルの品質を上げて行っていることがわかる。
実際に1000回サンプリングするとこんな感じで、真の分布に対して上手くいっているように見える。1サンプルあたりに、二重ループを回さなければならないので、結構遅い。
次回は、DDPMの学習とサンプリングに取り組んでみよう!
Discussion