Turing.jlで統計モデリング
はじめに
Juliaでベイジアンモデリングをするためのツールを調べています。JuliaではTuring.jlというパッケージが有名です。
こちらの記事などでも詳しく紹介されています。
ベイジアンなデータ分析というと概ね以下のような手順になります。なおMCMCで推論することを想定しています。
- データの準備
- 確率モデルの定義
- 事前分布の決定
- 推論
- MCMCの収束チェック
- 結果の解釈
この一連の手順でTuringをどう使うかをまとめてみたいと思います。
データの準備
以下のデータを使います。様々な条件(サイズや密度・捕食者の有無など)におけるオタマジャクシの生存数を調べた実験データです。
まずはデータをロードします。
using HTTP
using CSV
using DataFrames
resp = HTTP.get("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/reedfrogs.csv").body;
df = CSV.read(resp, DataFrame, header=1, delim=';')
first(df, 10)
全部で48行のテーブルデータで、各行には初めのオタマジャクシ数 (density
) と一定時間後の生存数 (surv
) が記録されています。データの各行を「タンク」と呼ぶことにします。[1]
各タンクでの生存率をグラフにすると以下のようになります。
タンクのdensityで三種類に分類しています(縦の点線)。左のブロックが一番小さく、右のブロックが一番大きいです。平均生存率は75%程度ですが、図を見ればわかる通り、タンクごとの差の大きいデータセットになっています。これは実験の介入の有無だけでなく、観測しきれていないタンク固有の環境差からも来ています。今回はこのデータを使って介入 (pred
, size
) の効果を分析してみます。
確率モデルの定義
以下のようなロジスティック回帰モデルを定義します。後で見るように、上記データはタンクごとの環境差が大きいのでこのモデルでは予測精度は良くありません。
ここで
-
はi番目のタンクの捕食者の有無を表し、\mathrm{predator}[i] pred="pred"
なら2、そうでなければ1をとるとします -
はi番目のタンクのオタマジャクシの大きさを表し、\mathrm{size}[i] size="big"
なら2、そうでなければ1をとるとします -
はタンクを表します。j=1,\dots,48 はタンクy_j の生存数です。j
Turingではこのモデルは以下のような直感的な記法で定義することができます。
using Turing
using Distributions
sigmoid(x) = 1 / (1 + exp(-x))
@model function logistic_reg(x, y)
α ~ Normal(0, 1.5)
β ~ MvNormal(2, 1.0)
γ ~ MvNormal(2, 1.0)
logit_p = map(xj->α + β[xj.predator] + γ[xj.size], x)
p = sigmoid.(logit_p)
for j in eachindex(x)
y[j] ~ Binomial(x[j].n, p[x[j].i])
end
return
end
確率変数は ~
で定義します。それ以外は普通のJuliaの文法になっているのでわかりやすいです。今回はデータ y
は観測値として与えるため、モデルの引数に入れています。
このモデルに以下のようにしてデータをセットしました。
obs_x = [(i=i, n=df[i, :density], predator=ifelse(df[i, :pred]=="pred", 2, 1), size=ifelse(df[i, :size]=="big", 2, 1)) for i in 1:nrow(df)]
obs_y = df[!, :surv];
model = logistic_reg(obs_x, obs_y)
事前分布の決定
上では天下りに
ベイジアンなデータ分析では、できるだけデータを見る前に事前分布を設計することが望ましいです。ですので初めの方に書いた図は一旦忘れて、事前分布として何が相応しいか考えてみます。私はオタマジャクシの生存率の分布に関して特に専門的な知見を持ち合わせているわけではないので、生存率が一様になるような事前分布が欲しいです。ですので、上で定義した確率モデル logistic_reg
からランダムサンプリングした時に、生存率分布が一様になるかどうかを確かめる必要があります(これをprior predictive check/simulationと言ったりします)。
Turingではこのために事前分布からサンプリングする方法が用意されています
prior_chain = sample(model, Prior(), 100);
これで得られるチェインを使ってy
に対して予測をします。model
はすでに観測値obs_y
をセットしてしまっているので、新たにy[i]
をmissing
にしたモデルを定義して予測値を得ます。
model_missing = logistic_reg(obs_x, fill(missing, length(obs_x)))
prior_pred_sim = predict(model_missing, prior_chain);
sample
関数にmodel_missing
を通してもよさそうに思うのですが、やってみるとうまくいきませんでした。そのため、prior predictive checkのために、データの載ったmodel
を使う必要がありやや不自然に感じます。
何はともあれ得られた生存率を図にしてみます。
plot([16.5, 16.5], [0, 1], linestyle=:dot, color=:black)
plot!([32.5, 32.5], [0, 1], linestyle=:dot, color=:black)
for i in 1:size(prior_pred_sim)[1]
p = reshape(prior_pred_sim[i,:,1].value.data, 48)
scatter!(1:48, p ./ df[!, :density], markersize=2, markercolor=:gray, legend=false, markeralpha=0.3, markerstrokewidth=0)
end
current()
おおむね一様に点が打たれているので問題なさそうです。[2] 事前に調べてこうなることがわかっていたので、初めから上記の事前分布を用いていましたが、本来は自分の仮説を表現する分布を得るために試行錯誤が必要です。また、一様な分布が欲しいからといって、例えば
推論
準備ができたので、データをもとにパラメータの分布を推定します。NUTSサンプラーを利用します。
post = sample(model, NUTS(), MCMCThreads(), 3000, 3);
チェイン数を3、チェインごとに3000サンプルとる設定にしています。Julia起動時にスレッド数を指定していれば、マルチスレッドで並列に処理できます。
MCMCの収束チェック
結果を解釈する前に、得られたチェインの収束をチェックします。
summarize(post)
とすると以下のように表示されました。
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
α 0.8679 0.8433 0.0089 0.0166 3094.4956 1.0007
β[1] 1.5037 0.6648 0.0070 0.0108 3464.4218 1.0000
β[2] -1.1635 0.6655 0.0070 0.0112 3458.5998 0.9999
γ[1] 0.5181 0.6615 0.0070 0.0125 3053.7478 1.0008
γ[2] -0.1460 0.6605 0.0070 0.0126 3049.1536 1.0008
ess
が3000を超えていて、rhat
も1に近いので問題なさそうです。念の為可視化もしてみます。
using StatsPlots
plot(post)
ここでうまく収束していない場合は、MCMCの設定を修正したり、モデルのparametrizationを変えたりしてうまくいくよう試行錯誤することになります。
結果の解釈
上のグラフと合わせて、パラメータ分布のquantileも見てみます。
quantile(post)
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
α -0.7870 0.3015 0.8859 1.4292 2.5049
β[1] 0.2078 1.0510 1.5170 1.9514 2.7924
β[2] -2.4590 -1.6129 -1.1575 -0.7160 0.1364
γ[1] -0.7556 0.0686 0.5042 0.9472 1.8420
γ[2] -1.4270 -0.5943 -0.1553 0.2852 1.1805
タンクごとの環境差があり分散が大きいデータセットなので、単純なロジスティック回帰では精度はあまりよくないです。
その他
WAIC
モデルの汎化性能を比較する際に情報量基準を使うことがあると思うので、TuringでWAICを計算する方法を紹介します。WAICは
で計算できます。pointwise_loglikelihoods
で手に入るので以下のようにすればWAICが評価できます。
using StatsFuns #logsumexpを使うため
logmeanexp(x) = logsumexp(x) - log(length(x))
function waic(model, chain)
model_params = chain.name_map[:parameters]
lppd = pointwise_loglikelihoods(model, chain[model_params])
lppd = values(lppd)
pointwise_waic = -2*(logmeanexp.(lppd) - var.(lppd))
return sum(pointwise_waic)
end
先程のモデルで求めてみると
waic(model1, post1)
262.34169031358215
となりました。
階層モデル
上のモデルは環境差が考慮できていなかったので精度がよくありませんでした。そこでここでは以下のような階層モデルを使って分析してみます。
logitの
階層モデルもTuringでは普通に定義できます。
@model function multi_level(x, y)
ᾱ ~ Normal(0, 1.0)
σ ~ Exponential(1)
# use non-centered parametrization for better convergence
z ~ MvNormal(48, 1.0)
α = ᾱ .+ σ * z
β ~ MvNormal(2, 1.0)
γ ~ MvNormal(2, 1.0)
logit_p = map(xj->α[xj.i] + β[xj.predator] + γ[xj.size], x)
p = sigmoid.(logit_p)
for j in eachindex(x)
y[j] ~ Binomial(x[j].n, p[x[j].i])
end
return
end
上と同様の手順で推論をすると以下のようになりました。
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
β[1] 0.3535 1.1577 1.6025 2.0441 2.8447
β[2] -2.1994 -1.3932 -0.9617 -0.5310 0.2613
γ[1] -0.6647 0.1573 0.5846 1.0151 1.8045
γ[2] -1.1479 -0.3358 0.0895 0.5088 1.3538
σ 0.5188 0.6731 0.7630 0.8625 1.0938
ᾱ -0.7047 0.1832 0.6742 1.1667 2.0396
最後に
上で見てきたように、Turing.jlを使えば定型のデータ分析はこなせそうです。次はSoss.jlを使ってTuringとの比較をしてみたいと思います。
-
Statistical Rethinkingにならっています。 ↩︎
-
縦方向に等間隔の隙間があるのは、
が各タンクで固定だからです。 ↩︎N_j -
まあデータが多ければそこまで気にするようなことでもないのですが。 ↩︎
-
ただし
は95%区間で正の値を取るので、証拠としては弱めです。 ↩︎\beta[2]
Discussion