5.4 潜在変数モデル - 第5章ベイズ推論プログラミング
はじめに
Pythonでスラスラわかる ベイズ推論「超」入門(赤石 雅典 (著), 須山 敦志 (監修))の5.4節のPyMCコードをNumPyroで書き直しました。
アイリス・データセットのversicolorとvirginicaをがく片の幅sepal_widthだけでクラスタリングします。
それぞれのがく片の幅が正規分布に従うというシンプルな仮定から見事にクラスタリングされる様子を味わうことができます。
フォルダ構造とユーティリティ関数、ライブラリimport
リンク集の記事にフォルダ構造とユーティリティ関数、ライブラリimportを掲載しました。
準備としてそちらのページをご覧ください。
5.4 潜在変数モデル
5.4.1 問題設定
versicolor と virginica のがく片の幅の分布から2つの正規分布の混在比率を同時に求めます。
5.4.2 データ準備
がく片の長さについて調べるので目的変数の
# アイリスデータセットの読み込み
df = sns.load_dataset('iris')
# 花の種類をsetosa以外の2種類に絞り込む
df_exclude_setosa = df.query('species != "setosa"')
# インデックスを0から振り直す
df_exclude_setosa = df_exclude_setosa.reset_index(drop=True)
# petal_widthの項目値をx_dataにセット
Y = jnp.array(df_exclude_setosa['petal_width'].values, dtype = float)
display(Y)
# 色分けしないでプロットする。
bins = np.arange(0.8, 3.0, 0.1)
fig, ax = plt.subplots()
sns.histplot(bins=bins, x=Y)
ax.set_xlabel('petal_width')
ax.xaxis.set_tick_params(rotation=90)
ax.set_title('petal_widthのヒストグラム')
ax.set_xticks(bins)
plt.tight_layout()
plt.show()

# 花の種類の答えのプロット
bins = np.arange(0.8, 3.0, 0.1)
fig, ax = plt.subplots()
sns.histplot(data=df_exclude_setosa, bins=bins, x='petal_width',
hue='species', kde=True)
ax.xaxis.set_tick_params(rotation=90)
ax.set_title('petal_widthのヒストグラム')
ax.set_xticks(bins);

5.4.3 確率モデル定義
確率モデルをプログラミングするために、前章と同様に数式を使って状況を整理します。
まず、ヒストグラムよりがく片の長さは花の種類ごとの正規分布に従うと仮定します。
参考書籍で標準偏差の逆数である精度
正規分布のパラメータは花の種類ごとに次の確率分布に従うと仮定します。
ここで花の種類の序数はベルヌーイ分布に従うと仮定します。
このベルヌーイ分布の確率の事前分布は一様分布
ここまで複雑だと数式がごちゃごちゃするので、一気にプログラミングを行います。
ここまでの仮定を、後ろの方から記述します。
def model_latent_variable_models(Y = None, N = None, n_groups = None):
'''
5.4節の2種類の花のがく片の幅の潜在変数モデル
'''
# 花の種類を決めるベルヌーイ分布の確率の事前分布は一様分布 $[0,1]$ と仮定します
p = numpyro.sample("p", dist.Uniform(low = 0, high = 1))
# 花の種類の序数はベルヌーイ分布に従うと仮定します
with numpyro.plate("N", N):
s = numpyro.sample("s", dist.Bernoulli(probs = p))
# がく片の長さは花の種類ごとの正規分布に従うと仮定します
with numpyro.plate("group", n_groups):
# 正規分布のパラメータは花の種類ごとに次の確率分布に従うと仮定します
μ_s = numpyro.sample("μ_s", dist.Normal(loc = 0, scale = 10))
τ_s = numpyro.sample("τ_s", dist.HalfNormal(scale = 10))
# NumPyroの確立分布では標準偏差が必要なので逆数を計算します。
σ_s = numpyro.deterministic("σ_s", jnp.sqrt(1.0 / (τ_s + 0.001)))
# がく片の長さは花の種類ごとの正規分布に従うと仮定します。
# ベクトル化(学習用データを確率変数に割り当てるためのNumPyroのお作法)
with numpyro.plate("N", N):
numpyro.sample("Y", dist.Normal(loc = μ_s[s], scale = σ_s[s]), obs = Y)
作成したモデルをプロットします。
モデル化するとパラメータの関係性が分かります。
model_args = {
"Y": Y,
"N": len(Y),
"n_groups": 2,
}
try_render_model(model_latent_variable_models, render_name = "潜在変数モデル", **model_args)

5.4.4 サンプリングと結果分析
データを用意してモデルを作成したら後はユーティリティ関数に渡すだけです。
model_args = {
"Y": Y,
"N": len(Y),
"n_groups": 2,
}
idata = run_mcmc(
model_latent_variable_models,
num_chains = 1,
num_warmup = 2000,
num_samples = 1000,
thinning = 1,
seed = 42,
target_accept_prob = 0.99,
log_likelihood = False,
**model_args
)
結果分析のコードは書籍とほぼ同じです。
まずはサンプリングが上手くいったか確認します。
az.plot_trace(idata, compact = False, var_names = ["p", "μ_s", "σ_s"])
plt.tight_layout()
知りたかったそれぞれの正規分布のパラメータをプロットします。
plt.rcParams['figure.figsize']=(6,6)
az.plot_posterior(idata, var_names = ["μ_s", "σ_s"])
plt.tight_layout();

summary = az.summary(idata, var_names = ["μ_s", "σ_s"])
display(summary)

5.4.5 ヒストグラムと正規分布関数の重ね描き
ヒストグラムにベイズ推論で求めた正規分布を重ねてプロットします。
がく片の幅のデータだけでうまくクラスタリングできたことがわかります。
# 正規分布関数の定義
def norm(x, mu, sigma):
return np.exp(-((x - mu)/sigma)**2/2) / (np.sqrt(2 * np.pi) * sigma)
# 推論結果から各パラメータの平均値を取得
mean = summary['mean']
# muの平均値取得
mean_mu0 = mean['μ_s[0]']
mean_mu1 = mean['μ_s[1]']
# sigmaの平均値取得
mean_sigma0 = mean['σ_s[0]']
mean_sigma1 = mean['σ_s[1]']
# 正規分布関数値の計算
x = np.arange(0.8, 3.0, 0.05)
delta = 0.1
y0 = norm(x, mean_mu0, mean_sigma0) * delta / 2
y1 = norm(x, mean_mu1, mean_sigma1) * delta / 2
# ラベルを追加します
label_0 = 'Bayse versicolor' if mean_mu0 < mean_mu1 else 'Bayse virginica'
label_1 = 'Bayse versicolor' if mean_mu0 >= mean_mu1 else 'Bayse virginica'
# グラフ描画
bins = np.arange(0.8, 3.0, delta)
plt.rcParams['figure.figsize']=(6,6)
fig, ax = plt.subplots()
sns.histplot(data=df_exclude_setosa, bins=bins, x='petal_width',
hue='species', kde=True, ax=ax, stat='probability')
ax.get_lines()[1].set_label('KDE versicolor')
ax.get_lines()[0].set_label('KDE virginica')
ax.plot(x, y0, c='b', lw=3, label=label_0)
ax.plot(x, y1, c='y', lw=3, label=label_1)
ax.set_xticks(bins);
ax.xaxis.set_tick_params(rotation=90)
ax.set_title('ヒストグラムと正規分布関数の重ね描き')
plt.legend();

5.4.6 潜在変数の確率分布
省略
終わりに
潜在変数モデルのベイズ推論を行いました。
数式で考えるのは難しそうですね。
Discussion