📖

Turing.jlで統計モデリング

10 min read

はじめに

Juliaでベイジアンモデリングをするためのツールを調べています。JuliaではTuring.jlというパッケージが有名です。

https://github.com/TuringLang/Turing.jl

こちらの記事などでも詳しく紹介されています。

https://zenn.dev/takilog/articles/c5d29ecbce7565

ベイジアンなデータ分析というと概ね以下のような手順になります。なおMCMCで推論することを想定しています。

  1. データの準備
  2. 確率モデルの定義
  3. 事前分布の決定
  4. 推論
  5. MCMCの収束チェック
  6. 結果の解釈

この一連の手順でTuringをどう使うかをまとめてみたいと思います。

データの準備

以下のデータを使います。様々な条件(サイズや密度・捕食者の有無など)におけるオタマジャクシの生存数を調べた実験データです。

https://rdrr.io/github/rmcelreath/rethinking/man/reedfrogs.html

まずはデータをロードします。

using HTTP
using CSV
using DataFrames

resp = HTTP.get("https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/reedfrogs.csv").body;
df = CSV.read(resp, DataFrame, header=1, delim=';')

first(df, 10)

全部で48行のテーブルデータで、各行には初めのオタマジャクシ数 (density) と一定時間後の生存数 (surv) が記録されています。データの各行を「タンク」と呼ぶことにします。[1]
各タンクでの生存率をグラフにすると以下のようになります。

タンクのdensityで三種類に分類しています(縦の点線)。左のブロックが一番小さく、右のブロックが一番大きいです。平均生存率は75%程度ですが、図を見ればわかる通り、タンクごとの差の大きいデータセットになっています。これは実験の介入の有無だけでなく、観測しきれていないタンク固有の環境差からも来ています。今回はこのデータを使って介入 (pred, size) の効果を分析してみます。

確率モデルの定義

以下のようなロジスティック回帰モデルを定義します。後で見るように、上記データはタンクごとの環境差が大きいのでこのモデルでは予測精度は良くありません。

\begin{aligned} \alpha &\sim \mathrm{Normal}(0, 1.5) \\ \beta_k & \sim \mathrm{Normal}(0, 1) \quad(k=1,2)\\ \gamma_\ell & \sim \mathrm{Normal}(0, 1) \quad(\ell=1,2) \\ p_j &= \mathrm{sigmoid}(\alpha + \beta_{\mathrm{predator}[j]} + \gamma_{\mathrm{size}[j]})\\ y_j &\sim \mathrm{Binomial}(N_j, p_j) \end{aligned}

ここで

  • \mathrm{predator}[i]はi番目のタンクの捕食者の有無を表し、pred="pred"なら2、そうでなければ1をとるとします
  • \mathrm{size}[i]はi番目のタンクのオタマジャクシの大きさを表し、size="big"なら2、そうでなければ1をとるとします
  • j=1,\dots,48はタンクを表します。y_j はタンク j の生存数です。

Turingではこのモデルは以下のような直感的な記法で定義することができます。

using Turing
using Distributions

sigmoid(x) = 1 / (1 + exp(-x))

@model function logistic_reg(x, y)
    α ~ Normal(0, 1.5)
    β ~ MvNormal(2, 1.0)
    γ ~ MvNormal(2, 1.0)
    
    logit_p = map(xj->α + β[xj.predator] + γ[xj.size], x)
    p = sigmoid.(logit_p)
    
    for j in eachindex(x)
        y[j] ~ Binomial(x[j].n, p[x[j].i])
    end
    return
end

確率変数は ~ で定義します。それ以外は普通のJuliaの文法になっているのでわかりやすいです。今回はデータ y は観測値として与えるため、モデルの引数に入れています。

このモデルに以下のようにしてデータをセットしました。

obs_x = [(i=i, n=df[i, :density], predator=ifelse(df[i, :pred]=="pred", 2, 1), size=ifelse(df[i, :size]=="big", 2, 1)) for i in 1:nrow(df)]
obs_y = df[!, :surv];

model = logistic_reg(obs_x, obs_y)

事前分布の決定

上では天下りに \alpha, \beta, \gamma の事前分布を与えました。ここについてもう少し詳しくみます。

ベイジアンなデータ分析では、できるだけデータを見る前に事前分布を設計することが望ましいです。ですので初めの方に書いた図は一旦忘れて、事前分布として何が相応しいか考えてみます。私はオタマジャクシの生存率の分布に関して特に専門的な知見を持ち合わせているわけではないので、生存率が一様になるような事前分布が欲しいです。ですので、上で定義した確率モデル logistic_reg からランダムサンプリングした時に、生存率分布が一様になるかどうかを確かめる必要があります(これをprior predictive check/simulationと言ったりします)。

Turingではこのために事前分布からサンプリングする方法が用意されています

prior_chain = sample(model, Prior(), 100);

これで得られるチェインを使ってyに対して予測をします。modelはすでに観測値obs_yをセットしてしまっているので、新たにy[i]missingにしたモデルを定義して予測値を得ます。

model_missing = logistic_reg(obs_x, fill(missing, length(obs_x)))
prior_pred_sim = predict(model_missing, prior_chain);

sample関数にmodel_missingを通してもよさそうに思うのですが、やってみるとうまくいきませんでした。そのため、prior predictive checkのために、データの載ったmodelを使う必要がありやや不自然に感じます。
何はともあれ得られた生存率を図にしてみます。

plot([16.5, 16.5], [0, 1], linestyle=:dot, color=:black)
plot!([32.5, 32.5], [0, 1], linestyle=:dot, color=:black)
for i in 1:size(prior_pred_sim)[1]
    p = reshape(prior_pred_sim[i,:,1].value.data, 48)
    scatter!(1:48, p ./ df[!, :density], markersize=2, markercolor=:gray, legend=false, markeralpha=0.3, markerstrokewidth=0)
end
current()


おおむね一様に点が打たれているので問題なさそうです。[2] 事前に調べてこうなることがわかっていたので、初めから上記の事前分布を用いていましたが、本来は自分の仮説を表現する分布を得るために試行錯誤が必要です。また、一様な分布が欲しいからといって、例えば \alpha \sim \mathrm{Normal}(0, 100) のようにしてしまうと、sigmoid が入るせいで生存率分布が0か1に集中してしまうので注意が必要です。[3]

推論

準備ができたので、データをもとにパラメータの分布を推定します。NUTSサンプラーを利用します。

post = sample(model, NUTS(), MCMCThreads(), 3000, 3);

チェイン数を3、チェインごとに3000サンプルとる設定にしています。Julia起動時にスレッド数を指定していれば、マルチスレッドで並列に処理できます。

MCMCの収束チェック

結果を解釈する前に、得られたチェインの収束をチェックします。

summarize(post)

とすると以下のように表示されました。

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64 

           α    0.8679    0.8433     0.0089    0.0166   3094.4956    1.0007
        β[1]    1.5037    0.6648     0.0070    0.0108   3464.4218    1.0000
        β[2]   -1.1635    0.6655     0.0070    0.0112   3458.5998    0.9999
        γ[1]    0.5181    0.6615     0.0070    0.0125   3053.7478    1.0008
        γ[2]   -0.1460    0.6605     0.0070    0.0126   3049.1536    1.0008

essが3000を超えていて、rhatも1に近いので問題なさそうです。念の為可視化もしてみます。

using StatsPlots

plot(post)

ここでうまく収束していない場合は、MCMCの設定を修正したり、モデルのparametrizationを変えたりしてうまくいくよう試行錯誤することになります。

結果の解釈

上のグラフと合わせて、パラメータ分布のquantileも見てみます。

quantile(post)
Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           α   -0.7870    0.3015    0.8859    1.4292    2.5049
        β[1]    0.2078    1.0510    1.5170    1.9514    2.7924
        β[2]   -2.4590   -1.6129   -1.1575   -0.7160    0.1364
        γ[1]   -0.7556    0.0686    0.5042    0.9472    1.8420
        γ[2]   -1.4270   -0.5943   -0.1553    0.2852    1.1805

\beta[1], \beta[2]はそれぞれ正・負の側に推定されています。\beta[1]が捕食者なし・\beta[2]が捕食者ありの場合のlogitへの寄与なので、捕食者がいると生存数が下がり、いないと上がることが確かめられました。[4]一方で、サイズの効果 \gamma は大きくないようです。

p_jをプロットしてみます。

タンクごとの環境差があり分散が大きいデータセットなので、単純なロジスティック回帰では精度はあまりよくないです。

その他

WAIC

モデルの汎化性能を比較する際に情報量基準を使うことがあると思うので、TuringでWAICを計算する方法を紹介します。WAICは

\mathrm{WAIC} = -2\left(\sum_i\log\left(\frac{1}{S}\sum_{s=1}^Sp(y_i|\theta_s)\right) - \sum_i\mathrm{var}_\theta \log p(y_i | \theta)\right)

で計算できます。y_iが観測データで、\theta_sはパラメータのs個目のサンプルです。この式を見ると、各データ点でのlog-likelihood \log p(y_i|\theta_s) があれば計算できることがわかります。これはTuringでは pointwise_loglikelihoods で手に入るので以下のようにすればWAICが評価できます。

using StatsFuns #logsumexpを使うため

logmeanexp(x) = logsumexp(x) - log(length(x))

function waic(model, chain)
    model_params = chain.name_map[:parameters]
    lppd = pointwise_loglikelihoods(model, chain[model_params])
    lppd = values(lppd)
    pointwise_waic = -2*(logmeanexp.(lppd) - var.(lppd))
    return sum(pointwise_waic)
end

先程のモデルで求めてみると

waic(model1, post1)
262.34169031358215

となりました。

階層モデル

上のモデルは環境差が考慮できていなかったので精度がよくありませんでした。そこでここでは以下のような階層モデルを使って分析してみます。

\begin{aligned} \bar\alpha &\sim \mathrm{Normal}(0, 1.5) \\ \sigma &\sim \mathrm{Exponential}(1) \\ \alpha_j &\sim \mathrm{Normal}(\bar{\alpha}, \sigma) \\ \beta_k & \sim \mathrm{Normal}(0, 1) \quad(k=1,2)\\ \gamma_\ell & \sim \mathrm{Normal}(0, 1) \quad(\ell=1,2) \\ p_j &= \mathrm{sigmoid}(\alpha_j + \beta_{\mathrm{predator}[j]} + \gamma_{\mathrm{size}[j]})\\ y_j &\sim \mathrm{Binomial}(N_j, p_j) \end{aligned}

logitの \alpha をタンクごとに \alpha_j とし、さらに \alpha_j の従う分布のパラメータ \bar\alpha, \sigma にも事前分布をおいてデータから推定できるようにしています。これによって、タンクごとの環境差を考慮することができます。単に \alpha_j に環境ごとに独立な事前分布を設定するのではなく、\bar\alpha, \sigma を通すことで、タンク間で知見が共有されるようになり、新しいタンクについても予測をすることができるようになります。

階層モデルもTuringでは普通に定義できます。

@model function multi_level(x, y)
    ᾱ ~ Normal(0, 1.0)
    σ ~ Exponential(1)

    # use non-centered parametrization for better convergence
    z ~ MvNormal(48, 1.0)
    
    α = ᾱ .+ σ * z    
    
    β ~ MvNormal(2, 1.0)
    γ ~ MvNormal(2, 1.0)
    
    logit_p = map(xj->α[xj.i] + β[xj.predator] + γ[xj.size], x)
    p = sigmoid.(logit_p)
    
    for j in eachindex(x)
        y[j] ~ Binomial(x[j].n, p[x[j].i])
    end
    return
end

上と同様の手順で推論をすると以下のようになりました。

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        β[1]    0.3535    1.1577    1.6025    2.0441    2.8447
        β[2]   -2.1994   -1.3932   -0.9617   -0.5310    0.2613
        γ[1]   -0.6647    0.1573    0.5846    1.0151    1.8045
        γ[2]   -1.1479   -0.3358    0.0895    0.5088    1.3538
           σ    0.5188    0.6731    0.7630    0.8625    1.0938-0.7047    0.1832    0.6742    1.1667    2.0396

\beta, \gamma の解釈に大きな変更はありません。しかしp_jをプロットしてみると以下のようになり、精度が良くなっていることがわかります。

最後に

上で見てきたように、Turing.jlを使えば定型のデータ分析はこなせそうです。次はSoss.jlを使ってTuringとの比較をしてみたいと思います。

脚注
  1. Statistical Rethinkingにならっています。 ↩︎

  2. 縦方向に等間隔の隙間があるのは、N_jが各タンクで固定だからです。 ↩︎

  3. まあデータが多ければそこまで気にするようなことでもないのですが。 ↩︎

  4. ただし\beta[2]は95%区間で正の値を取るので、証拠としては弱めです。 ↩︎

Discussion

ログインするとコメントできます