[拡散モデル] 拡散過程を可視化してみる
はじめに
現在、話題になっている拡散モデルでは、観測データに対して、ノイズを徐々に加えて最終的にノイズのみする拡散過程が使われています。
その拡散過程が実際にどの様にノイズに変換するかを確率分布で可視化してみようと思います。
以下の数式では、拡散モデルの元祖であるDDPM(Diffusion Denoising Probabilistic Models)とそろえます。
以下で使う可視化コードは、こちらにあります。
拡散過程の概要
拡散過程は、観測変数にノイズを徐々に加えていき、最終的にノイズのみの分布(
ステップを
データ分布を徐々にノイズ飲みの分布に変化させるために、分散を徐々に大きくする。(
また、
この様に定義すると、
また、この式は、
可視化
本章では、ダミーの観測データを作り、(1)式を通じて、どの様に確率分布が推移するかを可視化する。
観測データを混合1次元ガウス分布から何点かサンプリングし、それらを(1)の式で、拡散させ、それぞれのステップごとの確率分布を可視化する。
観測データの確率分布
2つの1次元ガウス分布を混合した混合ガウス分布からサンプリングしたデータを用いる。
この混合ガウス分布の確率変数を
n=2 -
: n番目のガウス分布の平均\mu_n \in \mathbb{R} -
: n番目のガウス分布の標準偏差\sigma_n \in \mathbb{R}^+ -
: n番目分布の混合割合w_n \sum_{i=1}^{n} n_i = 1
観測データ生成コード
# 全体のサンプルリング数
n = 200_000
# 2つの1次元ガウス分布の定義
mean1, std1, pi1 = 12, 0.4, 0.3
mean2, std2, pi2 = 8, 0.7, 1-pi1
n1 = int(pi1 * n)
n2 = n - n1
# データ生成(観測データ)
data1 = np.random.normal(mean1, std1, n1)
data2 = np.random.normal(mean2, std2, n2)
measured_data = np.concatenate([data1, data2])
# プロット(確率密度を描画する)
plt.figure(figsize=(10, 6))
sns.histplot(measured_data, kde=True, stat='density', linewidth=0, color='blue')
plt.title('生成した観測データ', fontsize=20)
plt.xlim(-5, 20)
plt.grid()
plt.ylabel('Density', fontsize=20)
plt.tight_layout()
plt.savefig('imgs/measured_data.png')
plt.show()
確率分布
拡散過程
拡散のステップ数は、30(
# 総ステップ数(拡散を実施する回数)
T = 30
betas = np.linspace(0.01, 0.99, T)
alphas = 1 - betas
plt.figure(figsize=(8, 6))
plt.plot(betas, marker="o", label=r'$ \beta $', markersize=3, alpha=0.8)
plt.plot(np.sqrt(alphas), marker='o', label=r'$ \sqrt{\alpha} $', markersize=3, alpha=0.8)
plt.xlabel('t', fontsize=30)
plt.title(r'$ \alpha, \beta $の可視化', fontsize=20)
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig('imgs/alpha_and_beta.png')
plt.show()
これらのデータを用いて、拡散過程を実行した結果をバイオリンプロットで可視化する。
この図を見ると、tが増えると、平均が0に近づき、分散も一定になっていくのがみて取れる。
また、観測データ分布の情報はなくなっているのがわかる。
コード
datas = [measured_data]
data_t = measured_data
for t in range(T):
# 拡散する
data_t = np.random.normal(loc=np.sqrt(alphas[t]) * data_t, scale=np.sqrt(betas[t]))
datas.append(data_t)
plt.figure(figsize=(12, 8))
sns.violinplot(datas)
plt.axhline(0, color='red', alpha=0.8, linestyle='--')
plt.title('tごとの確率変数($X_t$)の確率分布', fontsize=20)
plt.xlabel('t', fontsize=30)
plt.grid()
plt.ylabel('$x_t$', fontsize=30)
plt.tight_layout()
plt.savefig("imgs/difusion_process.png")
plt.show()
###(Extra)Tensorboardでの可視化
Tensorboardでは、ヒストグラムの軌跡を見ることができる。
コード
import tensorflow as tf
log_dir = 'logs/diffusion_process/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir=log_dir)
tag_name = 'diffusion_process'
with file_writer.as_default():
# 観測データを記録 (t=0)
tf.summary.histogram(tag_name, measured_data, step=0)
for t, data_t in enumerate(datas, start=1):
tf.summary.histogram(tag_name, data_t, step=t)
# tensorboardを起動して、拡散過程を見る。
# ブラウザからlocalhost:6006にアクセス
!tensorboard --logdir=logs/diffusion_process --port=6006
Discussion