😇

NumPyroによるベイジアン線形単回帰の実装

2020/10/01に公開

はじめに

確率モデリングにハマっているkajyuuenです。最近PyroでMCMCを実行したときの遅さに耐えかねて、NumPyroに入門しました。そこで本記事では勉強も兼ねて、NumPyroでベイジアン線形単回帰を実装していきます。また可視化にはベイジアンモデリングの可視化ライブラリであるarvizを利用します。

ソースコードはGitHubに公開しています。
https://github.com/kajyuuen/playground

トイデータの作成

今回、利用するデータには疑似データを使います。
具体的には以下のような数式で表現出来るデータについて考えます。

y_i = 0.9x_i + 5 + \epsilon_{\rm real}
\epsilon_{\rm real} \sim {\mathcal N}(0, 0.5)

では、データを作成していきましょう。
まず、必要となるライブラリ群を準備します。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import numpyro
import numpyro.distributions as dist
import jax
plt.style.use('seaborn-darkgrid')

次にNumpyを用いてデータを生成します。

N = 100
alpha_real, beta_real = 0.9, 5
epsilon_real = np.random.normal(0, 0.5, N)

x = np.random.normal(10, 1, N)
y_real = alpha_real * x + beta_real
y = y_real + epsilon_real

生成したデータをプロットした結果は次のようになります。

ベイジアン線形単回帰モデル

生成過程

ではベイジアン線形単回帰モデルについて考えていきましょう。
線形単回帰は傾き\alpha, 切片を\betaとしたとき、各変数y_iについて次のように書けます。

y_i = \alpha x_i + \beta

この線形単回帰を確率を用いて、表現すると次のようになります。

{\boldsymbol y} \sim {\mathcal N}(\mu = \alpha {\boldsymbol x} + \beta, \sigma = \epsilon)

すなわちデータ系列{\boldsymbol y}は平均\alpha {\boldsymbol x}+ \beta 、標準偏差\epsilonの正規分布に従うと仮定できます。ただし各パラメータ\alpha, \beta, \epsilonについて、実際の値はわからないので、次のような事前分布を仮定します。この事前分布は適当な分布であれば何でもいいです。

\alpha \sim {\mathcal N}(\mu_\alpha, \sigma_\alpha)
\beta \sim {\mathcal N}(\mu_\beta, \sigma_\beta)
\epsilon \sim {\rm HalfCauchy(\sigma_\epsilon)}

これらの事前分布と確率モデルのグラフィカルモデルは以下のように書けます。

これをNumPyroで実装します。

def model(x, y = None):
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
    beta = numpyro.sample("beta", dist.Normal(0, 10))
    eplison = numpyro.sample("eplison", dist.HalfCauchy(5))
    
    mu = numpyro.deterministic("mu", alpha * x + beta)
    
    with numpyro.plate("data", size = len(x)):
        y_pred = numpyro.sample("y_pred", dist.Normal(mu, eplison), obs=y)

推論

続いてMCMCによるサンプリングを行います。今回はNo-U-turn sampler(NUTS)を用います。

kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_samples=1000, num_warmup=300, num_chains=4, chain_method="parallel")
mcmc.run(jax.random.PRNGKey(0), x, y)

サンプリング結果をarvizを用いて可視化しましょう。

az.plot_trace(mcmc, var_names=["alpha", "beta", "eplison"])

左がカーネル密度推定のグラフです。中心極限定理によって正規分布に似たグラフになっていることが確認できると思います。
右側が書くサンプリングで得た値になります。規則性、周期性が見られないグラフになっていて、上手く混合されている様子がわかります。

またprint_summary()を用いると、主要な統計量を見ることが出来ます。

mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      0.95      0.05      0.95      0.86      1.04   1090.27      1.01
      beta      4.56      0.54      4.56      3.71      5.44   1089.63      1.01
   eplison      0.51      0.04      0.51      0.45      0.57   1546.22      1.00

Number of divergences: 0

予測分布

次は事後分布からのサンプリングを元に予測分布を生成してみましょう。
まずget_samples()を用いて、事後分布から各パラメータ\alpha, \beta, \epsilonをサンプリングしましょう。

posterior_samples = mcmc.get_samples()

続いて、サンプリングしたパラメータ\alpha, \beta, \epsilonを用いて{\boldsymbol y}の予測分布を求めましょう。

posterior_predictive = numpyro.infer.Predictive(model, posterior_samples)(
    jax.random.PRNGKey(1), x
)

真のデータのカーネル密度推定と、事後分布から計算されたカーネル密度推定をプロットしてみます。

_, ax = plt.subplots()
for y_pred in posterior_predictive["y_pred"][0:50]:
    sns.kdeplot(y_pred, alpha=0.1, c='r', ax=ax)
# real data
sns.kdeplot(y, linewidth=3, color='k', ax=ax)
    
plt.xlabel('$y$', fontsize=16);

真のデータにかなり近いグラフが書けていることが確認できます。

最後に、事後分布からサンプリングされた直線と\pm \sigma, \pm 3\sigmaの不確実性を帯で表現したグラグを記してみます。

def summary(samples):
    s_mean, s_std = samples.mean(0), samples.std(0)
    site_stats = {
        "mean": s_mean,
        "std": s_std,
        "upper_sigma1": s_mean + s_std,
        "lower_sigma1": s_mean - s_std,
        "upper_sigma3": s_mean + 3 * s_std,
        "lower_sigma3": s_mean - 3 * s_std
    }
    return site_stats
pred_summary = summary(posterior_predictive["y_pred"])

plt.plot(x, y, 'C0.');

# 直線
alpha_m = posterior_samples['alpha'].mean()
beta_m = posterior_samples['beta'].mean()
plt.plot(x, alpha_m * x + beta_m,
         c='k', label='y = {:.2f} * x + {:.2f}'.format(alpha_m, beta_m))

# 範囲
idx = np.argsort(x)
x_ord = x[idx]
plt.fill_between(x_ord, pred_summary["upper_sigma1"][idx], pred_summary["lower_sigma1"][idx], alpha=0.3, color="b")
plt.fill_between(x_ord, pred_summary["upper_sigma3"][idx], pred_summary["lower_sigma3"][idx], alpha=0.1, color="b")



plt.xlabel('$x$', fontsize=16)
plt.ylabel('$y$', fontsize=16, rotation=0)
plt.legend(loc=2, fontsize=14)

いい感じに予測できているのではないでしょうか?
NumPyroはPyroに比べてMCMCがめちゃくちゃ早いので、これからはMCMCを使うならNumPyro、変分ベイズ法やニューラルネット側で色々凝ったことをやりたいのならPyroといった使い分けになるのかなぁ…と勝手に思っていますがどうなんでしょうか?詳しい人がいたら教えて下さい。

参考資料

Discussion