📘

NumPyro:離散潜在変数の扱い方

2023/04/21に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は離散潜在変数の扱い方を説明します。Stanでは自動で扱うことができないため、基本的に場合の数を数え上げて離散パラメータを消去(周辺化)した形で対数尤度を表現する必要があります。しかし、NumPyroでは一部制限はありますが自動で離散潜在変数も扱うことができます。(アヒル本のLDAはこれらに引っかかりこの機能付きでの実装はできませんでした。もちろんStanと同様に手作業で周辺化すればいいのですが、うまく実装できる方法がもしあれば教えてください。。。)

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

1. Bernoulli分布

例題としてアヒル本のモデル式11-1を扱います。背景はアヒル本を参照していただいて、実装する式は以下になります。

coin[n] \sim Bernoulli(0.5) \\ \theta[0] = q, \theta[1] = 1.0 \\ Y[n] \sim Bernoulli(\theta[coin[n]])

データの準備

アヒル本のchap11のデータ(chap11/input/data-coin.txt)を使用します。

df = pd.read_csv("./RStanBook/chap11/input/data-coin.txt")
df.head()
	Y
0	1
1	0
2	1
3	0
4	0

モデルの定義

NumPyroでもStanと同様に手動で周辺化ができるので2通りの方法で実装します。

自動

infer={‘enumerate’: ‘parallel’}により、MCMCにこの離散潜在変数を周辺化することを教えます。

def model1(y=None):
    
    q = numpyro.sample("q", dist.Uniform(0, 1))
    
    with numpyro.plate("N", len(y)):
        coin = numpyro.sample("coin", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
        theta = jnp.where(coin == 0, q, 1.0)
        numpyro.sample("obs", dist.Bernoulli(theta), obs=y)

手動

Stanの場合はtarget += ・・・の形で対数尤度を加えていきますが、NumPyroではnumpyro.factor(name, lp)で追加できます。

def model2(y=None):
    
    q = numpyro.sample("q", dist.Uniform(0, 1))
    
    p = jnp.exp(dist.Bernoulli(0.5).log_prob(0.5) + dist.Bernoulli(q).log_prob(y)) 
    p += jnp.exp(dist.Bernoulli(0.5).log_prob(0.5) + dist.Bernoulli(1).log_prob(y))  
    lp = jnp.log(p)
    
    numpyro.factor("lp", lp)

MCMC

model1もmodel2も同じ結果になるのでmolde1を実行します。このように一番最初に述べたNumPyroの制限に引っかからないケースであれば離散潜在変数を意識せずに実行することができます。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         q      0.20      0.09      0.20      0.05      0.35   1864.31      1.00

Number of divergences: 0

ここで、coinは周辺化されているのでqの情報しか得られていません。そのため、以下のような操作を実行することで離散潜在変数のサンプルも得ることができます。

posterior_samples = mcmc.get_samples()
print(posterior_samples)
{'q': Array([0.32800168, 0.32800168, 0.32800168, ..., 0.2598623 , 0.2598623 ,
       0.21775615], dtype=float32)}
import jax

num_chains = 4
num_samples = 2000

# infer_discrete=Trueにより離散潜在変数も取得
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), y=df["Y"].values)

# mcmc_samplesに追加
chain_discrete_samples = jax.tree_util.tree_map(
    lambda x: x.reshape((num_chains, num_samples) + x.shape[1:]),
    discrete_samples)
mcmc.get_samples().update(discrete_samples)
mcmc.get_samples(group_by_chain=True).update(chain_discrete_samples)

mcmc.get_samples()
{'q': Array([0.32800168, 0.32800168, 0.32800168, ..., 0.2598623 , 0.2598623 ,
        0.21775615], dtype=float32),
 'coin': Array([[1, 0, 1, ..., 0, 1, 1],
        [1, 0, 1, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 1],
        ...,
        [1, 0, 1, ..., 1, 1, 1],
        [0, 0, 0, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 0]], dtype=int32),
 'obs': Array([[1, 0, 1, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 1],
        ...,
        [1, 0, 1, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 1],
        [1, 0, 1, ..., 1, 1, 1]], dtype=int32)}

2. Poisson分布

例題としてアヒル本のモデル式11-2を扱います。背景はアヒル本を参照していただいて、実装する式は以下になります。

m[n] \sim Poisson(\lambda) \\ Y[n] \sim Binomial(m[n], 0.5)

データの準備

アヒル本のchap11のデータ(chap11/input/data-poisson-binomial.txt)を使用します。

df = pd.read_csv("./RStanBook/chap11/input/data-poisson-binomial.txt")
df.head()
	Y
0	1
1	0
2	1
3	0
4	0

モデルの定義

ポアソン分布もベルヌーイ分布と同様にデフォルトでenumerateの機能があればいいのですが、ポアソン分布の値域は整数なので数え上げることができないためデフォルトではサポートされていません。そこで今回は既存のクラスを改変して使用してみます。

enumerateが可能かどうかの確認

print(dist.Bernoulli(1).has_enumerate_support)
print(dist.Poisson(1).has_enumerate_support)
True
False

enumerateに対応したPoisson分布

ここで、周辺化する際は数え上げる必要があるので、どの値まで数え上げるかをmax_valueで指定します。この時、max_valueより大きな値は確率を0と見なしていることになります。

下記の実装は既存のdist.Poissonを継承したクラスになります。周辺化が可能なクラスはenumerate_support()メソッドを持っており、これが場合の数を全て列挙しているような関数になっています。私も完璧に理解しているわけではないですが、他のenumerateをサポートしている分布のソースを見ると理解が少し深まるかと思います。例えば、カテゴリ分布では以下のように書かれており、カテゴリの場合の数(jnp.arange(self.probs.shape[-1]))が指定されています。

 def enumerate_support(self, expand=True):
        values = jnp.arange(self.probs.shape[-1]).reshape(
            (-1,) + (1,) * len(self.batch_shape)
        )
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values

そのため、今回の場合は以下のようにjnp.arange(self.max_value)の箇所で場合の数を指定しています。

class PoissonMax(dist.Poisson):
    """ デフォルトでは対応していないenumerateに対応したポアソン分布
        事前知識からありえない大きさの数字をmax_valueに指定。それ以上の値は確率を0とみなす。
    """
    arg_constraints = {"rate": constraints.positive, "max_value": constraints.positive}
    has_enumerate_support = True
    
    def __init__(self, rate, max_value, *, is_sparse=False, validate_args=None):
        self.max_value = max_value
        super().__init__(rate, is_sparse=is_sparse, validate_args=validate_args)
    
    def enumerate_support(self, expand=True):
        values = jnp.arange(self.max_value).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values

モデルの定義

def model(y=None):
    
    lam = numpyro.sample("lam", dist.HalfNormal(10))
    
    with numpyro.plate("N", len(y)):
        m = numpyro.sample("m", PoissonMax(rate=lam, max_value=40), infer={"enumerate": "parallel"})
        numpyro.sample("obs", dist.Binomial(total_count=m, probs=0.5), obs=y)

MCMC

アヒル本の結果と一緒になりました。自動で求めてくれるので楽ですね。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       lam      9.61      0.45      9.63      8.89     10.36    695.51      1.00

Number of divergences: 0

混合正規分布

今回の実装とは異なりますが、公式のTutorialにもあります。

https://num.pyro.ai/en/latest/tutorials/gmm.html#Training-a-MAP-estimator

NumPyroの混合分布

NumPyroでは混合分布用のクラスが準備されています。MixtureSameFamilyMixtureGeneraの2つが存在しますが、今回は2つの同じ正規分布を混ぜることを想定するので前者を使います。

MixtureSameFamilyの使い方としては、どの分布にどれだけ属すのかの確率を表す分布のmixing_distと混合する分布のcomponent_distを入力にします。

# どの分布にどれだけ属すのかの確率
mixing_dist = dist.Categorical(probs=jnp.ones(2) / 2.)
# 2つの正規分布を定義
component_dist = dist.Normal(loc=jnp.array([0, 5]), scale=jnp.ones(2))

mixture = dist.MixtureSameFamily(mixing_dist, component_dist)

samples = mixture.sample(random.PRNGKey(0), (5000,))
plt.hist(samples)

データの準備

アヒル本のchap11のデータ(chap11/input/data-mix1.txt)を使用します。

df = pd.read_csv(".//RStanBook/chap11/input/data-mix1.txt")
plt.hist(df["Y"])

モデルの定義

同じ分布の場合パラメータが識別不可になるため計算が発散します。そのため、以下のように2つの分布に上下関係を作る必要があります。実装は以下の2通りありますが、dist.ImproperUniformを使わない分前者の方が良さそうです。

dist.TransformDistributionを使用

def model1(y=None):
    K = 2
    
    a = numpyro.sample("a", dist.Dirichlet(jnp.ones(K)))
    # muをmu = numpyro.sample("mu", dist.Normal(jnp.zeros(K), jnp.ones(K)*100))などすると識別不可能でInfが出る
    mu = numpyro.sample("mu", dist.TransformedDistribution(dist.Normal(0, 10).expand([K]), dist.transforms.OrderedTransform()))
    scale = numpyro.sample("scale", dist.Uniform(0, 10).expand([K]).to_event(1))
    
    with numpyro.plate("N", len(y)):
        mixing_dist = dist.Categorical(probs=a)
        component_dist = dist.Normal(loc=mu, scale=scale)
        numpyro.sample("obs", dist.MixtureSameFamily(mixing_dist, component_dist), obs=y)

dist.ImproperUniformとconstraints.greater_thanを使用

def model2(y=None):
    K = 2
    
    a = numpyro.sample("a", dist.Dirichlet(jnp.ones(K)))
    # muをmu = numpyro.sample("mu", dist.Normal(jnp.zeros(K), jnp.ones(K)*100))などすると識別不可能でInfが出る
    mu_inner = numpyro.sample("mu_inner", dist.Normal(0, 1))
    mu_outer = numpyro.sample("mu_outer", dist.ImproperUniform(constraints.greater_than(mu_inner), (), ()))
    scale = numpyro.sample("scale", dist.Uniform(0, 10).expand([K]).to_event(1))
    
    with numpyro.plate("N", len(y)):
        mixing_dist = dist.Categorical(probs=a)
        component_dist = dist.Normal(loc=jnp.array([mu_inner, mu_outer]), scale=scale)
        numpyro.sample("obs", dist.MixtureSameFamily(mixing_dist, component_dist), obs=y)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    y=df["Y"].values,
)
mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      a[0]      0.54      0.08      0.55      0.42      0.66    865.86      1.00
      a[1]      0.46      0.08      0.45      0.34      0.58    865.86      1.00
     mu[0]     -0.05      0.24     -0.05     -0.48      0.30   1576.76      1.00
     mu[1]      5.71      0.67      5.77      4.64      6.80    645.56      1.00
  scale[0]      1.30      0.19      1.28      1.01      1.62    949.62      1.00
  scale[1]      2.48      0.48      2.41      1.71      3.26    821.65      1.00

Number of divergences: 0

ゼロ過剰ポアソン分布

データの準備

アヒル本のchap11のデータ(chap11/input/data-ZIP.txt)を使用します。

df = pd.read_csv("./RStanBook/chap11/input/data-ZIP.txt")
df["Age"] = df["Age"]/10.
df.head()
	Sex	Sake	Age	Y
0	0	1	1.8	5
1	1	0	1.8	2
2	1	1	1.8	1
3	0	0	1.9	3
4	0	0	1.9	5

モデルの定義

NumPyroがデフォルトでdist.ZeroInflatedPoissonとしてサポートしているので、以下のように使用すれば実行できます。

import jax.scipy as jsp

def model(X, y=None):
    D = X.shape[1]
    
    b1 = numpyro.sample("b1", dist.Normal(0, 1).expand([D]))
    b2 = numpyro.sample("b2", dist.Normal(0, 1).expand([D]))
    
    # jsp.special.expit : 高速なシグモイド関数
    q = jsp.special.expit(jnp.dot(X, b1))
    lam = jnp.exp(jnp.dot(X, b2))
    
    with numpyro.plate("N", len(y)):
        numpyro.sample("obs", dist.ZeroInflatedPoisson(gate=q, rate=lam), obs=y)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df[["Sex", "Sake", "Age"]].values,
    y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     b1[0]     -1.35      0.36     -1.35     -1.96     -0.77   1546.72      1.00
     b1[1]     -2.25      0.44     -2.23     -2.96     -1.55   1973.49      1.00
     b1[2]      0.08      0.06      0.08     -0.01      0.17   1480.48      1.00
     b2[0]     -0.55      0.08     -0.55     -0.68     -0.41   1492.57      1.00
     b2[1]     -0.06      0.08     -0.06     -0.18      0.08   1717.80      1.00
     b2[2]      0.51      0.01      0.51      0.49      0.53   1425.44      1.00

Number of divergences: 0

最後に

以上で「離散潜在変数の扱い方」は終わりです。enumeratenumpyro.factorの使い方は理解できたでしょうか?
次は「再パラメータ化」を扱います。

Discussion