💬

NumPyro:階層モデル

2023/04/20に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は階層モデルを扱います。以前までと同様に本シリーズの記事は実装メインなため、理論的な詳細は省略します。最初に1つの階層のモデルを扱い、次に複数の階層のモデルを扱います。どちらもNumpyの行列計算と今までのexpandplateの挙動が理解できていればイメージ掴める内容かと思います。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

単一の階層モデル

データの読み込み

前回までと同様にアヒル本のデータを使用します。今回はchap08のchap08/input/data-salary-2.txtを使用します。説明変数はXのみで、個体を表すIDとしてKIDが与えられています。また、Stanは1始まりですが、NumPyroは0始まりなのでKID - 1を行なっています。

df = pd.read_csv("../tutorial/RStanBook/chap08/input/data-salary-2.txt")
df["KID"] = df["KID"] -1
df.head()
X	Y	KID
0	7	457	0
1	10	482	0
2	16	518	0
3	25	535	0
4	5	427	0

ちなみに、jax.numpyだとnumpyと異なり、以下のようにIndexを超えてアクセスしても値が返されてしまいます。そのため、バグに気づかないこともあるので注意してください。(私はこれで数時間無駄にしました。。。)

numpy_array = np.array([0, 1, 2, 3])
jax_numpy_array = jnp.array([0, 1, 2, 3])

try:
    print(numpy_array[10])
except IndexError:
    print("Numpy だとIndexErrorが出る")

try:
    print(jax_numpy_array[10])
except IndexError:
    print("jax.numpy だとIndexErrorが出ない")
Numpy だとIndexErrorが出る
3

モデルの定義

モデルの等価な式の形として2通りあるので両方とも実装する。ここで、階層モデルを扱う際に以下のNumpyの操作を使用しています。通常のNumpyと同様の挙動ですが、一応確認しておきましょう。

# coef[KID]では以下のように動作します
kid = np.array([0, 1, 2, 3, 2, 1, 0])
coef_diff = np.array([0, 10, 20, 30])

coef_diff[kid]
array([ 0, 10, 20, 30, 20, 10,  0])

係数の式をまとめない場合

coef[k] = coef_{common} + coef_{diff}[k] \\ coef_{diff}[k] \sim Normal(0, \sigma_{coef}) \\ intercept[k] = intercept_{common} + intercept_{diff}[k] \\ intercept_{diff}[k] \sim Normal(0, \sigma_{intercept})
def model1(x, KID, y=None):
    intercept_common = numpyro.sample("intercept_common", dist.Normal(0, 100))
    coef_common = numpyro.sample("coef_common", dist.Normal(0, 100))
    sigma_intercept = numpyro.sample("sigma_intercept", dist.HalfNormal(100))
    sigma_coef = numpyro.sample("sigma_coef", dist.HalfNormal(100))
    sigma_Y = numpyro.sample("sigma_Y", dist.HalfNormal(100))
    
    with numpyro.plate("K", len(np.unique(KID))):
        coef_diff = numpyro.sample("coef_diff", dist.Normal(0, sigma_coef))
        intercept_diff = numpyro.sample("intercept_diff", dist.Normal(0, sigma_intercept))  
    
    coef = numpyro.deterministic("coef", coef_common + coef_diff)
    intercept = numpyro.deterministic("intercept", intercept_common + intercept_diff)    
    mu = numpyro.deterministic("mu", coef[KID]*x + intercept[KID])
    
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Normal(mu, sigma_Y), obs=y)

係数の式をまとめる場合

coef[k] \sim Normal(coef_{common}, \sigma_{coef}) \\ intercept[k] \sim Normal(intercept_{common}, \sigma_{intercept})
def model2(x, KID, y=None):
    intercept_common = numpyro.sample("intercept_common", dist.Normal(0, 100))
    coef_common = numpyro.sample("coef_common", dist.Normal(0, 100))
    sigma_intercept = numpyro.sample("sigma_intercept", dist.HalfNormal(100))
    sigma_coef = numpyro.sample("sigma_coef", dist.HalfNormal(100))
    sigma_Y = numpyro.sample("sigma_Y", dist.HalfNormal(100))
    
    with numpyro.plate("K", len(np.unique(KID))):
        coef = numpyro.sample("coef", dist.Normal(coef_common, sigma_coef))
        intercept = numpyro.sample("intercept", dist.Normal(intercept_common, sigma_intercept))  

    mu = numpyro.deterministic("mu", coef[KID]*x + intercept[KID])
    
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Normal(mu, sigma_Y), obs=y)

モデルのレンダリング

それぞれのモデルを描画してみます。numpyro.plateを使用することで階層も分かりやすく表示されますね。

係数の式をまとめない場合

係数の式をまとめる場合

MCMC

結果を確認するとちゃんと収束していることが分かります。また、実行時間は4chainsで3.25s, 3.1sとほとんど変わらないですね。アヒル本の結果を見るとStanでは1chainsごとに2s, 1sという結果のようなので、コンパイル時間や実行時間の点でも今回の例ではNumPyroの速さが際立っています。(他のサイトの結果を見るとStanとNumPyroではほとんど差がかわらないが、コンパイル時間まで加味するとNumPyroが速いという結果が多いように思えます)

係数の式をまとめない場合

%%time
# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc1 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc1.run(
    rng_key=rng_key,
    x=df["X"].values,
    KID=df["KID"].values,
    y=df["Y"].values,
)
mcmc1.print_summary()
CPU times: user 4.21 s, sys: 65.7 ms, total: 4.27 s
Wall time: 3.25 s

                        mean       std    median      5.0%     95.0%     n_eff     r_hat
       coef_common     13.08      6.74     13.05      3.53     22.10    865.19      1.00
      coef_diff[0]     -5.17      6.75     -5.10    -14.11      4.31    875.46      1.00
      coef_diff[1]      6.46      6.83      6.32     -2.87     15.98    879.47      1.00
      coef_diff[2]     -0.73      6.83     -0.79    -10.24      8.86    878.51      1.00
      coef_diff[3]     -0.37      7.23     -0.00    -11.21      8.89    998.62      1.00
  intercept_common    287.39     73.96    306.88    169.71    386.68    899.60      1.00
 intercept_diff[0]     93.05     77.00     74.28    -10.02    214.83    907.62      1.00
 intercept_diff[1]     45.16     73.34     24.29    -58.57    161.55    926.99      1.00
 intercept_diff[2]     28.84     72.88      9.52    -73.35    150.35   1035.96      1.00
 intercept_diff[3]    121.24    135.47     82.92    -46.88    337.16    958.68      1.00
           sigma_Y     29.18      4.06     28.82     22.80     35.44   2792.77      1.00
        sigma_coef     10.32     10.06      7.36      2.26     19.39   1275.23      1.00
   sigma_intercept     95.76     65.00     80.52      2.64    189.72   1008.46      1.00

Number of divergences: 5

係数の式をまとめる場合

%%time
# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc2 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc2.run(
    rng_key=rng_key,
    x=df["X"].values,
    KID=df["KID"].values,
    y=df["Y"].values,
)
mcmc2.print_summary()
CPU times: user 4.08 s, sys: 55.9 ms, total: 4.14 s
Wall time: 3.1 s

                        mean       std    median      5.0%     95.0%     n_eff     r_hat
       coef_common     13.08      6.74     13.05      3.53     22.10    865.19      1.00
      coef_diff[0]     -5.17      6.75     -5.10    -14.11      4.31    875.46      1.00
      coef_diff[1]      6.46      6.83      6.32     -2.87     15.98    879.47      1.00
      coef_diff[2]     -0.73      6.83     -0.79    -10.24      8.86    878.51      1.00
      coef_diff[3]     -0.37      7.23     -0.00    -11.21      8.89    998.62      1.00
  intercept_common    287.39     73.96    306.88    169.71    386.68    899.60      1.00
 intercept_diff[0]     93.05     77.00     74.28    -10.02    214.83    907.62      1.00
 intercept_diff[1]     45.16     73.34     24.29    -58.57    161.55    926.99      1.00
 intercept_diff[2]     28.84     72.88      9.52    -73.35    150.35   1035.96      1.00
 intercept_diff[3]    121.24    135.47     82.92    -46.88    337.16    958.68      1.00
           sigma_Y     29.18      4.06     28.82     22.80     35.44   2792.77      1.00
        sigma_coef     10.32     10.06      7.36      2.26     19.39   1275.23      1.00
   sigma_intercept     95.76     65.00     80.52      2.64    189.72   1008.46      1.00

複数の階層モデル

NumPyroのチュートリアルを一部参考にしてコードを作成しています。
https://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html

データの読み込み

前回までと同様にアヒル本のデータを使用します。今回はchap08のchap08/input/data-salary-3.txtを使用します。説明変数はXのみで、個体を属するIDとしてKIDが与えられており、KIDが属するさらに大きな概念としてGIDが与えられています。また先ほどと同様に、Stanは1始まりですが、NumPyroは0始まりなので1引いています。

df = pd.read_csv("./RStanBook/chap08/input/data-salary-3.txt")
df["KID"] = df["KID"] -1
df["GID"] = df["GID"] -1
print(df.KID.unique())
print(df.GID.unique())
df.head()
X	Y	KID
0	7	457	0
1	10	482	0
2	16	518	0
3	25	535	0
4	5	427	0

ここで、次のモデル内部で使用するのでKIDからGIDへマッピングする配列を作成しておきます。

map_kid_to_gid= (
    df[["KID", "GID"]]
    .drop_duplicates()
    .set_index("KID", verify_integrity=True)
    .sort_index()["GID"]
    .values
)
print(map_kid_to_gid)
array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       2, 2, 2, 2, 2, 2, 2, 2])

モデルの定義

先ほどのモデル2の分布の平均coef_{common}に所属するグループの平均を採用した2段階の階層モデルになっています。

a[k] \sim Normal(a_{group}[K2G[k]], \sigma_a) \\ a_{group}[g] \sim Normal(a_{global}, \sigma_{ag})
def model3(x, GID, KID, map_kid_to_gid, y=None):
    a0 = numpyro.sample("a0", dist.Normal(0, 100))
    b0 = numpyro.sample("b0", dist.Normal(0, 100))
    s_ag = numpyro.sample("s_ag", dist.HalfNormal(100))
    s_bg = numpyro.sample("s_bg", dist.HalfNormal(100))
    s_a = numpyro.sample("s_a", dist.HalfNormal(100))
    s_b = numpyro.sample("s_b", dist.HalfNormal(100))
    s_Y = numpyro.sample("s_Y", dist.HalfNormal(100))
    
    with numpyro.plate("G", len(np.unique(GID))):
        a1 = numpyro.sample("a1", dist.Normal(a0, s_ag))
        b1 = numpyro.sample("b1", dist.Normal(b0, s_bg))
    
    with numpyro.plate("K", len(np.unique(KID))):
        a = numpyro.sample("a", dist.Normal(a1[map_kid_to_gid], s_a))
        b = numpyro.sample("b", dist.Normal(b1[map_kid_to_gid], s_b))
    
    mu = numpyro.deterministic("mu", a[KID]*x + b[KID])
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Normal(mu, s_Y), obs=y)

ここで、少し慣れた人は2つ目のnumpyro.plateはいらないのではないかと思うかと思います。なぜならa1[map_kid_to_gid]が長さlen(np.unique(KID))の配列であり、plateがなくてもlen(np.unique(KID))の数だけ分布が生成されるためです。これは正しい認識でして、plateをコメントアウトした以下の実装でも結果は同じです。しかし、plateがあることで次のモデルのレンダリング時にプレートが複数表示されたり、コード的にも階層であることが分かりやすくなったりするので(恐らく)わざわざ書くのが推奨の書き方かと思います。(公式のチュートリアルの書き方も同じです)

def model3_non(x, GID, KID, map_kid_to_gid, y=None):
    a0 = numpyro.sample("a0", dist.Normal(0, 100))
    b0 = numpyro.sample("b0", dist.Normal(0, 100))
    s_ag = numpyro.sample("s_ag", dist.HalfNormal(100))
    s_bg = numpyro.sample("s_bg", dist.HalfNormal(100))
    s_a = numpyro.sample("s_a", dist.HalfNormal(100))
    s_b = numpyro.sample("s_b", dist.HalfNormal(100))
    s_Y = numpyro.sample("s_Y", dist.HalfNormal(100))
    
    with numpyro.plate("G", len(np.unique(GID))):
        a1 = numpyro.sample("a1", dist.Normal(a0, s_ag))
        b1 = numpyro.sample("b1", dist.Normal(b0, s_bg))
    
    # a1[map_kid_to_gid]の時点でKIDの数だけ生成されるのでplateなしでも同じ
    #with numpyro.plate("K", len(np.unique(KID))):
    a = numpyro.sample("a", dist.Normal(a1[map_kid_to_gid], s_a))
    b = numpyro.sample("b", dist.Normal(b1[map_kid_to_gid], s_b))
    
    mu = numpyro.deterministic("mu", a[KID]*x + b[KID])
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Normal(mu, s_Y), obs=y)

shapeの確認

少し複雑なので、各サンプルサイトのShapeを確認してみましょう。

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model3).get_trace(
        x=df["X"].values,
        GID=df["GID"].values,
        KID=df["KID"].values,
        map_kid_to_gid=map_kid_to_gid, 
        y=df["Y"].values
    )
print(numpyro.util.format_shapes(trace))
Trace Shapes:      
 Param Sites:      
Sample Sites:      
      a0 dist     |
        value     |
      b0 dist     |
        value     |
    s_ag dist     |
        value     |
    s_bg dist     |
        value     |
     s_a dist     |
        value     |
     s_b dist     |
        value     |
     s_Y dist     |
        value     |
      G plate   3 |
      a1 dist   3 |
        value   3 |
      b1 dist   3 |
        value   3 |
      K plate  30 |
       a dist  30 |
        value  30 |
       b dist  30 |
        value  30 |
      N plate 300 |
     obs dist 300 |
        value 300 |

モデルのレンダリング

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model3)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(
    rng_key=rng_key,
    x=df["X"].values, 
    GID=df["GID"].values,
    KID=df["KID"].values, 
    map_kid_to_gid=map_kid_to_gid,
    y=df["Y"].values
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      a[0]      9.08      1.78      9.07      6.01     11.83   2683.95      1.00
      a[1]     17.81      2.27     17.79     14.07     21.50   2732.26      1.00
      a[2]     11.00      1.97     10.97      7.80     14.29   2673.58      1.00
      a[3]     14.60      2.13     14.62     11.25     18.16   3840.61      1.00
      a[4]     33.08      2.12     33.03     29.72     36.66   3978.88      1.00
      a[5]     36.14      3.00     36.40     31.05     40.76   1425.89      1.00
      a[6]     30.88      1.25     30.89     28.84     32.96   4008.69      1.00
      a[7]     25.19      1.73     25.17     22.38     28.03   4909.21      1.00
      a[8]     22.45      1.50     22.38     20.07     25.00   3242.20      1.00
      a[9]     29.06      1.85     29.03     25.97     32.03   1443.74      1.00
     a[10]     32.46      1.81     32.51     29.44     35.35   5953.17      1.00
     a[11]     34.20      1.36     34.18     32.01     36.42   5050.82      1.00
     a[12]     28.77      2.28     28.86     24.98     32.34   2716.09      1.00
     a[13]     23.99      2.02     23.96     20.81     27.41   3379.93      1.00
     a[14]     33.75      2.77     33.66     29.26     38.24   4712.94      1.00
     a[15]     19.46      1.34     19.44     17.29     21.67   4019.57      1.00
     a[16]     28.23      1.41     28.21     26.09     30.65   4729.48      1.00
     a[17]     28.87      2.32     28.91     24.98     32.65   4110.67      1.00
     a[18]     20.84      2.35     20.72     17.09     24.77   2053.59      1.00
     a[19]     34.04      1.52     34.06     31.56     36.54   3397.22      1.00
     a[20]     26.52      2.49     26.51     22.48     30.62   3974.52      1.00
     a[21]     28.18      1.36     28.24     25.82     30.25   1607.29      1.00
     a[22]     13.35      1.95     13.33     10.23     16.66   2817.55      1.00
...
       s_b     25.68     12.67     24.76      4.38     42.65    430.02      1.00
      s_bg    161.02     58.09    155.91     67.22    250.22   4277.68      1.00

Number of divergences: 93

最後に

以上で「階層モデル」は終わりです。基本的にNumpyの行列計算と今までのexpandplateの挙動が理解できていればイメージ掴める内容だったかと思います。次回は離散潜在変数を扱います。

Discussion