NumPyro:階層モデル
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
今回は階層モデルを扱います。以前までと同様に本シリーズの記事は実装メインなため、理論的な詳細は省略します。最初に1つの階層のモデルを扱い、次に複数の階層のモデルを扱います。どちらもNumpyの行列計算と今までのexpand
やplate
の挙動が理解できていればイメージ掴める内容かと思います。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
単一の階層モデル
データの読み込み
前回までと同様にアヒル本のデータを使用します。今回はchap08のchap08/input/data-salary-2.txt
を使用します。説明変数はX
のみで、個体を表すIDとしてKID
が与えられています。また、Stanは1始まりですが、NumPyroは0始まりなのでKID - 1
を行なっています。
df = pd.read_csv("../tutorial/RStanBook/chap08/input/data-salary-2.txt")
df["KID"] = df["KID"] -1
df.head()
X Y KID
0 7 457 0
1 10 482 0
2 16 518 0
3 25 535 0
4 5 427 0
ちなみに、jax.numpy
だとnumpy
と異なり、以下のようにIndexを超えてアクセスしても値が返されてしまいます。そのため、バグに気づかないこともあるので注意してください。(私はこれで数時間無駄にしました。。。)
numpy_array = np.array([0, 1, 2, 3])
jax_numpy_array = jnp.array([0, 1, 2, 3])
try:
print(numpy_array[10])
except IndexError:
print("Numpy だとIndexErrorが出る")
try:
print(jax_numpy_array[10])
except IndexError:
print("jax.numpy だとIndexErrorが出ない")
Numpy だとIndexErrorが出る
3
モデルの定義
モデルの等価な式の形として2通りあるので両方とも実装する。ここで、階層モデルを扱う際に以下のNumpyの操作を使用しています。通常のNumpyと同様の挙動ですが、一応確認しておきましょう。
# coef[KID]では以下のように動作します
kid = np.array([0, 1, 2, 3, 2, 1, 0])
coef_diff = np.array([0, 10, 20, 30])
coef_diff[kid]
array([ 0, 10, 20, 30, 20, 10, 0])
係数の式をまとめない場合
def model1(x, KID, y=None):
intercept_common = numpyro.sample("intercept_common", dist.Normal(0, 100))
coef_common = numpyro.sample("coef_common", dist.Normal(0, 100))
sigma_intercept = numpyro.sample("sigma_intercept", dist.HalfNormal(100))
sigma_coef = numpyro.sample("sigma_coef", dist.HalfNormal(100))
sigma_Y = numpyro.sample("sigma_Y", dist.HalfNormal(100))
with numpyro.plate("K", len(np.unique(KID))):
coef_diff = numpyro.sample("coef_diff", dist.Normal(0, sigma_coef))
intercept_diff = numpyro.sample("intercept_diff", dist.Normal(0, sigma_intercept))
coef = numpyro.deterministic("coef", coef_common + coef_diff)
intercept = numpyro.deterministic("intercept", intercept_common + intercept_diff)
mu = numpyro.deterministic("mu", coef[KID]*x + intercept[KID])
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, sigma_Y), obs=y)
係数の式をまとめる場合
def model2(x, KID, y=None):
intercept_common = numpyro.sample("intercept_common", dist.Normal(0, 100))
coef_common = numpyro.sample("coef_common", dist.Normal(0, 100))
sigma_intercept = numpyro.sample("sigma_intercept", dist.HalfNormal(100))
sigma_coef = numpyro.sample("sigma_coef", dist.HalfNormal(100))
sigma_Y = numpyro.sample("sigma_Y", dist.HalfNormal(100))
with numpyro.plate("K", len(np.unique(KID))):
coef = numpyro.sample("coef", dist.Normal(coef_common, sigma_coef))
intercept = numpyro.sample("intercept", dist.Normal(intercept_common, sigma_intercept))
mu = numpyro.deterministic("mu", coef[KID]*x + intercept[KID])
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, sigma_Y), obs=y)
モデルのレンダリング
それぞれのモデルを描画してみます。numpyro.plate
を使用することで階層も分かりやすく表示されますね。
係数の式をまとめない場合
係数の式をまとめる場合
MCMC
結果を確認するとちゃんと収束していることが分かります。また、実行時間は4chainsで3.25s, 3.1sとほとんど変わらないですね。アヒル本の結果を見るとStanでは1chainsごとに2s, 1sという結果のようなので、コンパイル時間や実行時間の点でも今回の例ではNumPyroの速さが際立っています。(他のサイトの結果を見るとStanとNumPyroではほとんど差がかわらないが、コンパイル時間まで加味するとNumPyroが速いという結果が多いように思えます)
係数の式をまとめない場合
%%time
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc1 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc1.run(
rng_key=rng_key,
x=df["X"].values,
KID=df["KID"].values,
y=df["Y"].values,
)
mcmc1.print_summary()
CPU times: user 4.21 s, sys: 65.7 ms, total: 4.27 s
Wall time: 3.25 s
mean std median 5.0% 95.0% n_eff r_hat
coef_common 13.08 6.74 13.05 3.53 22.10 865.19 1.00
coef_diff[0] -5.17 6.75 -5.10 -14.11 4.31 875.46 1.00
coef_diff[1] 6.46 6.83 6.32 -2.87 15.98 879.47 1.00
coef_diff[2] -0.73 6.83 -0.79 -10.24 8.86 878.51 1.00
coef_diff[3] -0.37 7.23 -0.00 -11.21 8.89 998.62 1.00
intercept_common 287.39 73.96 306.88 169.71 386.68 899.60 1.00
intercept_diff[0] 93.05 77.00 74.28 -10.02 214.83 907.62 1.00
intercept_diff[1] 45.16 73.34 24.29 -58.57 161.55 926.99 1.00
intercept_diff[2] 28.84 72.88 9.52 -73.35 150.35 1035.96 1.00
intercept_diff[3] 121.24 135.47 82.92 -46.88 337.16 958.68 1.00
sigma_Y 29.18 4.06 28.82 22.80 35.44 2792.77 1.00
sigma_coef 10.32 10.06 7.36 2.26 19.39 1275.23 1.00
sigma_intercept 95.76 65.00 80.52 2.64 189.72 1008.46 1.00
Number of divergences: 5
係数の式をまとめる場合
%%time
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc2 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc2.run(
rng_key=rng_key,
x=df["X"].values,
KID=df["KID"].values,
y=df["Y"].values,
)
mcmc2.print_summary()
CPU times: user 4.08 s, sys: 55.9 ms, total: 4.14 s
Wall time: 3.1 s
mean std median 5.0% 95.0% n_eff r_hat
coef_common 13.08 6.74 13.05 3.53 22.10 865.19 1.00
coef_diff[0] -5.17 6.75 -5.10 -14.11 4.31 875.46 1.00
coef_diff[1] 6.46 6.83 6.32 -2.87 15.98 879.47 1.00
coef_diff[2] -0.73 6.83 -0.79 -10.24 8.86 878.51 1.00
coef_diff[3] -0.37 7.23 -0.00 -11.21 8.89 998.62 1.00
intercept_common 287.39 73.96 306.88 169.71 386.68 899.60 1.00
intercept_diff[0] 93.05 77.00 74.28 -10.02 214.83 907.62 1.00
intercept_diff[1] 45.16 73.34 24.29 -58.57 161.55 926.99 1.00
intercept_diff[2] 28.84 72.88 9.52 -73.35 150.35 1035.96 1.00
intercept_diff[3] 121.24 135.47 82.92 -46.88 337.16 958.68 1.00
sigma_Y 29.18 4.06 28.82 22.80 35.44 2792.77 1.00
sigma_coef 10.32 10.06 7.36 2.26 19.39 1275.23 1.00
sigma_intercept 95.76 65.00 80.52 2.64 189.72 1008.46 1.00
複数の階層モデル
NumPyroのチュートリアルを一部参考にしてコードを作成しています。
データの読み込み
前回までと同様にアヒル本のデータを使用します。今回はchap08のchap08/input/data-salary-3.txt
を使用します。説明変数はX
のみで、個体を属するIDとしてKID
が与えられており、KIDが属するさらに大きな概念としてGIDが与えられています。また先ほどと同様に、Stanは1始まりですが、NumPyroは0始まりなので1引いています。
df = pd.read_csv("./RStanBook/chap08/input/data-salary-3.txt")
df["KID"] = df["KID"] -1
df["GID"] = df["GID"] -1
print(df.KID.unique())
print(df.GID.unique())
df.head()
X Y KID
0 7 457 0
1 10 482 0
2 16 518 0
3 25 535 0
4 5 427 0
ここで、次のモデル内部で使用するのでKIDからGIDへマッピングする配列を作成しておきます。
map_kid_to_gid= (
df[["KID", "GID"]]
.drop_duplicates()
.set_index("KID", verify_integrity=True)
.sort_index()["GID"]
.values
)
print(map_kid_to_gid)
array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2])
モデルの定義
先ほどのモデル2の分布の平均
def model3(x, GID, KID, map_kid_to_gid, y=None):
a0 = numpyro.sample("a0", dist.Normal(0, 100))
b0 = numpyro.sample("b0", dist.Normal(0, 100))
s_ag = numpyro.sample("s_ag", dist.HalfNormal(100))
s_bg = numpyro.sample("s_bg", dist.HalfNormal(100))
s_a = numpyro.sample("s_a", dist.HalfNormal(100))
s_b = numpyro.sample("s_b", dist.HalfNormal(100))
s_Y = numpyro.sample("s_Y", dist.HalfNormal(100))
with numpyro.plate("G", len(np.unique(GID))):
a1 = numpyro.sample("a1", dist.Normal(a0, s_ag))
b1 = numpyro.sample("b1", dist.Normal(b0, s_bg))
with numpyro.plate("K", len(np.unique(KID))):
a = numpyro.sample("a", dist.Normal(a1[map_kid_to_gid], s_a))
b = numpyro.sample("b", dist.Normal(b1[map_kid_to_gid], s_b))
mu = numpyro.deterministic("mu", a[KID]*x + b[KID])
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, s_Y), obs=y)
ここで、少し慣れた人は2つ目のnumpyro.plate
はいらないのではないかと思うかと思います。なぜならa1[map_kid_to_gid]
が長さlen(np.unique(KID))
の配列であり、plate
がなくてもlen(np.unique(KID))
の数だけ分布が生成されるためです。これは正しい認識でして、plate
をコメントアウトした以下の実装でも結果は同じです。しかし、plate
があることで次のモデルのレンダリング時にプレートが複数表示されたり、コード的にも階層であることが分かりやすくなったりするので(恐らく)わざわざ書くのが推奨の書き方かと思います。(公式のチュートリアルの書き方も同じです)
def model3_non(x, GID, KID, map_kid_to_gid, y=None):
a0 = numpyro.sample("a0", dist.Normal(0, 100))
b0 = numpyro.sample("b0", dist.Normal(0, 100))
s_ag = numpyro.sample("s_ag", dist.HalfNormal(100))
s_bg = numpyro.sample("s_bg", dist.HalfNormal(100))
s_a = numpyro.sample("s_a", dist.HalfNormal(100))
s_b = numpyro.sample("s_b", dist.HalfNormal(100))
s_Y = numpyro.sample("s_Y", dist.HalfNormal(100))
with numpyro.plate("G", len(np.unique(GID))):
a1 = numpyro.sample("a1", dist.Normal(a0, s_ag))
b1 = numpyro.sample("b1", dist.Normal(b0, s_bg))
# a1[map_kid_to_gid]の時点でKIDの数だけ生成されるのでplateなしでも同じ
#with numpyro.plate("K", len(np.unique(KID))):
a = numpyro.sample("a", dist.Normal(a1[map_kid_to_gid], s_a))
b = numpyro.sample("b", dist.Normal(b1[map_kid_to_gid], s_b))
mu = numpyro.deterministic("mu", a[KID]*x + b[KID])
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, s_Y), obs=y)
shapeの確認
少し複雑なので、各サンプルサイトのShapeを確認してみましょう。
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model3).get_trace(
x=df["X"].values,
GID=df["GID"].values,
KID=df["KID"].values,
map_kid_to_gid=map_kid_to_gid,
y=df["Y"].values
)
print(numpyro.util.format_shapes(trace))
Trace Shapes:
Param Sites:
Sample Sites:
a0 dist |
value |
b0 dist |
value |
s_ag dist |
value |
s_bg dist |
value |
s_a dist |
value |
s_b dist |
value |
s_Y dist |
value |
G plate 3 |
a1 dist 3 |
value 3 |
b1 dist 3 |
value 3 |
K plate 30 |
a dist 30 |
value 30 |
b dist 30 |
value 30 |
N plate 300 |
obs dist 300 |
value 300 |
モデルのレンダリング
MCMC
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model3)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(
rng_key=rng_key,
x=df["X"].values,
GID=df["GID"].values,
KID=df["KID"].values,
map_kid_to_gid=map_kid_to_gid,
y=df["Y"].values
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
a[0] 9.08 1.78 9.07 6.01 11.83 2683.95 1.00
a[1] 17.81 2.27 17.79 14.07 21.50 2732.26 1.00
a[2] 11.00 1.97 10.97 7.80 14.29 2673.58 1.00
a[3] 14.60 2.13 14.62 11.25 18.16 3840.61 1.00
a[4] 33.08 2.12 33.03 29.72 36.66 3978.88 1.00
a[5] 36.14 3.00 36.40 31.05 40.76 1425.89 1.00
a[6] 30.88 1.25 30.89 28.84 32.96 4008.69 1.00
a[7] 25.19 1.73 25.17 22.38 28.03 4909.21 1.00
a[8] 22.45 1.50 22.38 20.07 25.00 3242.20 1.00
a[9] 29.06 1.85 29.03 25.97 32.03 1443.74 1.00
a[10] 32.46 1.81 32.51 29.44 35.35 5953.17 1.00
a[11] 34.20 1.36 34.18 32.01 36.42 5050.82 1.00
a[12] 28.77 2.28 28.86 24.98 32.34 2716.09 1.00
a[13] 23.99 2.02 23.96 20.81 27.41 3379.93 1.00
a[14] 33.75 2.77 33.66 29.26 38.24 4712.94 1.00
a[15] 19.46 1.34 19.44 17.29 21.67 4019.57 1.00
a[16] 28.23 1.41 28.21 26.09 30.65 4729.48 1.00
a[17] 28.87 2.32 28.91 24.98 32.65 4110.67 1.00
a[18] 20.84 2.35 20.72 17.09 24.77 2053.59 1.00
a[19] 34.04 1.52 34.06 31.56 36.54 3397.22 1.00
a[20] 26.52 2.49 26.51 22.48 30.62 3974.52 1.00
a[21] 28.18 1.36 28.24 25.82 30.25 1607.29 1.00
a[22] 13.35 1.95 13.33 10.23 16.66 2817.55 1.00
...
s_b 25.68 12.67 24.76 4.38 42.65 430.02 1.00
s_bg 161.02 58.09 155.91 67.22 250.22 4277.68 1.00
Number of divergences: 93
最後に
以上で「階層モデル」は終わりです。基本的にNumpyの行列計算と今までのexpand
やplate
の挙動が理解できていればイメージ掴める内容だったかと思います。次回は離散潜在変数を扱います。
Discussion