📌

NumPyro の実行例 - Eight Schools

8 min read

NumPyro の実行例 - Eight Schools

有名な例題である Eight Schools を Google Colab 上で実行してみました。コードは、こちらのサイトを参考に書いています。

http://num.pyro.ai/en/latest/getting_started.html#a-simple-example-8-schools

Install Packages

まずは、NumPyro をインストールします。インストール完了後に、ランタイムは再起動しておきます。

!pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install numpyro==0.7.2
!pip install arviz==0.11.2

Import Packages

次に、必要なパッケージをインポートします。

import numpyro
import numpyro.distributions as dist

import jax
import jax.numpy as jnp
import arviz as az
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

Define Data

データは、下のデータを利用しました。

J = 8
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

Define Model & Inference

まずは、NumPyro の Getting Started にある通りのモデルでやってみます。

def model(J, sigma, y=None):

    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    
    eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
    theta = numpyro.deterministic('theta', mu + tau * eta)
    
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)

mcmc.run(jax.random.PRNGKey(0), J, sigma, y)

idata = az.from_numpyro(mcmc)
CPU times: user 20.3 s, sys: 188 ms, total: 20.5 s
Wall time: 14.6 s

Chain は 4本計算しましたが、概ね 15秒くらいで計算できました。

az.plot_trace(idata);

png

az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
0 eta[0] 0.315 0.989 -1.544 2.169 0.002 0.002 348062 289037 1
1 eta[1] 0.098 0.94 -1.691 1.864 0.002 0.001 387829 294075 1
2 eta[2] -0.085 0.97 -1.908 1.748 0.001 0.002 419545 305056 1
3 eta[3] 0.063 0.942 -1.726 1.83 0.001 0.001 419036 302751 1
4 eta[4] -0.161 0.931 -1.911 1.604 0.002 0.001 361756 287679 1
5 eta[5] -0.07 0.942 -1.866 1.689 0.001 0.001 400910 299947 1
6 eta[6] 0.356 0.96 -1.474 2.145 0.002 0.001 346002 286779 1
7 eta[7] 0.075 0.975 -1.755 1.918 0.002 0.002 379096 292705 1
8 mu 4.387 3.321 -1.967 10.544 0.006 0.004 352183 278018 1
9 tau 3.592 3.227 0 9.294 0.006 0.004 220316 162144 1
10 theta[0] 6.188 5.598 -3.768 17.023 0.009 0.007 360746 322368 1
11 theta[1] 4.942 4.672 -3.781 13.981 0.007 0.005 438859 337950 1
12 theta[2] 3.92 5.269 -6.207 13.716 0.009 0.006 392005 334055 1
13 theta[3] 4.749 4.764 -4.203 13.968 0.007 0.006 438214 339686 1
14 theta[4] 3.605 4.663 -5.417 12.27 0.007 0.006 414163 332309 1
15 theta[5] 4.039 4.828 -5.303 13.073 0.007 0.006 422227 343072 1
16 theta[6] 6.294 5.083 -2.851 16.278 0.008 0.006 396408 327497 1
17 theta[7] 4.838 5.296 -5.165 14.943 0.009 0.006 384954 327053 1

ただ、PyStan で計算させたときと、計算結果が違うようなので、事前分布をもう少しフラットなものに置き換えます。PyStan での実行例は下をご覧頂けたらと思います。

https://zenn.dev/eota/articles/pystan_eight_schools
def model_modified(J, sigma, y=None):

    mu = numpyro.sample('mu', dist.Normal(0, 100))
    tau = numpyro.sample('tau', dist.HalfNormal(100))
    
    eta = numpyro.sample('eta', dist.Normal(0, 1), sample_shape=(J, ))
    theta = numpyro.deterministic('theta', mu + tau * eta)
    
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
%%time

nuts = numpyro.infer.NUTS(model_modified, target_accept_prob=0.99, max_tree_depth=10)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=100000, num_chains=4, progress_bar=False)

mcmc.run(jax.random.PRNGKey(0), J, sigma, y)

idata = az.from_numpyro(mcmc)
CPU times: user 21.3 s, sys: 60.2 ms, total: 21.3 s
Wall time: 14.4 s

計算時間はあまり変わりません。やはり 15秒くらいでした。

az.plot_trace(idata);

png

az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
0 eta[0] 0.389 0.94 -1.409 2.139 0.002 0.001 357491 278624 1
1 eta[1] -0.001 0.874 -1.675 1.646 0.001 0.001 440918 288705 1
2 eta[2] -0.193 0.933 -1.944 1.578 0.001 0.002 460476 287100 1
3 eta[3] -0.029 0.884 -1.696 1.654 0.001 0.001 456832 290115 1
4 eta[4] -0.353 0.875 -2.008 1.317 0.001 0.001 366124 285362 1
5 eta[5] -0.211 0.893 -1.896 1.488 0.001 0.001 428520 280111 1
6 eta[6] 0.346 0.888 -1.345 2.019 0.001 0.001 388660 280306 1
7 eta[7] 0.059 0.936 -1.689 1.84 0.001 0.002 459742 286121 1
8 mu 7.91 5.16 -1.763 17.535 0.011 0.01 236141 191930 1
9 tau 6.518 5.557 0 16.117 0.015 0.012 145392 163708 1
10 theta[0] 11.342 8.275 -3.133 27.98 0.015 0.011 305037 302050 1
11 theta[1] 7.88 6.248 -3.923 20.054 0.009 0.007 525761 348429 1
12 theta[2] 6.13 7.769 -9.565 20.486 0.012 0.01 414265 314637 1
13 theta[3] 7.644 6.515 -4.988 20.074 0.009 0.007 513369 341057 1
14 theta[4] 5.137 6.337 -7.308 16.689 0.01 0.008 417383 337156 1
15 theta[5] 6.14 6.71 -7.051 18.672 0.01 0.008 463469 342193 1
16 theta[6] 10.645 6.781 -1.396 24.202 0.011 0.008 389765 340423 1
17 theta[7] 8.443 7.838 -6.644 23.83 0.012 0.01 423002 307346 1

今度は PyStan で計算したときの計算結果に近いものが出てきました。

Summary

NumPyro の方では、Chain を 4本にしたときしか計算していませんが、いずれのケースでも 15秒程度で計算が完了しました。

関連情報

https://note.com/ds_kotaro/n/n9e4072503f51

Discussion

ログインするとコメントできます