🐡

【拡散モデル】混合ガウス分布とそのスコアをプロット

2023/03/21に公開

拡散モデル本 の勉強。

拡散モデルの多峰性・高次元性に対応するために、スコアという対数尤度の勾配を用いて、モデルを学習する。(本の1章の内容)

そこで、拡散モデルが、多峰性をうまく学習できているかを実装通して、実感したい。

真の分布を混合ガウス分布として、そこからデータをサンプルして拡散モデルを学習することを今後やっていく。まずは、混合ガウス分布を実装するところから。ついでにスコアのベクトル場をプロットしてみる。

混合ガウス分布のスコアの計算

混合ガウス分布のスコアは、プログラム上を見ると少し計算がややこしいが、以下のように書き下せばシンプル。

混合ガウス分布 p は、各i成分の分布をp_i、重みを \pi_i として、

p(x) = \Sigma_i \pi_i p_i(x)

と表すことができる。 p_i は以下のようなd次元正規分布で表されるとする:

p_i(x) = \frac{1}{\sqrt{(2\pi)^d |\Sigma_i|}} \exp\{-\frac{1}{2}(x-\mu_i)^T \Sigma_i^{-1}(x-\mu_i) \} .

xに関する勾配は以下のようになる:

\nabla_x p_i(x) = p_i(x)(-\Sigma^{-1}(x-\mu)) .

\nabla_x p_i(x)を用いることで、スコアは、

\nabla_x \log p(x) = \frac{\nabla_xp(x)}{p(x)} = \frac{\Sigma_i \pi_i \nabla_x p_i(x)}{p(x)}

で計算できる。

コード

今後、ノイズをたたみ込むを見越して、混合ガウス分布をinitする際に、ノイズを渡せるようにしておく。

import japanize_matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal


class GaussianMixture:
    def __init__(self, noise=0):
        self.components = [
            multivariate_normal(
                np.array([3, 7]), np.array([[4, -1.2], [-1.2, 1]]) + noise*np.eye(2)
            ),
            multivariate_normal(
                np.array([1, 4]), np.array([[2, 1], [1, 2]]) + noise*np.eye(2)
            ),
            multivariate_normal(
                np.array([7, 0]), np.array([[1, -0.5], [-0.5, 1]]) + noise*np.eye(2)
            ),
        ]
        self.weights = [0.5, 0.3, 0.2]

        assert sum(self.weights) == 1

    def pdf(self, x_1_grid, x_2_grid):
        x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)

        dens = 0
        for weight, component in zip(self.weights, self.components):
            dens += component.pdf(x_point_arr) * weight

        return dens.reshape(x_1_grid.shape)


    def score(self, x_1_grid, x_2_grid):
        x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)

        pdf = self.pdf(x_1_grid, x_2_grid)

        score = 0
        for weight, component in zip(self.weights, self.components):
            p_grad_inner = np.array([-np.matmul(np.linalg.inv(component.cov), (v - component.mean)) for v in x_point_arr])
            p = component.pdf(x_point_arr).reshape((p_grad_inner.shape[0], 1))

            p_grad = p*p_grad_inner
            score += weight*p_grad

        score = score / pdf.reshape((score.shape[0], 1))

        return score.reshape((x_1_grid.shape[0], x_1_grid.shape[1], 2))


    def sample(self, size):
        k = np.random.choice(len(self.weights), p=self.weights)
        return self.components[k].rvs(size=size)


def display_pdf(dist: GaussianMixture, x_1_line: np.array, x_2_line: np.array):
    x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
    dens = dist.pdf(x_1_grid, x_2_grid)
    plt.axes().set_aspect('equal', 'datalim')
    plt.contour(x_1_grid, x_2_grid, dens, alpha=0.5)
    plt.title('真の分布')
    plt.xlabel('$x_1$')
    plt.ylabel('$x_2$')

    plt.colorbar() # 等高線の色


def display_score(dist: GaussianMixture, x_1_line: np.array, x_2_line: np.array):
    x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
    score = dist.score(x_1_grid, x_2_grid)*0.0001

    plt.quiver(x_1_grid, x_2_grid, score[:,:,0], score[:,:,1], color='red',angles='xy')


if __name__ == "__main__":
    dist = GaussianMixture()

    x_1_line = np.linspace(-5, 10, 100)
    x_2_line = np.linspace(-5, 10, 100)
    display_pdf(dist, x_1_line, x_2_line)

    x_1_line = np.linspace(-5, 10, 25)
    x_2_line = np.linspace(-5, 10, 25)
    display_score(dist, x_1_line, x_2_line)
    plt.show()

実行結果

スコアを用いてサンプリングをするランジュバン・モンテカルロ法は、確かに尤度が小さい領域では尤度が大きい方向へ向かおうとし、尤度が大きい領域ではあまり動かないということがわかる!

Discussion