🐥

# 【拡散モデル】SBMサンプリング

2023/03/25に公開

## コード

``````import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from mim import train

from generate_dist import GaussianMixture, display_pdf, display_score
from score_matching import MLP

def sampling(input_size, score, noises):
BASE_ALPHA = 0.5
K = 50 # loop num

x = torch.normal(0, noises[-1], (input_size,1)).squeeze()
xs = []

for noise in reversed(noises):
xs_per_noise = []
alpha = BASE_ALPHA * noise/noises[-1]
for k in range(K):
u = torch.normal(0, 1, (2,1)).squeeze()
s = score(x, noise)
x = x + alpha*s + np.sqrt(2*alpha)*u
xs_per_noise.append(x.detach().numpy())
xs.append(xs_per_noise)

return xs

def _calc_true_score(x, noise):
dist = GaussianMixture(noise=noise)
x = x.detach().numpy()
v = dist.score(np.array([[x[0]]]), np.array([[x[1]]]))

def plot_trail(model, noises):
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(), x_1_line, x_2_line)

x_1_line = np.linspace(-5, 10, 20)
x_2_line = np.linspace(-5, 10, 20)
display_score(GaussianMixture(), x_1_line, x_2_line)

xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
for i, xs in enumerate(xs):
color = 'rgbcmyk'[i]
plt.plot([x[0] for x in xs], [x[1] for x in xs], c=color, marker='o', markersize=max(6-i,1), alpha=0.5)
plt.show()

if __name__ == "__main__":
model = MLP(3, 2)
noises = np.array([1/4, 1/2, 1, 2, 4, 8])
model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())

plot_trail(model, noises)

sample_result = []
for t in range(1000):
xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
sample_result.append(xs[-1][-1])

x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(), x_1_line, x_2_line)

for i, xs in enumerate(xs):
plt.scatter([x[0] for x in sample_result], [x[1] for x in sample_result], color='red', marker=".", s=48, alpha=0.05)

plt.show()
``````

サンプリング１過程の軌跡の表示と、1000回サンプルした時の散布図を表示している。

## 結果

まずは、サンプリングの過程の結果。赤 -> 緑 -> 青...の順で、ノイズが小さくなっていくが、ノイズが大きい時には、スコアの値を無視して、大きく動くことがわかる。これにより、初期値付近にいた峰を回避して、別の峰に移動できており、多峰性の困難さを克服していることがわかる。ノイズが小さくなるときは、より尤もらしい点をサンプルできるように、サンプルの品質を上げて行っていることがわかる。