🦔

# 【拡散モデル】デノイジングスコアマッチングによるスコアの学習

2023/03/22に公開

\Sigma_{t=1}^T w_t \mathbb{E}_{x \sim p_{data}(x), \tilde{x}\sim \mathcal{N}(x, \sigma_t^2 I)} \left[\left\| \frac{x-\tilde{x}}{\sigma_t^2} - s_\theta(\tilde{x}, \sigma_t) \right\|^2 \right]

w_t=\sigma_t^2に設定する。

s_\theta は隠れ層を2個持つ単純な多層パーセプトロンを使って学習する。

## コード

モデルを学習した後、真のスコアと学習したスコアを比較するためのプロットを表示する。

import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt

from generate_dist import GaussianMixture, display_pdf, display_score

class MLP(torch.nn.Module):
def __init__(self, input_size, output_size):
super(MLP, self).__init__()

self.model = nn.Sequential(
torch.nn.Linear(input_size, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, output_size),
# torch.nn.Linear(input_size, output_size),
)

def forward(self, x):
x = self.model(x)
return x

def train(model: torch.nn.Module, noises: np.array, dist: GaussianMixture):
x_log = []
noised_x_log = []
loss_log = []
criterion = nn.MSELoss()

num_epoch = 10000
for epoch in range(num_epoch):

data_x = dist.sample(size=2000)
sigma = np.random.choice(noises)
noised_x = np.random.normal(data_x, sigma, size=data_x.shape)

y_tensor = torch.from_numpy((data_x - noised_x)/sigma/sigma).float()
x_tensor = torch.from_numpy(np.insert(noised_x, data_x.shape[1], sigma, axis=1)).float()
pred = model(x_tensor)

loss = criterion(pred, y_tensor)*sigma*sigma
loss.backward()
optimizer.step()

loss_log.append((epoch+1, loss.item()))
x_log.append(data_x)
noised_x_log.append(noised_x)

if (epoch+1) % 100 == 0:
print(f'Epoch[{epoch+1}/{num_epoch}], Loss: {loss.item():.4f}')

return model, loss_log, x_log, noised_x_log

def model_score(x_1_grid, x_2_grid, model: MLP, noise):
x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)
x_tensor = torch.from_numpy(np.insert(x_point_arr, x_point_arr.shape[1], noise, axis=1)).float()
v = model(x_tensor).detach().numpy()
return v.reshape((x_1_grid.shape[0], x_1_grid.shape[1], 2))

def display_model_score(x_1_line: np.array, x_2_line: np.array, model, noise):
x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
score = model_score(x_1_grid, x_2_grid, model, noise)*0.0001

plt.quiver(x_1_grid, x_2_grid, score[:,:,0], score[:,:,1], color='blue',angles='xy')

def compare_model_score(model, noises):
for noise in noises:
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(GaussianMixture(noise=noise), x_1_line, x_2_line)

x_1_line = np.linspace(-5, 10, 20)
x_2_line = np.linspace(-5, 10, 20)
display_score(GaussianMixture(noise=noise), x_1_line, x_2_line)
display_model_score(x_1_line, x_2_line, model, noise)

plt.title('noise: ' + str(noise))
plt.show()

if __name__ == "__main__":
model = MLP(3, 2)
noises = np.array([1/4, 1/2, 1, 2, 4, 8])
model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())

compare_model_score(model, noises)


## 結果

おおむねうまくいっているように見えるが、尤度が低い場所でずれが生じている。