😎

ベイズとMCMC

2023/02/18に公開

はじめに

参考

http://web.sfc.keio.ac.jp/~maunz/BS19/BS19-11.pdf

復習

事前分布・事後分布ってなんだっけ?

観測データxを得たもとでのパラメータが従う分布p(\theta|x)
が事後分布。\thetaは確率分布のパラメータだと思っておけばOK

事後分布とは、ベイズの定理から以下のようにあらわせる

p(\theta|x)=\frac{p(x|\theta)p(\theta)}{p(x)}

ここで
p(\theta)は事前分布、p(x|\theta)は尤度である

尤度はデータに対する当てはまりを表すパラメータの関数。
事前分布はあらかじめ\thetaに対してわかっている事前情報を表す。

MCMCとは

コンピュータで事後分布を求めるためには解析計算が必要。しかし複雑な世の中では
解析的に解けないものがほとんどだ。
そんな事後分布をコンピュータによる力業で求めるための方法がMCMCだ。

MCMCを行う際の流れ

  • 事前分布の値の決めつけ

  • MCMCによるモンテカルロ法

  • 事後分布の計算

  • やるときはBlanchard and Kahn[1980] の条件も満たさないとうまく行かないとのこと。
    どういう意味かは要整理

実例

問題設定

x=np.random.randn(10000)*0.4+3 に対してガウス分布の事後分布を作成。

x_i\sim \mathcal{N}(\mu,{\sigma}^2) \ \ (i=1,2,...,10000)

事前分布としてガウス分布の平均が一様分布
\mu \sim \mathcal{U}(\mu_{min},\mu_{max})

標準偏差が指数に従うものとする。
\sigma \sim \text{Exp}(\lambda)

  1. この事後分布を求めるプログラムを作成せよ
  2. 事後分布の\mu\sigmaのヒストグラムを図示せよ

洞察

事後分布をベイズの定理より定式化すると

p(\mu, {\sigma}^2 | x)=\frac{p(x |\mu, {\sigma}^2) p(\mu) p({\sigma}^2)}{p(x)}

=\frac{p(x |\mu, {\sigma}^2) p(\mu) p({\sigma}^2)}{\iint p(x | \mu, {\sigma}^2) p(\mu) p({\sigma}^2) \mathrm{d} \mu \mathrm{d} {\sigma}^2}

しかし残念ながらこれは解析的に解けないため、MCMCを使う必要がある。

  • 尤度関数 : p(x |\mu, {\sigma}^2)

    p(x|\mu, \sigma^2) = \prod_{i=1}^N \mathcal{N}(x_i|\mu, \sigma^2)

  • 一様分布 : p(\mu)

  • 指数分布 : p(\sigma)

解答

実際に事後分布を求めるためのMCMCをする。
xの与え方から明らかに平均3、標準偏差0.4に従うはず。
事後分布でもこの結果が出ることを、以降行うシミュレーション結果から確認する

準備

まずは、必要なライブラリをインポート、データを生成する

import numpy as np
import pymc3 as pm
import arviz as az
import matplotlib.pyplot as plt

# データの生成
np.random.seed(123)
x = np.random.randn(10000)*0.4+3

今回使うデータxを可視化しておく

# データの可視化
plt.hist(x, bins=50)
plt.show()

確かに正規分布になっている。
今回の事前分布のパラメータは以下のように決めておく(適当)

# 事前分布のハイパーパラメータ
mu_min = 0
mu_max = 6
lam = 2

以降MCMCの操作を記述する

MCMC条件の決定

サンプル数、バーンイン期間、チェーンの状態数を決める。

サンプル数

MCMCでとるサンプル数。十分取らないとまともな値に収束しない。
これは乱数間の自己相関が高くなっていることが原因らしい。
これを調べるための指標があるらしい。(後日まとめます!!)

  • 経験者X曰く、百万回くらいを2分割して各50万とったらしい
  • 経験者Y曰く、十万単位必要であるケースは稀で、基本的に数千あればよいとのこと
    数千くらいやって収束診断しながら増やしていくのが一般的。
    診断方法はこちら↓
    http://web.sfc.keio.ac.jp/~maunz/BS19/BS19-11.pdf
    今回は簡単な問題のはずなので10000程度で様子を見てみる。
バーンイン期間

前述のとおりMCMCはなかなか収束しないものらしい。
MCMCサンプリング初期の段階をバーンイン期間とよび、この段階では最終的なサンプルとして望ましくはないのでこの期間だけサンプルが破棄される。
今回は簡単な問題なので1000でいいだろうと判断。

チェーンの状態数

MCMCアルゴリズムで生成されるサンプルは、一連の状態(シーケンス)として表される。これらの状態のシーケンスを「chain(チェーン)」と呼ぶ。
MCMCアルゴリズムは、複数のチェーンを生成し、それらを平均化することで、より正確な結果を得ることができる。

  • MCMCにおいては、各チェーンは独立している必要がある、すなわち、各チェーンが異なる初期値から開始され、状態の遷移にランダム性が必要
  • 各チェーンの状態数が十分に大きく、バーンイン期間を適切に設定する必要あり。

今回は簡単な問題のはずなので以下の値に設定

N_sampling = 10000 # 事後分布の推定に使用する総サンプル数
tune_period = 1000 # 初期の状態から定常状態に収束するまでの間に行うバーンイン期間
N_chains = 4 # チェーンの状態数

モデルの記述

問題設定のモデルを忠実に記述する

問題設定(再掲)

x=np.random.randn(10000)*0.4+3 に対してガウス分布の事後分布を作成。

x_i\sim \mathcal{N}(\mu,{\sigma}^2) \ \ (i=1,2,...,10000)

事前分布としてガウス分布の平均が一様分布
\mu \sim \mathcal{U}(\mu_{min},\mu_{max})

標準偏差が指数に従うものとする。
\sigma \sim \text{Exp}(\lambda)

これに対応するPymcのプログラムは以下の通り。

# モデルの定義
with pm.Model() as model:
    # 平均の事前分布:一様分布
    mu = pm.Uniform('mu', lower=mu_min, upper=mu_max)
    
    # 分散の事前分布:指数分布
    sigma = pm.Exponential('sigma', lam=lam)
    
    # ガウス分布モデル
    obs = pm.Normal('obs', mu=mu, sd=sigma, observed=x)
    
    # MCMCサンプリング
    trace = pm.sample(N_sampling, tune=tune_period, chains=N_chains, random_seed=123)

結果の可視化

最後にこの事後分布を可視化するための記述を描く。

# 事後分布の可視化
pm.plot_trace(trace)
plt.tight_layout()
plt.show()

pm.plot_posterior(trace, var_names=['mu', 'sigma'], color='purple')
plt.tight_layout()
plt.show()

まとめ

最後に、上記のまとめコードです

全コード
import numpy as np
import pymc3 as pm
import arviz as az
import matplotlib.pyplot as plt

# データの生成
np.random.seed(123)
x = np.random.randn(10000)*0.4+3

# データの可視化
plt.hist(x, bins=50)
plt.show()

# 事前分布のハイパーパラメータ
mu_min = 0
mu_max = 6
lam = 2

# MCMC条件
N_sampling = 10000 # 事後分布の推定に使用する総サンプル数
tune_period = 1000 # 初期の状態から定常状態に収束するまでの間に行うバーンイン期間
N_chains = 4 # チェーンの状態数

# モデルの定義
with pm.Model() as model:
    # 平均の事前分布:一様分布
    mu = pm.Uniform('mu', lower=mu_min, upper=mu_max)
    
    # 分散の事前分布:指数分布
    sigma = pm.Exponential('sigma', lam=lam)
    
    # ガウス分布モデル
    obs = pm.Normal('obs', mu=mu, sd=sigma, observed=x)
    
    # MCMCサンプリング
    trace = pm.sample(N_sampling, tune=tune_period, chains=N_chains, random_seed=123)
    
# 事後分布の可視化
pm.plot_trace(trace)
plt.tight_layout()
plt.show()

pm.plot_posterior(trace, var_names=['mu', 'sigma'], color='purple')
plt.tight_layout()
plt.show()

結果

  • \mu,\sigmaの分布

問題の設定どおり、平均が3、標準偏差が0.4を中心に分布していることがわかる

MAP推定, EAP推定

Discussion