NumPyroによるベイジアン線形単回帰の実装
はじめに
確率モデリングにハマっているkajyuuenです。最近PyroでMCMCを実行したときの遅さに耐えかねて、NumPyroに入門しました。そこで本記事では勉強も兼ねて、NumPyroでベイジアン線形単回帰を実装していきます。また可視化にはベイジアンモデリングの可視化ライブラリであるarvizを利用します。
ソースコードはGitHubに公開しています。
トイデータの作成
今回、利用するデータには疑似データを使います。
具体的には以下のような数式で表現出来るデータについて考えます。
では、データを作成していきましょう。
まず、必要となるライブラリ群を準備します。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import numpyro
import numpyro.distributions as dist
import jax
plt.style.use('seaborn-darkgrid')
次にNumpyを用いてデータを生成します。
N = 100
alpha_real, beta_real = 0.9, 5
epsilon_real = np.random.normal(0, 0.5, N)
x = np.random.normal(10, 1, N)
y_real = alpha_real * x + beta_real
y = y_real + epsilon_real
生成したデータをプロットした結果は次のようになります。
ベイジアン線形単回帰モデル
生成過程
ではベイジアン線形単回帰モデルについて考えていきましょう。
線形単回帰は傾き
この線形単回帰を確率を用いて、表現すると次のようになります。
すなわちデータ系列
これらの事前分布と確率モデルのグラフィカルモデルは以下のように書けます。
これをNumPyroで実装します。
def model(x, y = None):
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
beta = numpyro.sample("beta", dist.Normal(0, 10))
eplison = numpyro.sample("eplison", dist.HalfCauchy(5))
mu = numpyro.deterministic("mu", alpha * x + beta)
with numpyro.plate("data", size = len(x)):
y_pred = numpyro.sample("y_pred", dist.Normal(mu, eplison), obs=y)
推論
続いてMCMCによるサンプリングを行います。今回はNo-U-turn sampler(NUTS)を用います。
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_samples=1000, num_warmup=300, num_chains=4, chain_method="parallel")
mcmc.run(jax.random.PRNGKey(0), x, y)
サンプリング結果をarvizを用いて可視化しましょう。
az.plot_trace(mcmc, var_names=["alpha", "beta", "eplison"])
左がカーネル密度推定のグラフです。中心極限定理によって正規分布に似たグラフになっていることが確認できると思います。
右側が書くサンプリングで得た値になります。規則性、周期性が見られないグラフになっていて、上手く混合されている様子がわかります。
またprint_summary()
を用いると、主要な統計量を見ることが出来ます。
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
alpha 0.95 0.05 0.95 0.86 1.04 1090.27 1.01
beta 4.56 0.54 4.56 3.71 5.44 1089.63 1.01
eplison 0.51 0.04 0.51 0.45 0.57 1546.22 1.00
Number of divergences: 0
予測分布
次は事後分布からのサンプリングを元に予測分布を生成してみましょう。
まずget_samples()
を用いて、事後分布から各パラメータ
posterior_samples = mcmc.get_samples()
続いて、サンプリングしたパラメータ
posterior_predictive = numpyro.infer.Predictive(model, posterior_samples)(
jax.random.PRNGKey(1), x
)
真のデータのカーネル密度推定と、事後分布から計算されたカーネル密度推定をプロットしてみます。
_, ax = plt.subplots()
for y_pred in posterior_predictive["y_pred"][0:50]:
sns.kdeplot(y_pred, alpha=0.1, c='r', ax=ax)
# real data
sns.kdeplot(y, linewidth=3, color='k', ax=ax)
plt.xlabel('$y$', fontsize=16);
真のデータにかなり近いグラフが書けていることが確認できます。
最後に、事後分布からサンプリングされた直線と
def summary(samples):
s_mean, s_std = samples.mean(0), samples.std(0)
site_stats = {
"mean": s_mean,
"std": s_std,
"upper_sigma1": s_mean + s_std,
"lower_sigma1": s_mean - s_std,
"upper_sigma3": s_mean + 3 * s_std,
"lower_sigma3": s_mean - 3 * s_std
}
return site_stats
pred_summary = summary(posterior_predictive["y_pred"])
plt.plot(x, y, 'C0.');
# 直線
alpha_m = posterior_samples['alpha'].mean()
beta_m = posterior_samples['beta'].mean()
plt.plot(x, alpha_m * x + beta_m,
c='k', label='y = {:.2f} * x + {:.2f}'.format(alpha_m, beta_m))
# 範囲
idx = np.argsort(x)
x_ord = x[idx]
plt.fill_between(x_ord, pred_summary["upper_sigma1"][idx], pred_summary["lower_sigma1"][idx], alpha=0.3, color="b")
plt.fill_between(x_ord, pred_summary["upper_sigma3"][idx], pred_summary["lower_sigma3"][idx], alpha=0.1, color="b")
plt.xlabel('$x$', fontsize=16)
plt.ylabel('$y$', fontsize=16, rotation=0)
plt.legend(loc=2, fontsize=14)
いい感じに予測できているのではないでしょうか?
NumPyroはPyroに比べてMCMCがめちゃくちゃ早いので、これからはMCMCを使うならNumPyro、変分ベイズ法やニューラルネット側で色々凝ったことをやりたいのならPyroといった使い分けになるのかなぁ…と勝手に思っていますがどうなんでしょうか?詳しい人がいたら教えて下さい。
参考資料
- Pythonによるベイズ統計モデリング
https://www.kyoritsu-pub.co.jp/bookdetail/9784320113374
Discussion