📌

PyStan の実行例 - Eight Schools

5 min read

PyStan の実行例 - Eight Schools

有名な例題である Eight Schools を Google Colab 上で実行してみました。コードはこちらのサイトを参考に書いています。

https://github.com/stan-dev/rstan/wiki/RStan-Getting-Started

Install Package

まずは、PyStan をインストールします。

!pip install pystan==2.19.1.1

Import Packages

次に、必要なパッケージをインポートします。

import pystan
import arviz as az

Define Data

データは、下のデータを利用しました。

data = {'J': 8,
        'y': [28,  8, -3,  7, -1,  1, 18, 12],
        'sigma': [15, 10, 16, 11,  9, 11, 10, 18]}

Define Model & Inference

モデルを定義して、推論を行います。

model_code = """
data {
  int<lower=0> J;         // number of schools
  real y[J];              // estimated treatment effects
  real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
  real mu;                // population treatment effect
  real<lower=0> tau;      // standard deviation in treatment effects
  vector[J] eta;          // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta;        // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1);       // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
%%time

sm = pystan.StanModel(model_code=model_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_67cb7d0f2cb7720776cbeb52007d2dbb NOW.


CPU times: user 1.86 s, sys: 132 ms, total: 1.99 s
Wall time: 1min 10s

Google Colab だと、コードのコンパイルに1分くらいかかりました。手元のマシンで計算すると、もっと速くなる可能性はあるかもしれません。次に、Chain を 1本だけ計算させた場合と Chain を 4本計算させた場合の両方で時間を測ってみます(Chain を並列に計算させる方法がわからなかったため ^^;)

%%time

mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=1, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:2 of 99500 iterations ended with a divergence (0.00201 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.


CPU times: user 7.44 s, sys: 130 ms, total: 7.57 s
Wall time: 7.54 s
%%time

mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=4, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:10 of 398000 iterations ended with a divergence (0.00251 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.


CPU times: user 7.45 s, sys: 434 ms, total: 7.88 s
Wall time: 38.3 s

Chain 1本あたりで 8~10秒くらいの時間がかかっているようです。それにしても、Stan は WARNING が細かく、親切設計ですね ^^

idata = az.from_pystan(mcmc)
az.plot_trace(idata);

png

az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
0 mu 7.93 5.217 -1.818 17.534 0.015 0.022 212292 176914 1
1 tau 6.581 5.728 0 16.263 0.019 0.026 144668 174137 1
2 eta[0] 0.39 0.938 -1.369 2.172 0.002 0.001 361783 281559 1
3 eta[1] 0.001 0.874 -1.674 1.643 0.001 0.001 389953 288434 1
4 eta[2] -0.194 0.927 -1.922 1.573 0.001 0.002 396820 288937 1
5 eta[3] -0.031 0.883 -1.676 1.672 0.001 0.001 357902 281880 1
6 eta[4] -0.351 0.878 -1.995 1.342 0.002 0.001 332011 273770 1
7 eta[5] -0.21 0.893 -1.91 1.474 0.001 0.001 415465 286695 1
8 eta[6] 0.342 0.886 -1.358 2.009 0.001 0.001 354287 272854 1
9 eta[7] 0.056 0.933 -1.711 1.81 0.001 0.002 410448 290257 1
10 theta[0] 11.384 8.345 -2.935 28.463 0.016 0.011 305154 302281 1
11 theta[1] 7.894 6.282 -4.21 19.89 0.009 0.007 509134 350682 1
12 theta[2] 6.134 7.729 -9.339 20.493 0.013 0.01 389201 319471 1
13 theta[3] 7.645 6.536 -4.839 20.317 0.009 0.007 486618 348020 1
14 theta[4] 5.12 6.355 -7.239 16.81 0.01 0.008 409586 336711 1
15 theta[5] 6.142 6.727 -7.249 18.535 0.01 0.008 466074 348823 1
16 theta[6] 10.652 6.785 -1.628 24.027 0.011 0.008 368395 334347 1
17 theta[7] 8.443 7.879 -6.834 23.83 0.013 0.011 387466 300924 1

Summary

Google Colab 上でやっているためかもしれませんが、Stan のコードをコンパイルするのに1分くらいの時間がかかりました。サンプルはかなり多く発生させましたが、Chain 1本あたり 8~10秒くらいで走りました。

関連情報

https://note.com/ds_kotaro/n/n9e4072503f51

Discussion

ログインするとコメントできます