NumPyro:基本のモデル
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
前回までで単回帰やNumPyro特有の内容を見てきたので、それらを少し応用して基本的なモデルを実装していきます。理論的な内容は世の中にある優れた本にお任せし、実装中心の内容です。
また、前回までで出てきた内容は極力説明を省いて進行します。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
データの準備
書籍「StanとRでベイズ統計モデリング」のデータを使用します。以下のリンクからcloneしてください。
重回帰
データの準備
chapter05のデータ(chap05/input/data-attendance-1.txt)を使用します。ScoreとYには相関がありそうです。またScoreは0~1付近になるように前処理されています。
df = pd.read_csv("./RStanBook/chap05/input/data-attendance-1.txt")
sns.pairplot(df, hue="A")
df["Score"] = df.Score.values/200.
モデルの定義
単回帰の時とほとんど同じです。複数の説明変数があるので、係数coef
の分布をexpand()
で説明変数の数だけ定義しています。また、coef
が2次元の配列になることに伴いmu
の計算でもjnp.dot(X, coef)
と行列計算を行なっています。
def model(X, y=None):
# 切片
intercept = numpyro.sample("intercept", dist.Normal(0., 100.))
# 重み
coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
# muを計算
mu = numpyro.deterministic("mu", jnp.dot(X, coef) + intercept)
# ノイズ
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
# 正規分布からのサンプリング.yは観測値なので、obs=yを追加
with numpyro.plate("N", len(X)):
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
sample sitesのshapeの確認
重回帰では簡単なモデリングなのでほとんど必要ないですが、自分で実装した際に各サンプルサイトのShapeが知りたいことがよくあります。NumPyroではnumpyro.util.format_shapes()
という便利関数が用意されており、以下のようにtrace
を渡すことで一括でShapeを表示できます。ここで、intercept
とsigma
の形状は()
なため空欄になっていますが、coef
は2
、N plate
の中にあるobs
は50
になっていることが分かります。
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace(X=df[["A", "Score"]].values, y=df["Y"].values)
print(numpyro.util.format_shapes(trace))
Trace Shapes:
Param Sites:
Sample Sites:
intercept dist |
value |
coef dist 2 |
value 2 |
sigma dist |
value |
N plate 50 |
obs dist 50 |
value 50 |
グラフィカルモデル
numpyro.render_model(
model=model,
model_kwargs={"X": df[["A", "Score"]].values, "y": df["Y"].values},
render_params=True,
render_distributions=True
)
MCMC
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df[["A", "Score"]].values,
y=df["Y"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
coef[0] -0.14 0.02 -0.14 -0.17 -0.12 4948.18 1.00
coef[1] 0.32 0.05 0.32 0.24 0.41 3467.48 1.00
intercept 0.12 0.03 0.12 0.07 0.18 3376.21 1.00
sigma 0.05 0.01 0.05 0.04 0.06 4800.30 1.00
Number of divergences: 0
Arvizによる可視化
結果を見るとうまく収束していることが確認できます。
numpyro2az = az.from_numpyro(mcmc)
az.plot_trace(numpyro2az, figsize=(8,4));
ポアソン回帰
上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。
モデルの定義
想定しているデータはchapter05のデータ(chap05/input/data-attendance-2.txt)になります。重回帰の時と同じ素性のデータでYがカウントデータになったものです。
def model(X, y=None):
# 切片
intercept = numpyro.sample("intercept", dist.Normal(0, 100))
# 重み
coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
lambda_ = numpyro.deterministic("lambda", jnp.exp(jnp.dot(X, coef) + intercept))
with numpyro.plate("N", len(X)):
numpyro.sample("obs", dist.Poisson(lambda_), obs=y)
2項ロジスティック回帰
上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。
モデルの定義
想定しているデータはchapter05のデータ(chap05/input/data-attendance-2.txt)になります。NumPyroでは自動でlogitsを計算してくれる分布dist.BinomialLogits
と手動で指定できるdist.BinomialProbs
が用意されているので今回は前者を使用しています。
def model(X, total_count, y=None):
# 切片
intercept = numpyro.sample("intercept", dist.Normal(0, 100))
# 重み
coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
q = numpyro.deterministic("q", jnp.dot(X, coef) + intercept)
with numpyro.plate("N", len(X)):
# logitsに指定することで、自動的にlogitsを計算してくれる
numpyro.sample("obs", dist.BinomialLogits(logits=q, total_count=total_count), obs=y)
ロジスティック回帰
上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。
モデルの定義
想定しているデータはchapter05のデータ(chap05/input/data-attendance-3.txt)になります。重回帰の時と同じ素性のデータでYが(0, 1)のバイナリーデータになったものです。NumPyroでは自動でlogitsを計算してくれる分布dist.BernoulliLogits
と手動で指定できるdist.BernoulliProbs
が用意されているので今回は前者を使用しています。
def model(X, y=None):
# 切片
intercept = numpyro.sample("intercept", dist.Normal(0, 100))
# 重み
coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
q = numpyro.deterministic("q", jnp.dot(X, coef) + intercept)
with numpyro.plate("N", len(X)):
# logitsに指定することで、自動的にlogitsを計算してくれる
numpyro.sample("obs", dist.BernoulliLogits(logits=q), obs=y)
外れ値
以前の単回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。
モデルの定義
想定しているデータはchapter07のデータ(chap07/input/outlier.txt)になります。Yに外れ値が含まれるデータになります。こちらも最後の分布にdist.Cauchy
を使用しているだけです。
def model(x, y=None):
intercept = numpyro.sample("intercept", dist.Normal(0, 100))
coef = numpyro.sample("coef", dist.Normal(0, 100))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mu = numpyro.deterministic("mu", coef*x + intercept)
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Cauchy(mu, sigma), obs=y)
多項ロジスティック回帰
データの準備
chapter10のデータ(chap10/input/data-category.txt)を使用します。XはAge,Sex,Incomeの3つで、Yはカテゴリを表す列になります。
前処理としてAgeとIncomeを0~1付近にする処理とNumPyroはStanと異なり0始まりなのでYの値から1を引いてます。また、アヒル本の実装と同様にXに切片の項を追加しています。
df = pd.read_csv("./RStanBook/chap10/input/data-category.txt")
df["Age"] = df["Age"]/100
df["Income"] = df["Income"]/1000
df["Y"] = df["Y"] - 1
df["Intercept"] = 1
df = df[["Intercept", "Age", "Sex", "Income", "Y"]]
df.head()
Intercept Age Sex Income Y
0 1 0.18 1 0.472 1
1 1 0.18 0 0.468 4
2 1 0.18 1 0.451 5
3 1 0.18 1 0.441 5
4 1 0.18 1 0.499 5
モデルの定義
パラメータを識別可能にするために重みにゼロを設定しています。識別可能性などに関する議論はアヒル本の10章を参照してください。各カテゴリにおけるqを行列計算で同時に計算し、dist.CategoricalLogits
で自動的にlogitsを計算し推論します。
def model(X, y=None):
# Xの次元数
D = X.shape[1]
# クラス数
K = len(np.unique(y))
# 重み
# 識別可能にするためにゼロを追加している.詳細はアヒル本参照.
coef_raw = numpyro.sample("coef", dist.Normal(0.0, 10.0).expand([D, K-1]))
zeros = np.zeros((D, 1))
coef = jnp.concatenate([zeros, coef_raw], axis=1)
# muを計算
q = numpyro.deterministic("q", jnp.dot(X, coef))
with numpyro.plate("N", len(X)):
# logitsに指定することで、自動的にlogitsを計算してくれる
numpyro.sample("obs", dist.CategoricalLogits(logits=q), obs=y)
sample sitesのshapeの確認
coef
は(4, 5)
、N plate
の中にあるobs
は300
になっていることが分かります。
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace(X=df[["Intercept", "Age", "Sex", "Income"]].values, y=df["Y"].values)
print(numpyro.util.format_shapes(trace))
Trace Shapes:
Param Sites:
Sample Sites:
coef dist 4 5 |
value 4 5 |
N plate 300 |
obs dist 300 |
value 300 |
グラフィカルモデル
numpyro.render_model(
model=model,
model_kwargs={"X": df[["Intercept", "Age", "Sex", "Income"]].values, "y": df["Y"].values},
render_params=True,
render_distributions=True
)
MCMC
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df[["Intercept", "Age", "Sex", "Income"]].values,
y=df["Y"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
coef[0,0] -0.18 1.03 -0.16 -1.90 1.48 3559.38 1.00
coef[0,1] 0.16 1.73 0.18 -2.67 2.99 4100.31 1.00
coef[0,2] -0.99 1.15 -0.99 -2.80 0.95 3576.01 1.00
coef[0,3] 1.48 2.35 1.51 -2.26 5.39 5809.16 1.00
coef[0,4] 0.52 1.13 0.53 -1.40 2.32 3709.39 1.00
coef[1,0] -1.21 1.51 -1.20 -3.71 1.23 5457.01 1.00
coef[1,1] -4.16 2.76 -4.10 -8.79 0.22 6483.10 1.00
coef[1,2] -1.90 1.70 -1.91 -4.79 0.76 5736.89 1.00
coef[1,3] -7.76 3.90 -7.63 -13.86 -1.26 9303.49 1.00
coef[1,4] -2.82 1.71 -2.82 -5.62 -0.03 5675.71 1.00
coef[2,0] -1.20 0.35 -1.20 -1.77 -0.62 7126.00 1.00
coef[2,1] -3.30 0.87 -3.23 -4.72 -1.95 8978.87 1.00
coef[2,2] -1.84 0.39 -1.84 -2.48 -1.20 7533.20 1.00
coef[2,3] 0.01 0.98 -0.06 -1.66 1.57 9798.43 1.00
coef[2,4] -1.94 0.39 -1.93 -2.53 -1.23 7530.28 1.00
coef[3,0] 2.66 1.54 2.64 0.08 5.15 4013.50 1.00
coef[3,1] 1.84 2.61 1.82 -2.41 6.19 4806.71 1.00
coef[3,2] 4.41 1.70 4.38 1.80 7.37 4091.33 1.00
coef[3,3] -2.83 3.85 -2.79 -8.96 3.59 6989.27 1.00
coef[3,4] 2.43 1.69 2.43 -0.51 5.00 4189.68 1.00
Number of divergences: 0
simplexベクトルとImproperUniform
上記の例とは少し趣が異なりますが、Stanで定義できるsimplex
などの特殊な変数の定義の仕方を見ていきます。例としては、アヒル本9章のサイコロの例で説明します。
データの準備
chapter09のデータ(chap09/input/data-dice.txt)を使用します。出たサイコロの目を表すFaceの列のみがあります。NumPyroはStanと異なり0始まりなのでFaceの値から1を引いてます。
df = pd.read_csv("./RStanBook/chap09/input/data-dice.txt")
df["Face"] = df["Face"] -1
df.head()
Face
0 0
1 1
2 5
3 4
4 3
モデルの定義
dist.ImproperUniform
NumPyroでは特殊な分布としてImproperUniform(support, batch_shape, event_shape, *, validate_args=None)
が用意されています。最初の引数のsupport
にconstraints.???
を指定することでその制約条件にあった値が使用されることになります。また、第2、第3の引数でbatch_shape
とevent_shape
をそれぞれ指定します。
一見便利な分布ではあるのですが、sample methodが使用できないので一部の関数でNotImplementedErrorが起きることがあルので他の分布で代用できる場合はそちらを使用します。例えば、以下のように使用できます。constraintsは他にもたくさん定義されているので、こちらを参考にしてください。
# ordered vector with length 10
x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))
# real matrix with shape (3, 4)
y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
# a shape-(6, 8) batch of length-5 vectors greater than 3
z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
# aより大きい分布
a = sample('a', Normal(0, 1))
x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
ここではdist.ImproperUniform
を使用した例を載せます。同様の処理はdist.Dirichlet(jnp.ones(n_category)
で実装できます。
def model(Y):
n_category = len(np.unique(Y))
theta = numpyro.sample("theta", dist.ImproperUniform(constraints.simplex, (), (n_category,)))
#theta = numpyro.sample("theta_", dist.Dirichlet(jnp.ones(n_category)))
with numpyro.plate("N", len(Y)):
Y = numpyro.sample("obs", dist.Categorical(theta), obs=Y)
sample sitesのshapeの確認
先に説明したようにdist.ImproperUniformは以下のような制約があるので、NotImplementedErrorが出てしまいます。
sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace(Y=df["Face"].values)
print(numpyro.util.format_shapes(trace))
NotImplementedError Traceback (most recent call last)
Cell In[5], line 2
1 with numpyro.handlers.seed(rng_seed=0):
----> 2 trace = numpyro.handlers.trace(model).get_trace(Y=df["Face"].values)
3 print(numpyro.util.format_shapes(trace))
File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[4], line 3, in model(Y)
1 def model(Y):
...
246 :rtype: numpy.ndarray
247 """
--> 248 raise NotImplementedError
NotImplementedError:
MCMC
from numpyro.infer import init_to_feasible
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model, init_strategy=init_to_feasible)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
Y=df["Face"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
theta[0] 0.11 0.02 0.11 0.07 0.14 2611.22 1.00
theta[1] 0.37 0.03 0.37 0.31 0.42 2221.78 1.00
theta[2] 0.10 0.02 0.10 0.07 0.13 2472.86 1.00
theta[3] 0.25 0.03 0.25 0.20 0.30 2512.91 1.00
theta[4] 0.10 0.02 0.10 0.06 0.13 2152.20 1.00
theta[5] 0.07 0.02 0.07 0.04 0.10 2816.17 1.00
Number of divergences: 0
多変量正規分布
データの準備
chapter09のデータ(chap09/input/data-mvn.txt)を使用します。今回は2変数になります。
df = pd.read_csv("./RStanBook/chap09/input/data-mvn.txt")
df.head()
Y1 Y2
0 9.2 2.56
1 9.8 1.99
2 9.4 2.40
3 9.2 2.27
4 8.1 3.68
モデルの定義
分散共分散行列とLKJ分布との関係はこちらを参考にしてください。
def model(Y):
N, P = Y.shape
with numpyro.plate("features", P):
mu = numpyro.sample("mu", dist.Uniform(-1000, 1000))
sqrt_diag = numpyro.sample("sqrt_diag", dist.Uniform(0, 1))
rho = numpyro.sample("rho", dist.LKJ(dimension=P, concentration=1))
# jnp.diag(sqrt_diag)@rho@jnp.diag(sqrt_diag) と同じ.より高速.
theta = numpyro.deterministic("theta", jnp.outer(sqrt_diag,sqrt_diag)*rho)
with numpyro.plate("N", N):
Y = numpyro.sample("obs", dist.MultivariateNormal(mu, theta), obs = Y)
MCMC
rhoの対角成分は必ず1なのでr_hat=nanでも無視します。
from numpyro.infer import init_to_feasible
# run model
nuts_kernel = NUTS(model, init_strategy=init_to_feasible)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=5000)
mcmc.run(rng_key = random.PRNGKey(0), Y=df[["Y1", "Y2"]].values)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu[0] 9.19 0.10 9.19 9.02 9.36 5157.74 1.00
mu[1] 2.56 0.09 2.56 2.41 2.69 5083.20 1.00
rho[0,0] 1.00 0.00 1.00 1.00 1.00 nan nan
rho[0,1] -0.84 0.05 -0.85 -0.93 -0.76 908.55 1.00
rho[1,0] -0.84 0.05 -0.85 -0.93 -0.76 908.55 1.00
rho[1,1] 1.00 0.00 1.00 1.00 1.00 22.51 1.00
sqrt_diag[0] 0.61 0.09 0.60 0.48 0.74 679.01 1.00
sqrt_diag[1] 0.51 0.07 0.50 0.40 0.61 720.88 1.00
Number of divergences: 0
ここで、得られたrho
とsqrt_diag
から分散共分散行列を計算すると以下のようになりアヒル本の結果とほとんど同じになります。
rho = jnp.array([[1., -0.84], [-0.84, 1.]])
sqrt_diag = jnp.array([0.61, 0.50])
jnp.outer(sqrt_diag,sqrt_diag)*rho
Array([[ 0.37210003, -0.2562 ],
[-0.2562 , 0.25 ]], dtype=float32)
最後に
以上で「基本のモデル」は終わりです。次回からより詳細なトピックの説明になります。
Discussion