🐡
【拡散モデル】混合ガウス分布とそのスコアをプロット
拡散モデル本 の勉強。
拡散モデルの多峰性・高次元性に対応するために、スコアという対数尤度の勾配を用いて、モデルを学習する。(本の1章の内容)
そこで、拡散モデルが、多峰性をうまく学習できているかを実装通して、実感したい。
真の分布を混合ガウス分布として、そこからデータをサンプルして拡散モデルを学習することを今後やっていく。まずは、混合ガウス分布を実装するところから。ついでにスコアのベクトル場をプロットしてみる。
混合ガウス分布のスコアの計算
混合ガウス分布のスコアは、プログラム上を見ると少し計算がややこしいが、以下のように書き下せばシンプル。
混合ガウス分布
と表すことができる。
で計算できる。
コード
今後、ノイズをたたみ込むを見越して、混合ガウス分布をinitする際に、ノイズを渡せるようにしておく。
import japanize_matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal
class GaussianMixture:
def __init__(self, noise=0):
self.components = [
multivariate_normal(
np.array([3, 7]), np.array([[4, -1.2], [-1.2, 1]]) + noise*np.eye(2)
),
multivariate_normal(
np.array([1, 4]), np.array([[2, 1], [1, 2]]) + noise*np.eye(2)
),
multivariate_normal(
np.array([7, 0]), np.array([[1, -0.5], [-0.5, 1]]) + noise*np.eye(2)
),
]
self.weights = [0.5, 0.3, 0.2]
assert sum(self.weights) == 1
def pdf(self, x_1_grid, x_2_grid):
x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)
dens = 0
for weight, component in zip(self.weights, self.components):
dens += component.pdf(x_point_arr) * weight
return dens.reshape(x_1_grid.shape)
def score(self, x_1_grid, x_2_grid):
x_point_arr = np.stack([x_1_grid.flatten(), x_2_grid.flatten()], axis=1)
pdf = self.pdf(x_1_grid, x_2_grid)
score = 0
for weight, component in zip(self.weights, self.components):
p_grad_inner = np.array([-np.matmul(np.linalg.inv(component.cov), (v - component.mean)) for v in x_point_arr])
p = component.pdf(x_point_arr).reshape((p_grad_inner.shape[0], 1))
p_grad = p*p_grad_inner
score += weight*p_grad
score = score / pdf.reshape((score.shape[0], 1))
return score.reshape((x_1_grid.shape[0], x_1_grid.shape[1], 2))
def sample(self, size):
k = np.random.choice(len(self.weights), p=self.weights)
return self.components[k].rvs(size=size)
def display_pdf(dist: GaussianMixture, x_1_line: np.array, x_2_line: np.array):
x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
dens = dist.pdf(x_1_grid, x_2_grid)
plt.axes().set_aspect('equal', 'datalim')
plt.contour(x_1_grid, x_2_grid, dens, alpha=0.5)
plt.title('真の分布')
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.colorbar() # 等高線の色
def display_score(dist: GaussianMixture, x_1_line: np.array, x_2_line: np.array):
x_1_grid, x_2_grid = np.meshgrid(x_1_line, x_2_line)
score = dist.score(x_1_grid, x_2_grid)*0.0001
plt.quiver(x_1_grid, x_2_grid, score[:,:,0], score[:,:,1], color='red',angles='xy')
if __name__ == "__main__":
dist = GaussianMixture()
x_1_line = np.linspace(-5, 10, 100)
x_2_line = np.linspace(-5, 10, 100)
display_pdf(dist, x_1_line, x_2_line)
x_1_line = np.linspace(-5, 10, 25)
x_2_line = np.linspace(-5, 10, 25)
display_score(dist, x_1_line, x_2_line)
plt.show()
実行結果
スコアを用いてサンプリングをするランジュバン・モンテカルロ法は、確かに尤度が小さい領域では尤度が大きい方向へ向かおうとし、尤度が大きい領域ではあまり動かないということがわかる!
Discussion