🐥

【拡散モデル】SBMサンプリング

2023/03/25に公開

拡散モデル本 の勉強。

今回は、p.39に記載されているSBMサンプリングのアルゴリズムを、前回学習したスコアを用いてサンプリングしてみようと思う。

コード

import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from mim import train

from generate_dist import GaussianMixture, display_pdf, display_score
from score_matching import MLP


def sampling(input_size, score, noises):
    BASE_ALPHA = 0.5
    K = 50 # loop num

    x = torch.normal(0, noises[-1], (input_size,1)).squeeze()
    xs = []

    for noise in reversed(noises):
        xs_per_noise = []
        alpha = BASE_ALPHA * noise/noises[-1]
        for k in range(K):
            u = torch.normal(0, 1, (2,1)).squeeze()
            s = score(x, noise)
            x = x + alpha*s + np.sqrt(2*alpha)*u
            xs_per_noise.append(x.detach().numpy())
        xs.append(xs_per_noise)

    return xs


def _calc_true_score(x, noise):
    dist = GaussianMixture(noise=noise)
    x = x.detach().numpy()
    v = dist.score(np.array([[x[0]]]), np.array([[x[1]]]))

    return torch.tensor(v).squeeze()




def plot_trail(model, noises):
    x_1_line = np.linspace(-5, 10, 100)
    x_2_line = np.linspace(-5, 10, 100)
    display_pdf(GaussianMixture(), x_1_line, x_2_line)

    x_1_line = np.linspace(-5, 10, 20)
    x_2_line = np.linspace(-5, 10, 20)
    display_score(GaussianMixture(), x_1_line, x_2_line)

    xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
    for i, xs in enumerate(xs):
         color = 'rgbcmyk'[i]
         plt.plot([x[0] for x in xs], [x[1] for x in xs], c=color, marker='o', markersize=max(6-i,1), alpha=0.5)
    plt.show()


if __name__ == "__main__":
    model = MLP(3, 2)
    noises = np.array([1/4, 1/2, 1, 2, 4, 8])
    model, loss_log, x_log, noised_x_log = train(model, noises, GaussianMixture())

    plot_trail(model, noises)

    sample_result = []
    for t in range(1000):
        xs = sampling(2, lambda x, noise: model(torch.concat((x, torch.tensor([noise], dtype=torch.float)))), noises)
        sample_result.append(xs[-1][-1])

    x_1_line = np.linspace(-5, 10, 100)
    x_2_line = np.linspace(-5, 10, 100)
    display_pdf(GaussianMixture(), x_1_line, x_2_line)

    for i, xs in enumerate(xs):
        plt.scatter([x[0] for x in sample_result], [x[1] for x in sample_result], color='red', marker=".", s=48, alpha=0.05)

    plt.show()

サンプリング1過程の軌跡の表示と、1000回サンプルした時の散布図を表示している。

結果

まずは、サンプリングの過程の結果。赤 -> 緑 -> 青...の順で、ノイズが小さくなっていくが、ノイズが大きい時には、スコアの値を無視して、大きく動くことがわかる。これにより、初期値付近にいた峰を回避して、別の峰に移動できており、多峰性の困難さを克服していることがわかる。ノイズが小さくなるときは、より尤もらしい点をサンプルできるように、サンプルの品質を上げて行っていることがわかる。

実際に1000回サンプリングするとこんな感じで、真の分布に対して上手くいっているように見える。1サンプルあたりに、二重ループを回さなければならないので、結構遅い。

次回は、DDPMの学習とサンプリングに取り組んでみよう!

Discussion