MCMC法のギブスサンプリングをやってみよう
本記事の目的
言語化して浅い理解を深くしていく。今回はベイズ推定で用いるサンプリング方法について。ベイズ推定は、ベイズ統計学という言葉もあるように統計の中でもメジャーだと思われる。まず、このベイズ推定をいつも通り理論→実践で理解する。理論についてはこちらの書籍を参考にさせていただいた。(amazonに飛ぶのでご注意ください)
ベイズ理論とは
ベイズ推定の話に移る前に、その根底にあるベイズ理論について簡単に説明する。これは入力したデータを加味し、次のデータを予測する理論である。身近な一例を挙げる。ある試験を受けようとする学生がいる。試験の結果は合格か不合格の2択なので、試験を受ける前の学生の合格確率はとりあえず1/2と置くことできる。そして、試験後になると学生の手ごたえはよく感じていた。すると、学生の合格確率は3/4ぐらいに上がるだろう。このようにデータをもとに次のデータを予測するのがベイズ理論である。
ベイズ推定とは
先述のベイズ理論と絡めると、直前データをもとに次データの確率を推定する手法である。途中式は省くがベイズの定理を用いれば、次のように表せる。
ここで、Dはデータ、
コインの表が出る確率をθとし、4回投げ表→裏→表→裏の順に出た。このときの最も取りうる確率が大きくなるθを推測せよ。
解説:直感的にθ=2/4=0.5と思う方もいると思うが、その通りである。ただし、もう少し深く理解していこう。θ=0.1であれ、θ=0.9であれ、起きにくいかもしれないが「4回投げ表→裏→表→裏となる」確率は0ではない。今回求めるのは0<θ<1で「4回投げ表→裏→表→裏」確率が最大のθである。
まず、1回目のデータ(表)を読み込む。1回目のデータは事前確率を持たないので、事前確率は0と1の間なら何でも良い。今回は1にして計算する。1回目の事後確率
θを0<θ<1で積分すると総和は1となる(規格化)ので、
よってk=2となるので1回目の事後確率
次に2回目のデータ(裏)を読み込む。1回目の事後確率が2回目の事前確率となるので
以降は上記を繰り返すので、途中式は省略。各事後確率は下記の通り。
k=3となるので2回目の事後確率
k=2となるので3回目の事後確率
k=5/2となるので4回目の事後確率
この4回目の事後確率
このように、ある事後確率が次サンプルの事前確率となり、次サンプルの事後確率を求めていく。この事後確率が分かると、次の式を用いることで、求める母数の期待値を求めることができる。
この
区間推定と点推定について
統計を勉強している方は既知のことだと思うが、区間推定では信頼区間を用いて母数を推定している。例えば、95%信頼区間で
最尤推定とベイズ推定の違い
似た点推定に最尤推定という手法がある。最尤推定とは、その名の通り「最も尤もらしい値を推定」する手法。ここでもコイン投げを例に取る。
コインの表が出る確率をpとし、4回投げ表となった回数は2回だった。この事象が最も起きやすい確率pを求めよ。
解説:「4回投げ2回、表となる」確率をL(p)とすると
これを解いていく。指数の場合は両辺を対数でとってやると計算しやすくなる。
両辺をpで微分すると
よって、logL(p)の極値はp=0.5である。計算は省くが0<p<1で(logL(p))''<0となるため、p=0.5でlogL(p)は最大値を取る。また、logL(p)は単調増加なので、L(p)もp=0.5で最大値を取る。よって最も「4回投げ2回、表となる」確率(L(p))が大きくなるのはp=0.5。
以上から分かるようにベイズ推定は事前情報を使うのに対し、最尤推定は事前情報を使わない。
簡単にベイズ推定を行うには
上記のベイズ推定の例題は、たまたま簡単にベイズ推定できたが、本来ならばこうはいかない。EAP推定量
事前分布 | 尤度 | 事後分布 |
---|---|---|
ベータ分布 | 二項分布 | ベータ分布 |
正規分布 | 正規分布 | 正規分布 |
逆ガンマ分布 | 正規分布 | 逆ガンマ分布 |
ガンマ分布 | ポアソン分布 | ガンマ分布 |
※このような事前分布を尤度の自然な共役分布という。
上記の例題では事前分布がベータ分布、尤度が二項分布であったため、事後分布がベータ分布となり、繰り返し処理を行うだけで良かった。
複雑でもベイズ推定を行うには
現実世界では、上記のようなケースは稀である。コンピューターでできる範囲で計算が複雑になってもよいから、ベイズ推定を行うには後述するMCMC法(マルコフ連鎖モンテカルロ法)という乱数発生アルゴリズムを用いる。この手法は、マルコフチェーンとモンテカルロ法を組み合わせた手法であり、サンプリング手法の1種である。この手法を用いることで、
マルコフチェーンとは
事後データは、前データの中でも直前のデータのみに依存するモデルを指す。数学的に書くなら、時刻t+1のデータ
モンテカルロ法とは
確率で重みづけした乱数を用いて、データサンプリングを行う。モンテカルロ法については下記リンクで簡単にまとめている。
MCMC法の代表的な手法:ギブスサンプリング
MCMC法にはギブスサンプリングやモンテカルロ法など、さらに細分化されている。本記事ではギブスサンプリングに焦点を当てて説明する。他の手法についても、いずれはまとめたい。
書くにあたって、こちらのサイトを参考にさせていただいた。
多次元のサンプルA
D次元正規分布の式
2次元正規分布でμとΣを自分で定義し、条件付き確率が正規分布になるか確かめよう。
解説:各変数を次のように定義する。
記載はしないが、逆行列の公式
が成り立つことを用いて、
時刻tの時、サンプルが
ここで、条件付き確率
分母の
ここで
ガウス積分より
以上より☆式は
これは、
バーンインについて
MCMC法では、求めるパラメータの初期値は、そのパラメータの取りうる値ならばどんな値でも良いという特徴がある。とてもありがたい特徴なのだが、あまりにも母数から違いすぎると、求めるデータに影響を与える場合がある。そこで、MCMC法の最初の試行回数の何個かは省きましょうというのがバーンインという。下にサンプルの平均値(mu)10、標準偏差(sigma)8の正規分布データからギブスサンプリングした結果をバーンイン無し・有りでの推定結果を記す。
最初のデータだけ値が大きく違う
左:バーンイン無し、右:最初の100データをバーンイン
バーンインの数はデータによりけりなので、色々な値で試してみる必要がある。これについても後ほど実装してみる。
理論:ギブスサンプリングをやってみよう
では、ギブスサンプリングをやってみよう。といっても上記の例題と同様にパラメータの条件付き確率がどのような分布を取るか確認し、それをプログラミングするだけ。今回は既に持っているサンプル
<既に持っているサンプルの分布>
<既に持っているサンプルの確率分布>
<初期値>
最初に
ここで分子
よって、
最後の式ではガウス分布の式に合わせるために無理やり定数
次に
この件は下記のリンク参照した。
では、どう解くか。expの式の場合は、とりあえず対数をとってみるのが定石。また、
ここでガンマ分布
ここで
よって、
実践:ギブスサンプリングをやってみよう
繰り返しになるが、こちらのサイトを参考にしている。
今回はモデルの評価や途中データの詳細な確認等は行わないためプログラムから省いているが、興味がある方は上記のサイトを是非参考にしてほしい。
まずは環境構成とディレクトリ構成図。
Python:3.9.13
VSCode:1.76.2
GIBBS-SAMPLING
├ backend.py
└ frontend.py
まずは、backend.pyのコード。
import numpy as np
import pandas as pd
def sample_mu(y, N, sigma):
#平均
mean = np.sum(y) / N
#分散
variance = sigma * sigma / N
return np.random.normal(mean, np.sqrt(variance))
def sample_sigma(y, N, mu):
#アルファ
alpha = N / 2 + 1
#残差(yi-mu)
residuals = y - mu
#ベータ
beta = np.sum(residuals * residuals) / 2
#タウ
tau = np.random.gamma(alpha, 1/beta)
return (1 / (np.sqrt(tau)))
def model1(y, iters, init):
mu = init["mu"]
sigma = init["sigma"]
N = len(y)
trace = np.zeros((iters, 2))
for i in range(iters):
mu = sample_mu(y, N, sigma)
sigma = sample_sigma(y, N, mu)
trace[i, :] = np.array((mu, sigma))
trace = pd.DataFrame(trace)
trace.columns = ['mu', 'sigma']
return trace
-
np.random.normalで何をやっているか
事後分布である正規分布の確率密度関数f(x)から乱数を発生させている。事後分布が正規分布になるものを用いた理由は、乱数発生をnp.random.normal()
だけでできるためである。この乱数はf(x)の値に比例してxを抽出している。(=f(x1)大ならばx1は乱数抽出されやすい)今回は省略しているが、第三引数は乱数抽出数を表す。省略しているので、乱数は1つしか抽出されない。
-
tau = np.random.gamma(alpha, 1/beta)で何をやっているか
numpyではガンマ関数を下記の通り設定してある。
理論編で述べたガウス関数と比較すると、
次にfrontend.pyのコード。
import matplotlib.pyplot as plt
import numpy as np
from backend import model1
import streamlit as st
with st.form(key="gibbs-sampling"):
iters:int=st.slider("試行回数", 0, 10000, 0, 100)
size:int=st.slider("現時点のサンプル数", 0, 10000, 0, 10)
MU:int=st.number_input("サンプルの平均値",0)
SIGMA:int=st.number_input("サンプルの標準偏差",0)
burn_in:int=st.number_input("バーンイン数",0)
submit_button=st.form_submit_button("ギブスサンプリング")
my_bar = st.progress(0)
if submit_button:
plt.rcParams['figure.figsize'] = (16, 8)
np.random.seed(11)
init = {"mu": np.random.uniform(-10000, 10000), "sigma": np.random.uniform(0, 10000)}
my_bar.progress(20)
y = np.random.normal(MU, SIGMA, size=size)
trace = model1(y, iters, init)
fig1 = plt.figure()
ax1 = fig1.add_subplot(2, 1, 1)
ax2 = fig1.add_subplot(2, 1, 2)
trace.plot(ax=ax1)
ax1.legend(loc='upper right', prop={'size': 15})
kwargs = dict(histtype='stepfilled', alpha=0.3, density=True, bins=60, ec="k")
my_bar.progress(40)
ax2.hist(trace['mu'], label='mu', **kwargs)
ax2.hist(trace['sigma'], label='sigma', **kwargs)
ax2.legend(loc='upper right', prop={'size': 15})
row_data=trace.describe()
st.markdown("#### データ推移と事後確率密度分布")
st.pyplot(fig1)
st.markdown("#### データ表")
st.write(row_data)
burn_in_trace=trace[burn_in:]
fig2 = plt.figure()
ax3 = fig2.add_subplot(2, 1, 1)
ax4 = fig2.add_subplot(2, 1, 2)
burn_in_trace.plot(ax=ax3)
ax3.legend(loc='upper right', prop={'size': 15})
my_bar.progress(60)
kwargs = dict(histtype='stepfilled', alpha=0.3, density=True, bins=60, ec="k")
my_bar.progress(80)
ax4.hist(burn_in_trace['mu'], label='mu', **kwargs)
ax4.hist(burn_in_trace['sigma'], label='sigma', **kwargs)
ax4.legend(loc='upper right', prop={'size': 15})
burn_in_data=burn_in_trace.describe()
st.markdown("#### データ推移と事後確率密度分布(バーンイン込み)")
st.pyplot(fig2)
st.markdown("#### データ表(バーンイン込み)")
st.write(burn_in_data)
my_bar.progress(100)
-
itersとsizeは何か
itersはギブスサンプリングの試行回数、つまり求める事後確率密度分布を何個の点で作るかを示している。sizeは現時点で獲得しているサンプル数を示している。
-
my_bar = st.progress(0)で何をやっているか
プログレスバーの表示を今回は導入。計算に時間がかかるときは、これを入れておくとよいだろう。使い方としては、my_bar = st.progress(0)
を入れた後のコードの各部分にmy_bar.progress(20)
、my_bar.progress(40)
、・・・、と入れていき、最後の処理完了時にmy_bar.progress(100)
を入れることでプログレスバーが完成する。
-
plt.rcParams['figure.figsize'] = (16, 8)で何をやっているか
rcParamsで以降のmatplotlibの設定がすべて変わる。図を複数同じ形で出力するときは初めにplt.rcParams[]
を行うとよい。
-
init = {"mu": np.random.uniform(-10000, 10000), "sigma": np.random.uniform(0, 10000)}で何をやっているか
np.random.uniform(下限値,上限値)
を用いることで、初期値を一様分布から乱数発生させている。
-
ax1 = fig1.add_subplot(2, 1, 1)で何をやっているか
add_subplot()
でfigの表示方法を指定している。第一引数が行の分割数、第二引数が列の分割数を表しており、今回の場合は2×1表示している。第三引数は通し番号を表しているが、ここが1から採番していないと下記のエラーが発生する。
#第三引数に3を入力したとき
num must be an integer with 1 <= num <= 2, not 3
-
kwargs = dict(histtype='stepfilled', alpha=0.3, density=True, bins=60, ec="k")で何をやっているか
kwargs(ケーダブルアーグス)は複数のキーワード引数を持つデータを辞書型で受け取る際に用いられる。下記のリンクから分かるようにグラフの設定引数は多く存在するため、kwargsを用いることで、見やすくなる。
-
burn_in_trace=trace[burn_in:]で何をやっているか
序盤に述べたバーンイン部分をカットしている
結果:ギブスサンプリングをやってみよう
初期画面は下のようになるはず。
streamlit初期画面
試しにやってみよう。
実行画面
出力結果①
出力結果②
無限母集団のデータであっても、100個のデータ数だけで母数の確率密度関数が割り出せることが分かった。また、色々なデータで試してみると、当たり前のことかもしれないが下記のことが分かる。
・サンプル数が少ないと母数の確率密度関数の分散が大きくなるので、サンプル数が増えるたびにギブスサンプリングを行うとよい
・ギブスサンプリングの試行回数を増やしてみると母数の確率密度関数がより滑らかになる
・バーンインの範囲が小さくても、良い精度のサンプリングができる。
皆さんも色々なデータで試してみてほしい。
まとめ
ギブスサンプリングの素晴らしさを理解してもらえただろうか。長く書いてしまったが、やっていることは「事前データから求める一つの母数の事後確率を求める→事後確率から疑似母数を発生させる→異なる一つの母数の事後確率を求める→事後確率から疑似母数を発生させる→・・・」である。実サンプルのサンプリングが難しくても、ギブスサンプリングを用いれば疑似サンプルを生み出せるのは良いことである。一つ欠点を挙げるとしたら、事後確率が分かりやすい分布でないと疑似母数を発生させることが難しいことである。これを補っているのがメトロポリス法であるがまた今度。お疲れさまでした。
Discussion