🦔

【拡散モデル】デノイジングスコアマッチングによるスコアの学習

2023/03/22に公開

拡散モデル本 の勉強。

今回は、1.5.5節で説明されるデノイジングスコアマッチングによってスコアを学習してみる。SBMによるサンプリングを見据えて、p.38の式を目的関数とする。

\Sigma_{t=1}^T w_t \mathbb{E}_{x \sim p_{data}(x), \tilde{x}\sim \mathcal{N}(x, \sigma_t^2 I)} \left[\left\| \frac{x-\tilde{x}}{\sigma_t^2} - s_\theta(\tilde{x}, \sigma_t) \right\|^2 \right]

w_t=\sigma_t^2に設定する。

s_\theta は隠れ層を2個持つ単純な多層パーセプトロンを使って学習する。

コード

モデルを学習した後、真のスコアと学習したスコアを比較するためのプロットを表示する。

import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt

from generate_dist import GaussianMixture, display_pdf, display_score


class MLP(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            torch.nn.Linear(input_size, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, output_size),
            # torch.nn.Linear(input_size, output_size),
        )

    def forward(self, x):
        x = self.model(x)
        return x


def train(model: torch.nn.Module, noises: np.array, dist: GaussianMixture):
    x_log = []
    noised_x_log = []
    loss_log = []
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    num_epoch = 10000
    for epoch in range(num_epoch):
        optimizer.zero_grad()

        data_x = dist.sample(size=2000)
        sigma = np.random.choice(noises)
        noised_x = np.random.normal(data_x, sigma, size=data_x.shape)



        y_tensor = torch.from_numpy((data_x - noised_x)/sigma/sigma).float()
        x_tensor = torch.from_numpy(np.insert(noised_x, data_x.shape[1], sigma, axis=1)).float()
        pred = model(x_tensor)

        loss = criterion(pred, y_tensor)*sigma*sigma
        loss.backward()
        optimizer.step()


        loss_log.append((epoch+1, loss.item()))
        x_log.append(data_x)
        noised_x_log.append(noised_x)

        if (epoch+1) % 100 == 0:
            print(f'Epoch[{epoch+1}/{num_epoch}], Loss: {loss.item():.4f}')

    return model, loss_log, x_log, noised_x_log


def model_score(x_1_grid, x_2_grid, model: MLP, noise):
    x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)
    x_tensor = torch.from_numpy(np.insert(x_point_arr, x_point_arr.shape[1], noise, axis=1)).float()
    v = model(x_tensor).detach().numpy()
    return v.reshape((x_1_grid.shape[0], x_1_grid.shape[1], 2))


def display_model_score(x_1_line: np.array, x_2_line: np.array, model, noise):
    x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
    score = model_score(x_1_grid, x_2_grid, model, noise)*0.0001

    plt.quiver(x_1_grid, x_2_grid, score[:,:,0], score[:,:,1], color='blue',angles='xy')


def compare_model_score(model, noises):
    for noise in noises:
        x_1_line = np.linspace(-5, 10, 100)
        x_2_line = np.linspace(-5, 10, 100)
        display_pdf(GaussianMixture(noise=noise), x_1_line, x_2_line)

        x_1_line = np.linspace(-5, 10, 20)
        x_2_line = np.linspace(-5, 10, 20)
        display_score(GaussianMixture(noise=noise), x_1_line, x_2_line)
        display_model_score(x_1_line, x_2_line, model, noise)

        plt.title('noise: ' + str(noise))
        plt.show()


if __name__ == "__main__":
    model = MLP(3, 2)
    noises = np.array([1/4, 1/2, 1, 2, 4, 8])
    model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())

    compare_model_score(model, noises)

結果

前回描画した真の分布のスコア(赤)と、今回学習したモデルのスコア(青)を比較してみる。

おおむねうまくいっているように見えるが、尤度が低い場所でずれが生じている。
尤度が低い場所からどちらの峰の方向を選ぶかが、真の値のベクトルとずれている。

次回は、SBMサンプリングを実行して、ランダムな初期点からサンプリングしてみよう!

Discussion