🦁

2023/04/25に公開

# ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import make_regression

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam, TransformReparam

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)


# デモデータの準備

x, y, true_coef = make_regression(random_state=12,
n_samples=100,
n_features=1000,
n_informative=7,
noise=10.0,
bias=0.0,
coef=True)

df = pd.DataFrame(x)
df['y'] = y

print(true_coef[np.abs(true_coef)>0])

[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
59.46697217]


# Bayesian Lasso

こちらを参考にしました。

Laplace分布は原点上で尖ったような確率密度を持つ分布です。

d = dist.Laplace(0, 1)
samples = d.sample(random.PRNGKey(0), sample_shape=(1000,))
sns.displot(samples)


## モデルの定義

ベータの事前分布として0を原点としたラプラス分布を採用しています。その他は重回帰の例と同じです。

def model_lasso(X, y=None):
alpha = numpyro.sample("alpha", dist.Normal(0, 100))
b = numpyro.sample("b", dist.HalfNormal(10))
sigma = numpyro.sample("sigma", dist.HalfNormal(10))

with numpyro.plate("K", X.shape[1]):
beta = numpyro.sample("beta", dist.Laplace(0, b))

mu = numpyro.deterministic("mu", jnp.dot(X, beta) + alpha)

with numpyro.plate("N", len(X)):
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)


## モデルのレンダリング

numpyro.render_model(
model=model_lasso,
model_kwargs={"X": df.iloc[:, :-1].values, "y": df["y"].values},
render_params=True,
render_distributions=True
)


## MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_lasso)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df.iloc[:, :-1].values,
y=df["y"].values,
)
mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
alpha    -14.44     12.68    -14.37    -33.47      7.50    662.98      1.00
b      2.94      0.26      2.94      2.54      3.37     72.60      1.03
beta[0]     -0.37      3.66     -0.17     -6.26      5.33   1951.85      1.00
beta[1]     -1.63      4.29     -0.98     -8.67      4.56   1407.75      1.00
beta[2]     -0.66      3.57     -0.37     -6.17      5.55   1665.70      1.00
beta[3]      1.14      4.28      0.60     -5.58      8.19   1265.37      1.00
beta[4]     -0.92      3.93     -0.56     -8.14      4.60   1214.64      1.00
beta[5]      1.30      3.95      0.76     -5.06      7.74   1462.59      1.00
beta[6]     -1.83      4.22     -1.10     -9.67      3.70    893.02      1.00
beta[7]      1.23      3.84      0.76     -4.91      7.37   1278.04      1.00
beta[8]     -2.19      4.27     -1.45     -8.93      4.61   1265.98      1.00
beta[9]      0.86      3.82      0.48     -5.50      7.30   1619.54      1.00
beta[10]      0.66      3.81      0.25     -4.93      6.96   1120.14      1.00
beta[11]     -0.13      3.66     -0.07     -5.96      5.90   1547.37      1.00
beta[12]      0.99      3.62      0.63     -4.37      7.15   1589.32      1.00
beta[13]     -0.85      3.84     -0.60     -6.70      5.42    939.16      1.00
beta[14]     -0.39      3.93     -0.17     -7.05      5.50   1867.71      1.00
beta[15]      0.25      3.67      0.12     -6.21      5.73   1762.83      1.00
beta[16]      0.07      3.56      0.06     -5.67      5.83   2110.80      1.00
beta[17]     -0.47      3.78     -0.21     -6.90      5.23   1564.20      1.00
beta[18]     -0.20      3.42     -0.09     -6.23      5.12   1945.12      1.00
beta[19]     -1.22      3.97     -0.67     -7.74      5.06   2029.20      1.00
beta[20]      0.74      4.06      0.32     -5.86      7.12   1499.98      1.00
...
beta[999]     -1.07      3.86     -0.56     -7.39      4.86   1331.80      1.00
sigma      8.78      4.68      8.22      1.61     15.35     20.15      1.00

Number of divergences: 0


## 結果の確認

samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])

[ 5.583676    1.8187047  24.921148   15.995176    0.89190835 39.011353
19.61189   ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
59.46697217]


# 馬蹄事前分布を使用した回帰モデル

Bayesian Lassoの事後中央値は0の値とならないことが欠点としてある(参考)ため、馬蹄事前分布を使用した回帰モデルを実装します。こちらスライド, numpyroのドキュメントを参考にしました。

y \sim Normal(mu, \sigma) \\ mu = \beta x + intercept \\ \beta_j \sim Normal(0, \lambda_j^2 \tau^2) \\ \lambda_j \sim HalfCauchy(0, 1)

## モデルの定義とMCMC

def model_horseshoe(X, y=None):
D_X = X.shape[1]

# sample from horseshoe prior
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

betas = numpyro.sample("betas", dist.Normal(jnp.zeros(D_X), tau*lambdas))

# compute mean function using linear coefficients
mean_function = jnp.dot(X, betas)

prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)

# observe data
numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=y)


n_effやr_hatを見ると収束していないことが確認できます。そのため、次は再パラメータ化を試します。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_horseshoe)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df.iloc[:, :-1].values,
y=df["y"].values,
)
mcmc.print_summary()

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
betas[0]     -0.02      0.06      0.00     -0.09      0.05     25.25      1.10
betas[1]     -0.00      0.04     -0.00     -0.04      0.03    152.23      1.01
betas[2]     -0.04      0.10     -0.01     -0.09      0.05     14.40      1.19
betas[3]     -0.00      0.01     -0.00     -0.02      0.02     25.59      1.05
betas[4]     -0.07      0.09     -0.01     -0.18      0.03      3.36      2.16
betas[5]     -0.01      0.04     -0.00     -0.05      0.06     21.46      1.11
betas[6]      0.01      0.03     -0.00     -0.03      0.02     88.96      1.03
betas[7]      0.00      0.02      0.00     -0.02      0.03     52.17      1.00
betas[8]      0.00      0.01      0.00     -0.02      0.02     54.11      1.08
betas[9]      0.00      0.04      0.00     -0.06      0.04     41.61      1.01
betas[10]      0.00      0.01      0.00     -0.02      0.01    321.02      1.00
betas[11]      0.00      0.02     -0.00     -0.03      0.02     53.88      1.00
betas[12]      0.00      0.03      0.00     -0.04      0.04    229.37      1.00
betas[13]      0.00      0.02      0.00     -0.03      0.04     48.81      1.10
betas[14]     -0.00      0.01     -0.00     -0.02      0.02    136.49      1.02
betas[15]      0.03      0.07      0.01     -0.07      0.18      8.90      1.00
betas[16]     -0.00      0.01     -0.00     -0.02      0.02     86.24      1.03
betas[17]      0.00      0.01      0.00     -0.02      0.03     87.96      1.01
betas[18]     -0.00      0.02     -0.00     -0.03      0.02     93.97      1.02
betas[19]      0.00      0.04      0.00     -0.05      0.03     82.88      1.06
betas[20]     -0.01      0.04      0.00     -0.08      0.02     43.25      1.01
betas[21]     -0.00      0.02      0.00     -0.03      0.03    176.26      1.00
betas[22]     -0.02      0.04     -0.00     -0.09      0.01      4.94      1.50
...
prec_obs      0.00      0.00      0.00      0.00      0.00     15.05      1.08
tau[0]      0.01      0.00      0.01      0.01      0.01      9.68      1.21

Number of divergences: 970


samples = mcmc.get_samples()

print(samples["betas"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])

[-2.7003107e-03 -4.6002562e-03 -2.8566900e-03  4.5012035e-03
6.8844520e-02  1.5965151e-03  5.0731182e+01]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
59.46697217]


def reparam_model_horseshoe(X, y=None):
D_X = X.shape[1]

# sample from horseshoe prior
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

# 再パラメータ化
with numpyro.handlers.reparam(config={"x": TransformReparam()}):
betas = numpyro.sample(
"betas",
dist.TransformedDistribution(
dist.Normal(jnp.zeros(D_X), jnp.ones(D_X)),
dist.transforms.AffineTransform(0.0, tau*lambdas)
)
)

# compute mean function using linear coefficients
mean_function = jnp.dot(X, betas)

prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / jnp.sqrt(prec_obs)

# observe data
numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=y)


# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(reparam_model_horseshoe)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df.iloc[:, :-1].values,
y=df["y"].values,
)
mcmc.print_summary()

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
betas[0]      0.00      0.01      0.00     -0.01      0.02    100.89      1.00
betas[1]     -0.00      0.01     -0.00     -0.02      0.02     35.94      1.07
betas[2]     -0.00      0.01     -0.00     -0.02      0.02     96.75      1.03
betas[3]     -0.00      0.01     -0.00     -0.01      0.01    163.45      1.00
betas[4]     -0.00      0.05     -0.00     -0.02      0.03     87.50      1.00
betas[5]     -0.02      0.05     -0.00     -0.09      0.03     11.75      1.08
betas[6]      0.00      0.01      0.00     -0.01      0.01     75.07      1.02
betas[7]      0.00      0.01      0.00     -0.01      0.01    357.82      1.00
betas[8]     -0.00      0.01      0.00     -0.02      0.01    378.84      1.01
betas[9]     -0.01      0.04     -0.00     -0.04      0.04     12.07      1.12
betas[10]     -0.00      0.01     -0.00     -0.02      0.02     94.46      1.00
betas[11]     -0.02      0.02     -0.02     -0.06      0.00      4.18      1.18
betas[12]      0.00      0.02      0.00     -0.03      0.04     20.02      1.09
betas[13]      0.00      0.01      0.00     -0.01      0.01    293.32      1.00
betas[14]      0.00      0.01      0.00     -0.02      0.02     68.68      1.04
betas[15]     -0.00      0.03      0.00     -0.08      0.03     19.21      1.02
betas[16]     -0.01      0.03     -0.00     -0.08      0.02      5.79      1.26
betas[17]      0.00      0.03      0.00     -0.03      0.03     67.48      1.03
betas[18]      0.00      0.02     -0.00     -0.03      0.03     11.10      1.17
betas[19]      0.00      0.01      0.00     -0.01      0.02    122.03      1.01
betas[20]      0.00      0.01     -0.00     -0.01      0.01    330.34      1.00
betas[21]     -0.00      0.02     -0.00     -0.02      0.02     50.76      1.01
betas[22]     -0.00      0.01     -0.00     -0.01      0.01     80.41      1.00
...
prec_obs      0.00      0.00      0.00      0.00      0.00     26.99      1.00
tau[0]      0.00      0.00      0.00      0.00      0.01     18.63      1.04

Number of divergences: 1261


# 正則化付き馬蹄事前分布を使用した回帰モデル

## モデルの定義

NumPyroの実装として新しい点はdist.FoldedDistributionです。既存のクラスにdist.HalfNormaldist.HalfCohcyはありますがそれ以外の分布はないので、dist.FoldedDistribution(baseのdist)のように定義し使用します。この点以外は実装で新しいところはありません。

• scale_icept : 切片の正規分布におけるスケール
• scale_global : global shrinkage parameterのstudent-t分布のスケール
• nu_local : local shrinkage parametersのStudent-t分布の自由度
• nu_global :　global shrinkage parametersのStudent-t分布の自由度
• slab_df : student-t slabの自由度
• slab_scale : Student-t slabのスケール
• par_ratio : ゼロの係数に対する非ゼロ係数の割合
def regularized_horseshoe(X, y=None, scale_icept=10, scale_global=1, nu_local=1, nu_global=1, slab_df=4, slab_scale=2, par_ratio=None):
# パラメータはこれを参照 https://rdrr.io/cran/brms/man/horseshoe.html
# Appendix C : https://projecteuclid.org/journals/electronic-journal-of-statistics/volume-11/issue-2/Sparsity-information-and-regularization-in-the-horseshoe-and-other-shrinkage/10.1214/17-EJS1337SI.full
# https://discourse.julialang.org/t/regularized-horseshoe-prior/71599/2

if par_ratio is not None:
scale_global = par_ratio / np.sqrt(len(X))

D_X = X.shape[1]

# unscaled_betas
z = numpyro.sample("z", dist.Normal(jnp.zeros(D_X), jnp.ones(D_X)))
# local shrinkage parameter
lambdas = numpyro.sample("lambdas", dist.FoldedDistribution(dist.StudentT(nu_local, 0, 1)).expand([D_X]))
# noise std
sigma = numpyro.sample("sigma", dist.Exponential(1))
# global shrinkage parameter
tau = numpyro.sample("tau", dist.FoldedDistribution(dist.StudentT(nu_global, 0, scale_global*sigma)))
# slab degrees of freedom for the regularized horseshoe
caux = numpyro.sample("caux", dist.InverseGamma(0.5*slab_df, 0.5*slab_df))
# intercept
beta0 = numpyro.sample("beta0", dist.Normal(0, scale_icept))

# slab scale
c = numpyro.deterministic("c", slab_scale*jnp.sqrt(caux))
# truncated local shrinkage parameter
lambda_tilde = numpyro.deterministic("lambda_tilde", jnp.sqrt(c**2*jnp.square(lambdas) / (c**2 + tau**2*jnp.square(lambdas))))
# beta
beta = numpyro.deterministic("beta", z*lambda_tilde*tau)

# mean function
f = numpyro.deterministic("f", beta0 + jnp.dot(X, beta))

with numpyro.plate("N", len(X)):
numpyro.sample("obs", dist.Normal(f, sigma), obs=y)


## MCMC

オリジナルのhorseshoeに比べ少し収束している感じも気持ちありますがまだN_effやR_hatを見ると収束してないですね。また、正則化付き馬蹄事前分布は傾向としてNumber of divergencesも少ないことが知られていますが、今回はそれでも少し多かったのでtarget_accept_probをデフォルトの0.8から0.9へ変更しています。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(regularized_horseshoe, target_accept_prob=0.9)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df.iloc[:, :-1].values,
y=df["y"].values,
)
mcmc.print_summary()

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
beta0      1.20      0.51      1.15      0.47      2.21      9.62      1.00
caux    836.78    368.27    774.34    335.49   1331.69     13.21      1.21
lambdas[0]      2.54      3.56      1.11      0.02      7.07      9.65      1.21
lambdas[1]      1.69      2.34      0.82      0.01      5.34      9.84      1.11
lambdas[2]      1.46      1.40      1.08      0.02      3.38     39.71      1.00
lambdas[3]      2.38      3.20      1.20      0.04      5.59     10.77      1.18
lambdas[4]      1.79      2.91      0.82      0.01      4.37     32.66      1.06
lambdas[5]      1.72      1.98      1.12      0.13      3.53     51.05      1.00
lambdas[6]      1.72      1.80      1.08      0.06      4.03     25.36      1.10
lambdas[7]      0.74      0.70      0.48      0.02      1.70     11.24      1.02
lambdas[8]      1.69      2.37      1.01      0.09      3.50     28.32      1.07
lambdas[9]      2.27      4.12      0.57      0.00      7.06     19.25      1.01
lambdas[10]      0.96      1.09      0.51      0.04      2.40     17.42      1.00
lambdas[11]      1.44      2.97      0.69      0.01      3.10     37.59      1.03
lambdas[12]      2.64      3.38      1.38      0.02      6.85     12.09      1.24
lambdas[13]      0.78      0.71      0.67      0.00      1.79     15.80      1.07
lambdas[14]      1.79      2.48      0.93      0.10      4.63     41.14      1.01
lambdas[15]      5.11      9.00      1.27      0.02     14.57      9.80      1.11
lambdas[16]      1.39      1.70      0.91      0.10      2.83     32.90      1.01
lambdas[17]      0.41      0.48      0.25      0.00      0.95     34.87      1.07
lambdas[18]      2.27      3.56      0.87      0.06      6.70     17.80      1.14
lambdas[19]      1.63      2.51      0.82      0.02      4.00     23.02      1.06
lambdas[20]      1.65      2.13      0.97      0.02      4.12     12.47      1.17
...
z[998]     -0.15      0.89     -0.21     -1.55      1.31      5.59      1.48
z[999]      0.47      1.06      0.41     -1.24      1.87     10.09      1.32

Number of divergences: 11


samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])

[26.861206  12.739779  67.2756    69.541016   3.4015157 76.7576
58.429512 ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
59.46697217]


numpyro2az = az.from_numpyro(mcmc)
az.plot_forest(numpyro2az, var_names=["beta"], figsize=(8,4));


## MCMC par_ratioを指定

# Xの次元数
D=1000
# 非ゼロの係数の個数
p0=10
par_ratio = p0/(D-p0)

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(regularized_horseshoe, target_accept_prob=0.9)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df.iloc[:, :-1].values,
y=df["y"].values,
par_ratio=par_ratio
)
mcmc.print_summary()

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
beta0      0.77      0.79      0.75     -0.50      2.09    249.31      1.00
caux   1032.52    996.18    775.49    210.42   1861.12    445.57      1.00
lambdas[0]      2.55      8.49      0.91      0.00      5.14   1057.49      1.00
lambdas[1]      2.19      4.29      0.94      0.00      4.91    894.40      1.00
lambdas[2]      2.60      7.04      0.97      0.00      5.70   1145.11      1.00
lambdas[3]      2.11      4.64      0.91      0.00      4.71    879.15      1.00
lambdas[4]      2.99      8.28      1.03      0.00      6.47    787.06      1.00
lambdas[5]      3.89     13.76      1.09      0.00      7.50    943.71      1.00
lambdas[6]      2.30      5.51      0.94      0.00      5.18   1129.57      1.00
lambdas[7]      2.36      5.87      0.96      0.00      4.82    983.58      1.00
lambdas[8]      4.17     26.52      1.05      0.00      6.55    700.98      1.00
lambdas[9]      3.02      8.22      0.99      0.00      6.37    611.91      1.00
lambdas[10]      2.91     10.19      0.95      0.00      5.62    875.64      1.00
lambdas[11]      2.88      6.67      1.01      0.00      6.16    491.22      1.00
lambdas[12]      3.73     11.99      1.08      0.00      7.33    801.54      1.00
lambdas[13]      2.31      4.59      0.91      0.00      5.44   1082.74      1.00
lambdas[14]      2.53      5.13      1.00      0.00      5.72    634.83      1.00
lambdas[15]      2.87      7.23      1.02      0.00      6.29    903.70      1.00
lambdas[16]     16.10     38.02      2.09      0.00     44.98    221.81      1.00
lambdas[17]      2.31      5.35      0.95      0.00      5.08    904.12      1.00
lambdas[18]      2.22      4.61      0.94      0.00      4.99   1176.63      1.00
lambdas[19]      3.84     19.86      1.04      0.00      6.18   1069.36      1.00
lambdas[20]      2.54      5.68      0.95      0.00      5.82    856.07      1.00
...
z[998]     -0.08      1.00     -0.06     -1.88      1.44    948.18      1.00
z[999]      0.02      0.99     -0.00     -1.51      1.72    882.59      1.00

Number of divergences: 50


samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])

[26.88424   13.017417  66.96016   69.43479    3.9294827 77.563736
58.247986 ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
59.46697217]


numpyro2az = az.from_numpyro(mcmc)
az.plot_forest(numpyro2az, var_names=["beta"], figsize=(8,4));