NumPyro:基本のモデル

2023/04/18に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

前回までで単回帰やNumPyro特有の内容を見てきたので、それらを少し応用して基本的なモデルを実装していきます。理論的な内容は世の中にある優れた本にお任せし、実装中心の内容です。
また、前回までで出てきた内容は極力説明を省いて進行します。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

データの準備

書籍「StanとRでベイズ統計モデリング」のデータを使用します。以下のリンクからcloneしてください。
https://www.kyoritsu-pub.co.jp/book/b10003786.html
https://github.com/MatsuuraKentaro/RStanBook

重回帰

データの準備

chapter05のデータ(chap05/input/data-attendance-1.txt)を使用します。ScoreとYには相関がありそうです。またScoreは0~1付近になるように前処理されています。

df = pd.read_csv("./RStanBook/chap05/input/data-attendance-1.txt")
sns.pairplot(df, hue="A")
df["Score"] = df.Score.values/200.

モデルの定義

単回帰の時とほとんど同じです。複数の説明変数があるので、係数coefの分布をexpand()で説明変数の数だけ定義しています。また、coefが2次元の配列になることに伴いmuの計算でもjnp.dot(X, coef)と行列計算を行なっています。

def model(X, y=None):
    # 切片
    intercept = numpyro.sample("intercept", dist.Normal(0., 100.))
    # 重み
    coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
    
    # muを計算
    mu = numpyro.deterministic("mu", jnp.dot(X, coef) + intercept)
    
    # ノイズ
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    
    # 正規分布からのサンプリング.yは観測値なので、obs=yを追加
    with numpyro.plate("N", len(X)):
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

sample sitesのshapeの確認

重回帰では簡単なモデリングなのでほとんど必要ないですが、自分で実装した際に各サンプルサイトのShapeが知りたいことがよくあります。NumPyroではnumpyro.util.format_shapes()という便利関数が用意されており、以下のようにtraceを渡すことで一括でShapeを表示できます。ここで、interceptsigmaの形状は()なため空欄になっていますが、coef2N plateの中にあるobs50になっていることが分かります。

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model).get_trace(X=df[["A", "Score"]].values, y=df["Y"].values)
print(numpyro.util.format_shapes(trace))
 Trace Shapes:     
  Param Sites:     
 Sample Sites:     
intercept dist    |
         value    |
     coef dist  2 |
         value  2 |
    sigma dist    |
         value    |
       N plate 50 |
      obs dist 50 |
         value 50 |

グラフィカルモデル

numpyro.render_model(
    model=model, 
    model_kwargs={"X": df[["A", "Score"]].values, "y": df["Y"].values}, 
    render_params=True, 
    render_distributions=True
)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df[["A", "Score"]].values,
    y=df["Y"].values,
)
mcmc.print_summary()
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
    coef[0]     -0.14      0.02     -0.14     -0.17     -0.12   4948.18      1.00
    coef[1]      0.32      0.05      0.32      0.24      0.41   3467.48      1.00
  intercept      0.12      0.03      0.12      0.07      0.18   3376.21      1.00
      sigma      0.05      0.01      0.05      0.04      0.06   4800.30      1.00

Number of divergences: 0

Arvizによる可視化

結果を見るとうまく収束していることが確認できます。

numpyro2az = az.from_numpyro(mcmc)
az.plot_trace(numpyro2az, figsize=(8,4));

ポアソン回帰

上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。

モデルの定義

想定しているデータはchapter05のデータ(chap05/input/data-attendance-2.txt)になります。重回帰の時と同じ素性のデータでYがカウントデータになったものです。

def model(X, y=None):
    # 切片
    intercept = numpyro.sample("intercept", dist.Normal(0, 100))
    # 重み
    coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
    
    lambda_ = numpyro.deterministic("lambda", jnp.exp(jnp.dot(X, coef) + intercept))
    
    with numpyro.plate("N", len(X)):
        numpyro.sample("obs", dist.Poisson(lambda_), obs=y)

2項ロジスティック回帰

上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。

モデルの定義

想定しているデータはchapter05のデータ(chap05/input/data-attendance-2.txt)になります。NumPyroでは自動でlogitsを計算してくれる分布dist.BinomialLogitsと手動で指定できるdist.BinomialProbsが用意されているので今回は前者を使用しています。

def model(X, total_count, y=None):
    # 切片
    intercept = numpyro.sample("intercept", dist.Normal(0, 100))
    # 重み
    coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
    
    q = numpyro.deterministic("q", jnp.dot(X, coef) + intercept)
    
    with numpyro.plate("N", len(X)):
        # logitsに指定することで、自動的にlogitsを計算してくれる
        numpyro.sample("obs", dist.BinomialLogits(logits=q, total_count=total_count), obs=y)

ロジスティック回帰

上記の重回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。

モデルの定義

想定しているデータはchapter05のデータ(chap05/input/data-attendance-3.txt)になります。重回帰の時と同じ素性のデータでYが(0, 1)のバイナリーデータになったものです。NumPyroでは自動でlogitsを計算してくれる分布dist.BernoulliLogitsと手動で指定できるdist.BernoulliProbsが用意されているので今回は前者を使用しています。

def model(X, y=None):
    # 切片
    intercept = numpyro.sample("intercept", dist.Normal(0, 100))
    # 重み
    coef = numpyro.sample("coef", dist.Normal(0.0, 100).expand([X.shape[1]]))
    
    q = numpyro.deterministic("q", jnp.dot(X, coef) + intercept)
    
    with numpyro.plate("N", len(X)):
        # logitsに指定することで、自動的にlogitsを計算してくれる
        numpyro.sample("obs", dist.BernoulliLogits(logits=q), obs=y)

外れ値

以前の単回帰の例と一部だけ異なるだけなので、モデルの定義だけ例を載せておきます。

モデルの定義

想定しているデータはchapter07のデータ(chap07/input/outlier.txt)になります。Yに外れ値が含まれるデータになります。こちらも最後の分布にdist.Cauchyを使用しているだけです。

def model(x, y=None):
    intercept = numpyro.sample("intercept", dist.Normal(0, 100))
    coef = numpyro.sample("coef", dist.Normal(0, 100))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    
    mu = numpyro.deterministic("mu", coef*x + intercept)
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Cauchy(mu, sigma), obs=y)

多項ロジスティック回帰

データの準備

chapter10のデータ(chap10/input/data-category.txt)を使用します。XはAge,Sex,Incomeの3つで、Yはカテゴリを表す列になります。
前処理としてAgeとIncomeを0~1付近にする処理とNumPyroはStanと異なり0始まりなのでYの値から1を引いてます。また、アヒル本の実装と同様にXに切片の項を追加しています。

df = pd.read_csv("./RStanBook/chap10/input/data-category.txt")
df["Age"] = df["Age"]/100
df["Income"] = df["Income"]/1000
df["Y"] = df["Y"] - 1
df["Intercept"] = 1
df = df[["Intercept", "Age", "Sex", "Income", "Y"]]
df.head()
Intercept	Age	Sex	Income	Y
0	1	0.18	1	0.472	1
1	1	0.18	0	0.468	4
2	1	0.18	1	0.451	5
3	1	0.18	1	0.441	5
4	1	0.18	1	0.499	5

モデルの定義

パラメータを識別可能にするために重みにゼロを設定しています。識別可能性などに関する議論はアヒル本の10章を参照してください。各カテゴリにおけるqを行列計算で同時に計算し、dist.CategoricalLogitsで自動的にlogitsを計算し推論します。

def model(X, y=None):
    # Xの次元数
    D = X.shape[1]
    # クラス数
    K = len(np.unique(y))
    
    # 重み
    # 識別可能にするためにゼロを追加している.詳細はアヒル本参照.
    coef_raw = numpyro.sample("coef", dist.Normal(0.0, 10.0).expand([D, K-1]))
    zeros = np.zeros((D, 1))
    coef = jnp.concatenate([zeros, coef_raw], axis=1)
    
    # muを計算
    q = numpyro.deterministic("q", jnp.dot(X, coef)) 
    
    with numpyro.plate("N", len(X)):
        # logitsに指定することで、自動的にlogitsを計算してくれる
        numpyro.sample("obs", dist.CategoricalLogits(logits=q), obs=y)

sample sitesのshapeの確認

coef(4, 5)N plateの中にあるobs300になっていることが分かります。

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model).get_trace(X=df[["Intercept", "Age", "Sex", "Income"]].values, y=df["Y"].values)
print(numpyro.util.format_shapes(trace))
Trace Shapes:        
 Param Sites:        
Sample Sites:        
    coef dist 4   5 |
        value 4   5 |
      N plate   300 |
     obs dist   300 |
        value   300 |

グラフィカルモデル

numpyro.render_model(
    model=model, 
    model_kwargs={"X": df[["Intercept", "Age", "Sex", "Income"]].values, "y": df["Y"].values}, 
    render_params=True, 
    render_distributions=True
)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df[["Intercept", "Age", "Sex", "Income"]].values,
    y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
 coef[0,0]     -0.18      1.03     -0.16     -1.90      1.48   3559.38      1.00
 coef[0,1]      0.16      1.73      0.18     -2.67      2.99   4100.31      1.00
 coef[0,2]     -0.99      1.15     -0.99     -2.80      0.95   3576.01      1.00
 coef[0,3]      1.48      2.35      1.51     -2.26      5.39   5809.16      1.00
 coef[0,4]      0.52      1.13      0.53     -1.40      2.32   3709.39      1.00
 coef[1,0]     -1.21      1.51     -1.20     -3.71      1.23   5457.01      1.00
 coef[1,1]     -4.16      2.76     -4.10     -8.79      0.22   6483.10      1.00
 coef[1,2]     -1.90      1.70     -1.91     -4.79      0.76   5736.89      1.00
 coef[1,3]     -7.76      3.90     -7.63    -13.86     -1.26   9303.49      1.00
 coef[1,4]     -2.82      1.71     -2.82     -5.62     -0.03   5675.71      1.00
 coef[2,0]     -1.20      0.35     -1.20     -1.77     -0.62   7126.00      1.00
 coef[2,1]     -3.30      0.87     -3.23     -4.72     -1.95   8978.87      1.00
 coef[2,2]     -1.84      0.39     -1.84     -2.48     -1.20   7533.20      1.00
 coef[2,3]      0.01      0.98     -0.06     -1.66      1.57   9798.43      1.00
 coef[2,4]     -1.94      0.39     -1.93     -2.53     -1.23   7530.28      1.00
 coef[3,0]      2.66      1.54      2.64      0.08      5.15   4013.50      1.00
 coef[3,1]      1.84      2.61      1.82     -2.41      6.19   4806.71      1.00
 coef[3,2]      4.41      1.70      4.38      1.80      7.37   4091.33      1.00
 coef[3,3]     -2.83      3.85     -2.79     -8.96      3.59   6989.27      1.00
 coef[3,4]      2.43      1.69      2.43     -0.51      5.00   4189.68      1.00

Number of divergences: 0

simplexベクトルとImproperUniform

上記の例とは少し趣が異なりますが、Stanで定義できるsimplexなどの特殊な変数の定義の仕方を見ていきます。例としては、アヒル本9章のサイコロの例で説明します。

データの準備

chapter09のデータ(chap09/input/data-dice.txt)を使用します。出たサイコロの目を表すFaceの列のみがあります。NumPyroはStanと異なり0始まりなのでFaceの値から1を引いてます。

df = pd.read_csv("./RStanBook/chap09/input/data-dice.txt")
df["Face"] = df["Face"] -1
df.head()
	Face
0	0
1	1
2	5
3	4
4	3

モデルの定義

dist.ImproperUniform

NumPyroでは特殊な分布としてImproperUniform(support, batch_shape, event_shape, *, validate_args=None)が用意されています。最初の引数のsupportconstraints.???を指定することでその制約条件にあった値が使用されることになります。また、第2、第3の引数でbatch_shapeevent_shapeをそれぞれ指定します。
一見便利な分布ではあるのですが、sample methodが使用できないので一部の関数でNotImplementedErrorが起きることがあルので他の分布で代用できる場合はそちらを使用します。例えば、以下のように使用できます。constraintsは他にもたくさん定義されているので、こちらを参考にしてください。

# ordered vector with length 10
x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))

# real matrix with shape (3, 4)
y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))

# a shape-(6, 8) batch of length-5 vectors greater than 3
z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))

# aより大きい分布
a = sample('a', Normal(0, 1))
x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))

ここではdist.ImproperUniformを使用した例を載せます。同様の処理はdist.Dirichlet(jnp.ones(n_category)で実装できます。

def model(Y):
    n_category = len(np.unique(Y))
    theta = numpyro.sample("theta", dist.ImproperUniform(constraints.simplex, (), (n_category,)))
    #theta = numpyro.sample("theta_", dist.Dirichlet(jnp.ones(n_category)))

    with numpyro.plate("N", len(Y)):
        Y = numpyro.sample("obs", dist.Categorical(theta), obs=Y)

sample sitesのshapeの確認

先に説明したようにdist.ImproperUniformは以下のような制約があるので、NotImplementedErrorが出てしまいます。

sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model).get_trace(Y=df["Face"].values)
print(numpyro.util.format_shapes(trace))
NotImplementedError                       Traceback (most recent call last)
Cell In[5], line 2
      1 with numpyro.handlers.seed(rng_seed=0):
----> 2     trace = numpyro.handlers.trace(model).get_trace(Y=df["Face"].values)
      3 print(numpyro.util.format_shapes(trace))

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[4], line 3, in model(Y)
      1 def model(Y):
...
    246     :rtype: numpy.ndarray
    247     """
--> 248     raise NotImplementedError

NotImplementedError: 

MCMC

from numpyro.infer import init_to_feasible

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model, init_strategy=init_to_feasible)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    Y=df["Face"].values,
)
mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  theta[0]      0.11      0.02      0.11      0.07      0.14   2611.22      1.00
  theta[1]      0.37      0.03      0.37      0.31      0.42   2221.78      1.00
  theta[2]      0.10      0.02      0.10      0.07      0.13   2472.86      1.00
  theta[3]      0.25      0.03      0.25      0.20      0.30   2512.91      1.00
  theta[4]      0.10      0.02      0.10      0.06      0.13   2152.20      1.00
  theta[5]      0.07      0.02      0.07      0.04      0.10   2816.17      1.00

Number of divergences: 0

多変量正規分布

データの準備

chapter09のデータ(chap09/input/data-mvn.txt)を使用します。今回は2変数になります。

df = pd.read_csv("./RStanBook/chap09/input/data-mvn.txt")
df.head()
	Y1	Y2
0	9.2	2.56
1	9.8	1.99
2	9.4	2.40
3	9.2	2.27
4	8.1	3.68

モデルの定義

分散共分散行列とLKJ分布との関係はこちらを参考にしてください。

def model(Y):
    N, P = Y.shape
    
    with numpyro.plate("features", P):
        mu = numpyro.sample("mu", dist.Uniform(-1000, 1000))
        sqrt_diag = numpyro.sample("sqrt_diag", dist.Uniform(0, 1))
        
    rho = numpyro.sample("rho", dist.LKJ(dimension=P, concentration=1))  
    
    # jnp.diag(sqrt_diag)@rho@jnp.diag(sqrt_diag) と同じ.より高速.
    theta = numpyro.deterministic("theta", jnp.outer(sqrt_diag,sqrt_diag)*rho)
    
    with numpyro.plate("N", N):
        Y = numpyro.sample("obs", dist.MultivariateNormal(mu, theta), obs = Y)

MCMC

rhoの対角成分は必ず1なのでr_hat=nanでも無視します。

from numpyro.infer import init_to_feasible
# run model
nuts_kernel = NUTS(model, init_strategy=init_to_feasible)

mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=5000)
mcmc.run(rng_key = random.PRNGKey(0), Y=df[["Y1", "Y2"]].values)
mcmc.print_summary()
                  mean       std    median      5.0%     95.0%     n_eff     r_hat
       mu[0]      9.19      0.10      9.19      9.02      9.36   5157.74      1.00
       mu[1]      2.56      0.09      2.56      2.41      2.69   5083.20      1.00
    rho[0,0]      1.00      0.00      1.00      1.00      1.00       nan       nan
    rho[0,1]     -0.84      0.05     -0.85     -0.93     -0.76    908.55      1.00
    rho[1,0]     -0.84      0.05     -0.85     -0.93     -0.76    908.55      1.00
    rho[1,1]      1.00      0.00      1.00      1.00      1.00     22.51      1.00
sqrt_diag[0]      0.61      0.09      0.60      0.48      0.74    679.01      1.00
sqrt_diag[1]      0.51      0.07      0.50      0.40      0.61    720.88      1.00

Number of divergences: 0

ここで、得られたrhosqrt_diagから分散共分散行列を計算すると以下のようになりアヒル本の結果とほとんど同じになります。

rho = jnp.array([[1., -0.84], [-0.84, 1.]])
sqrt_diag = jnp.array([0.61, 0.50])
jnp.outer(sqrt_diag,sqrt_diag)*rho
Array([[ 0.37210003, -0.2562    ],
       [-0.2562    ,  0.25      ]], dtype=float32)

最後に

以上で「基本のモデル」は終わりです。次回からより詳細なトピックの説明になります。

Discussion