🦔
【拡散モデル】デノイジングスコアマッチングによるスコアの学習
拡散モデル本 の勉強。
今回は、1.5.5節で説明されるデノイジングスコアマッチングによってスコアを学習してみる。SBMによるサンプリングを見据えて、p.38の式を目的関数とする。
コード
モデルを学習した後、真のスコアと学習したスコアを比較するためのプロットを表示する。
import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from generate_dist import GaussianMixture, display_pdf, display_score
class MLP(torch.nn.Module):
def __init__(self, input_size, output_size):
super(MLP, self).__init__()
self.model = nn.Sequential(
torch.nn.Linear(input_size, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, output_size),
# torch.nn.Linear(input_size, output_size),
)
def forward(self, x):
x = self.model(x)
return x
def train(model: torch.nn.Module, noises: np.array, dist: GaussianMixture):
x_log = []
noised_x_log = []
loss_log = []
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epoch = 10000
for epoch in range(num_epoch):
optimizer.zero_grad()
data_x = dist.sample(size=2000)
sigma = np.random.choice(noises)
noised_x = np.random.normal(data_x, sigma, size=data_x.shape)
y_tensor = torch.from_numpy((data_x - noised_x)/sigma/sigma).float()
x_tensor = torch.from_numpy(np.insert(noised_x, data_x.shape[1], sigma, axis=1)).float()
pred = model(x_tensor)
loss = criterion(pred, y_tensor)*sigma*sigma
loss.backward()
optimizer.step()
loss_log.append((epoch+1, loss.item()))
x_log.append(data_x)
noised_x_log.append(noised_x)
if (epoch+1) % 100 == 0:
print(f'Epoch[{epoch+1}/{num_epoch}], Loss: {loss.item():.4f}')
return model, loss_log, x_log, noised_x_log
def model_score(x_1_grid, x_2_grid, model: MLP, noise):
x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)
x_tensor = torch.from_numpy(np.insert(x_point_arr, x_point_arr.shape[1], noise, axis=1)).float()
v = model(x_tensor).detach().numpy()
return v.reshape((x_1_grid.shape[0], x_1_grid.shape[1], 2))
def display_model_score(x_1_line: np.array, x_2_line: np.array, model, noise):
x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
score = model_score(x_1_grid, x_2_grid, model, noise)*0.0001
plt.quiver(x_1_grid, x_2_grid, score[:,:,0], score[:,:,1], color='blue',angles='xy')
def compare_model_score(model, noises):
for noise in noises:
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(noise=noise), x_1_line, x_2_line)
x_1_line = np.linspace(-5, 10, 20)
x_2_line = np.linspace(-5, 10, 20)
display_score(GaussianMixture(noise=noise), x_1_line, x_2_line)
display_model_score(x_1_line, x_2_line, model, noise)
plt.title('noise: ' + str(noise))
plt.show()
if __name__ == "__main__":
model = MLP(3, 2)
noises = np.array([1/4, 1/2, 1, 2, 4, 8])
model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())
compare_model_score(model, noises)
結果
前回描画した真の分布のスコア(赤)と、今回学習したモデルのスコア(青)を比較してみる。
おおむねうまくいっているように見えるが、尤度が低い場所でずれが生じている。
尤度が低い場所からどちらの峰の方向を選ぶかが、真の値のベクトルとずれている。
次回は、SBMサンプリングを実行して、ランダムな初期点からサンプリングしてみよう!
Discussion