Zenn
📝

NumPyro:再パラメータ化

2023/04/22に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

複雑なモデルや変数間に強い相関がある場合などでは結果が収束しないことがあります。収束がしづらいときは「MCMC実行時のパラメータを変える」「弱事前情報分布を指定する」「モデルの構造を変える」などで対応することもありますが、対応方法の1つとして再パラメータ化があります。言葉での説明が難しいので、有名なNealの漏斗で実装を確認してみます。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer import init_to_feasible
from numpyro.infer.util import initialize_model
from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam, TransformReparam


import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

Nealの漏斗

うまくいかない例

まずはうまくサンプリングができない例を見ていきます。以下の式を考えますが、この場合xがyに強く依存しているため、うまくMCMCではサンプリングできないという結果になります。

yNormal(0,3)x[n]Normal(0,exp(y/2)) y \sim Normal(0, 3) \\ x[n] \sim Normal(0, exp(y / 2))
def model1(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))

結果を見ると漏斗の形になっておらず、先端部分がうまくサンプリングできていないことがわかります。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc1 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc1.run(rng_key=rng_key, dim=1000)
samples1 = mcmc1.get_samples()

# plot
plt.scatter(samples1["x"][:, 0], samples1["y"], c="blue")

再パラメータ化

こういったケースでは再パラメータ化がうまくいくことがあります。再パラメータ化はrawパラメータ(Normal(0, 1)など)を考えてそれに元の分布を表すスケールをかけることで元の分布を表現します。こういった変換を行うことでサンプリング効率が良くなるというものです。式として考えると以下の通りです。

yNormal(0,3)xdecentered[n]Normal(0,1)x=exp(y/2)xdecentered y \sim Normal(0, 3) \\ x_{decentered}[n] \sim Normal(0, 1) \\ x = \exp(y / 2) * x_{decentered}

実際に実装をいくつか見ていきましょう。

手動で再パラメータ化する場合

手作業で再パラメータ化する場合は以下のようになります。式のままの実装ですね。numpyro.deterministicを使用することで、xを保存しています。
手動で書く場合は実装がごちゃつくのが難点ですね。

def model2(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    x_decentered = numpyro.sample("x_decentered", dist.Normal(jnp.zeros(dim - 1), jnp.ones(dim - 1)))
    x = numpyro.deterministic("x", jnp.exp(y / 2)*x_decentered)

LocScaleReparamを使用する場合

numpyro.infer.reparam.LocScaleReparamを使用することで上記の操作と同様の処理を行うことができます。numpyro.infer.reparam.LocScaleReparamの引数をcentered=0とした際の挙動としては以下のようになります。今回の場合、Baseの分布はdist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2))になります。

xdecenteredNormal(0,1)delta=xdecenteredvalue=(Baseの分布の平均)+(Baseの分布のスケール)delta x_{decentered} \sim Normal(0, 1) \\ delta = x_{decentered} \\ value = (Baseの分布の平均) + (Baseの分布のスケール)*delta
def model3(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    with numpyro.handlers.reparam(config={"x": LocScaleReparam(centered=0)}):
        numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))

LocScaleReparamをmodel関数外に書く場合

上記の記述はmodel関数内部に再パラメータ化のコードを書いていましたが、以下のように定義したmodel関数の外側で再パラメータ化の処理を加えたmodel関数を作ることもできます。

reparam_model = reparam(model1, config={"x": LocScaleReparam(centered=0)})

TransformedDistributionを使用する場合

以前基本の操作のところで扱ったTransformedDistributionを使用しても同様の操作を実装できます。これはベースとなる分布(dist.Normal(jnp.zeros(dim - 1), jnp.ones(dim - 1)),)を作成した後にAffine変換でスケールをjnp.exp(y / 2)だけ変換してます。
このとき、with numpyro.handlers.reparam(config={"x": TransformReparam()}):は特になくてもいいのですが、このコンテキストを使用することでベースとなっている分布(Normal)のサンプリング結果もx_baseとして得ることができます。

def model4(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    # with numpyro.handlers.reparam(config={"x": TransformReparam()}):
    # があることで、baseの分布(Normal)のサンプリング結果もx_baseとして得ることができる
    with numpyro.handlers.reparam(config={"x": TransformReparam()}):
        numpyro.sample(
            "x", 
            dist.TransformedDistribution(
                dist.Normal(jnp.zeros(dim - 1), jnp.ones(dim - 1)),
                dist.transforms.AffineTransform(0.0, jnp.exp(y / 2))
            )
        )

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc1 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc1.run(rng_key=rng_key, dim=1000)

# NUTSでMCMCを実行する
kernel = NUTS(model2, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc2 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc2.run(rng_key=rng_key, dim=1000)

# NUTSでMCMCを実行する
kernel = NUTS(model3, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc3 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc3.run(rng_key=rng_key, dim=1000)

# NUTSでMCMCを実行する
kernel = NUTS(model4, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc4 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc4.run(rng_key=rng_key, dim=1000)

# NUTSでMCMCを実行する
kernel = NUTS(reparam_model, adapt_step_size=True) # init_to_feasibleにするとN_effが増える
mcmc5 = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc5.run(rng_key=rng_key, dim=1000)

samples1 = mcmc1.get_samples()
samples2 = mcmc2.get_samples()
samples3 = mcmc3.get_samples()
samples4 = mcmc4.get_samples()
samples5 = mcmc5.get_samples()

reparamを使用すると以下のように、x_decenteredの結果も得ることができます。

print(samples3.keys())
print(samples5.keys())
dict_keys(['x', 'x_decentered', 'y'])
dict_keys(['x', 'x_decentered', 'y'])

TransformReparamを使用する場合、以下のようにx_baseとして元の分布のサンプリング結果を得ることができます。

print(samples4.keys())
dict_keys(['x', 'x_base', 'y'])

全ての結果を可視化します。潰れてしまっていますが、再パラメータ化すると漏斗の先端部分までうまくサンプリングできたことがわかります。

plt.scatter(samples2["x"][:, 0], samples2["y"], c="red", alpha=0.2)
plt.scatter(samples3["x"][:, 0], samples3["y"], c="green", alpha=0.2)
plt.scatter(samples4["x"][:, 0], samples4["y"], c="orange", alpha=0.2)
plt.scatter(samples5["x"][:, 0], samples5["y"], c="yellow", alpha=0.2)
plt.scatter(samples1["x"][:, 0], samples1["y"], c="blue")

最後に

以上で「再パラメータ化」は終わりです。再パラメータ化はうまくいく場合もありますがうまくいかない場合もあるので、階層モデルのような複雑なモデルでもreparam_model = reparam(model, config={name: LocScaleReparam(centered=0)})のようにするだけで再パラメータ化することができ気軽に試行錯誤できるのはうれしいですね。次は「順序回帰と独自の分布の定義」です。

Discussion

ログインするとコメントできます