🦁

NumPyro:スパースモデル

2023/04/25に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回はスパースモデルとして、Bayesian Lasso回帰と馬蹄事前分布を使用した回帰モデル、正則化つき馬蹄事前分布を使用した回帰モデルを扱います。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import make_regression

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam, TransformReparam

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

デモデータの準備

今回は正則化つき馬蹄事前分布の効果が発揮された例を紹介したいので、データ数N=100, 次元数P=1000,情報をもつ変数P0=7のN<<Pの例をみます。

x, y, true_coef = make_regression(random_state=12, 
                       n_samples=100, 
                       n_features=1000,
                       n_informative=7,
                       noise=10.0,
                       bias=0.0,
                       coef=True)

df = pd.DataFrame(x)
df['y'] = y

print(true_coef[np.abs(true_coef)>0])
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
 59.46697217]

Bayesian Lasso

こちらを参考にしました。
回帰係数の事前分布として以下のLaplace(二重指数)分布を採用することで実装できます。
Laplace分布は原点上で尖ったような確率密度を持つ分布です。

d = dist.Laplace(0, 1)
samples = d.sample(random.PRNGKey(0), sample_shape=(1000,))
sns.displot(samples)

モデルの定義

ベータの事前分布として0を原点としたラプラス分布を採用しています。その他は重回帰の例と同じです。

def model_lasso(X, y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0, 100))
    b = numpyro.sample("b", dist.HalfNormal(10))
    sigma = numpyro.sample("sigma", dist.HalfNormal(10))
    
    with numpyro.plate("K", X.shape[1]):
        beta = numpyro.sample("beta", dist.Laplace(0, b))
    
    mu = numpyro.deterministic("mu", jnp.dot(X, beta) + alpha)
    
    with numpyro.plate("N", len(X)):
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

モデルのレンダリング

numpyro.render_model(
    model=model_lasso, 
    model_kwargs={"X": df.iloc[:, :-1].values, "y": df["y"].values}, 
    render_params=True, 
    render_distributions=True
)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_lasso)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df.iloc[:, :-1].values,
    y=df["y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha    -14.44     12.68    -14.37    -33.47      7.50    662.98      1.00
         b      2.94      0.26      2.94      2.54      3.37     72.60      1.03
   beta[0]     -0.37      3.66     -0.17     -6.26      5.33   1951.85      1.00
   beta[1]     -1.63      4.29     -0.98     -8.67      4.56   1407.75      1.00
   beta[2]     -0.66      3.57     -0.37     -6.17      5.55   1665.70      1.00
   beta[3]      1.14      4.28      0.60     -5.58      8.19   1265.37      1.00
   beta[4]     -0.92      3.93     -0.56     -8.14      4.60   1214.64      1.00
   beta[5]      1.30      3.95      0.76     -5.06      7.74   1462.59      1.00
   beta[6]     -1.83      4.22     -1.10     -9.67      3.70    893.02      1.00
   beta[7]      1.23      3.84      0.76     -4.91      7.37   1278.04      1.00
   beta[8]     -2.19      4.27     -1.45     -8.93      4.61   1265.98      1.00
   beta[9]      0.86      3.82      0.48     -5.50      7.30   1619.54      1.00
  beta[10]      0.66      3.81      0.25     -4.93      6.96   1120.14      1.00
  beta[11]     -0.13      3.66     -0.07     -5.96      5.90   1547.37      1.00
  beta[12]      0.99      3.62      0.63     -4.37      7.15   1589.32      1.00
  beta[13]     -0.85      3.84     -0.60     -6.70      5.42    939.16      1.00
  beta[14]     -0.39      3.93     -0.17     -7.05      5.50   1867.71      1.00
  beta[15]      0.25      3.67      0.12     -6.21      5.73   1762.83      1.00
  beta[16]      0.07      3.56      0.06     -5.67      5.83   2110.80      1.00
  beta[17]     -0.47      3.78     -0.21     -6.90      5.23   1564.20      1.00
  beta[18]     -0.20      3.42     -0.09     -6.23      5.12   1945.12      1.00
  beta[19]     -1.22      3.97     -0.67     -7.74      5.06   2029.20      1.00
  beta[20]      0.74      4.06      0.32     -5.86      7.12   1499.98      1.00
...
 beta[999]     -1.07      3.86     -0.56     -7.39      4.86   1331.80      1.00
     sigma      8.78      4.68      8.22      1.61     15.35     20.15      1.00

Number of divergences: 0

結果の確認

重みの平均と真の係数を比較します。ほとんど一致してないことがわかります。(データ数100, 次元数200の時はうまく推定できていたのでコードの誤り等はないと思います)

samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])
[ 5.583676    1.8187047  24.921148   15.995176    0.89190835 39.011353
 19.61189   ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
 59.46697217]

馬蹄事前分布を使用した回帰モデル

Bayesian Lassoの事後中央値は0の値とならないことが欠点としてある(参考)ため、馬蹄事前分布を使用した回帰モデルを実装します。こちらスライド, numpyroのドキュメントを参考にしました。

式としては以下の通りです。ここで全係数に共通のパラメータ\tau\betaを0へ近づけるパラメータで、係数ごとのパラメータ\lambda_jはいくつかの係数を0から遠ざけるパラメータになります。

y \sim Normal(mu, \sigma) \\ mu = \beta x + intercept \\ \beta_j \sim Normal(0, \lambda_j^2 \tau^2) \\ \lambda_j \sim HalfCauchy(0, 1)

モデルの定義とMCMC

再パラメータ化しない場合

def model_horseshoe(X, y=None):
    D_X = X.shape[1]

    # sample from horseshoe prior
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    betas = numpyro.sample("betas", dist.Normal(jnp.zeros(D_X), tau*lambdas)) 

    # compute mean function using linear coefficients
    mean_function = jnp.dot(X, betas)

    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=y)

n_effやr_hatを見ると収束していないことが確認できます。そのため、次は再パラメータ化を試します。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_horseshoe)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df.iloc[:, :-1].values,
    y=df["y"].values,
)
mcmc.print_summary()
                  mean       std    median      5.0%     95.0%     n_eff     r_hat
    betas[0]     -0.02      0.06      0.00     -0.09      0.05     25.25      1.10
    betas[1]     -0.00      0.04     -0.00     -0.04      0.03    152.23      1.01
    betas[2]     -0.04      0.10     -0.01     -0.09      0.05     14.40      1.19
    betas[3]     -0.00      0.01     -0.00     -0.02      0.02     25.59      1.05
    betas[4]     -0.07      0.09     -0.01     -0.18      0.03      3.36      2.16
    betas[5]     -0.01      0.04     -0.00     -0.05      0.06     21.46      1.11
    betas[6]      0.01      0.03     -0.00     -0.03      0.02     88.96      1.03
    betas[7]      0.00      0.02      0.00     -0.02      0.03     52.17      1.00
    betas[8]      0.00      0.01      0.00     -0.02      0.02     54.11      1.08
    betas[9]      0.00      0.04      0.00     -0.06      0.04     41.61      1.01
   betas[10]      0.00      0.01      0.00     -0.02      0.01    321.02      1.00
   betas[11]      0.00      0.02     -0.00     -0.03      0.02     53.88      1.00
   betas[12]      0.00      0.03      0.00     -0.04      0.04    229.37      1.00
   betas[13]      0.00      0.02      0.00     -0.03      0.04     48.81      1.10
   betas[14]     -0.00      0.01     -0.00     -0.02      0.02    136.49      1.02
   betas[15]      0.03      0.07      0.01     -0.07      0.18      8.90      1.00
   betas[16]     -0.00      0.01     -0.00     -0.02      0.02     86.24      1.03
   betas[17]      0.00      0.01      0.00     -0.02      0.03     87.96      1.01
   betas[18]     -0.00      0.02     -0.00     -0.03      0.02     93.97      1.02
   betas[19]      0.00      0.04      0.00     -0.05      0.03     82.88      1.06
   betas[20]     -0.01      0.04      0.00     -0.08      0.02     43.25      1.01
   betas[21]     -0.00      0.02      0.00     -0.03      0.03    176.26      1.00
   betas[22]     -0.02      0.04     -0.00     -0.09      0.01      4.94      1.50
...
    prec_obs      0.00      0.00      0.00      0.00      0.00     15.05      1.08
      tau[0]      0.01      0.00      0.01      0.01      0.01      9.68      1.21

Number of divergences: 970  

全然見当違いの値になっています。

samples = mcmc.get_samples()

print(samples["betas"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])
[-2.7003107e-03 -4.6002562e-03 -2.8566900e-03  4.5012035e-03
  6.8844520e-02  1.5965151e-03  5.0731182e+01]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
 59.46697217]

再パラメータ化する場合

馬蹄事前分布を使用すると収束性が悪くなることがあるため、再パラメータ化を行います。ここでは、dist.TransformedDistributionを使用して変換する方法を採用しています。(reparam_model_horseshoe = reparam(model_horseshoe, config={"betas": LocScaleReparam(centered=0)})やチュートリアルでの手動で書く方法などの別の方法を試したのですがなぜかうまくいきません??)

def reparam_model_horseshoe(X, y=None):
    D_X = X.shape[1]

    # sample from horseshoe prior
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(D_X)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    # 再パラメータ化
    with numpyro.handlers.reparam(config={"x": TransformReparam()}):
        betas = numpyro.sample(
            "betas", 
            dist.TransformedDistribution(
                dist.Normal(jnp.zeros(D_X), jnp.ones(D_X)),
                dist.transforms.AffineTransform(0.0, tau*lambdas)
            )
        )

    # compute mean function using linear coefficients
    mean_function = jnp.dot(X, betas)

    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("Y", dist.Normal(mean_function, sigma_obs), obs=y)

今回の例ではNumPyroのチュートリアルと同じ再パラメータ化してもうまく推定できていません。ちなみに、データ数100, 次元数200でやった時は多少N_effやR_hatの改善が見られました。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(reparam_model_horseshoe)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df.iloc[:, :-1].values,
    y=df["y"].values,
)
mcmc.print_summary()
                  mean       std    median      5.0%     95.0%     n_eff     r_hat
    betas[0]      0.00      0.01      0.00     -0.01      0.02    100.89      1.00
    betas[1]     -0.00      0.01     -0.00     -0.02      0.02     35.94      1.07
    betas[2]     -0.00      0.01     -0.00     -0.02      0.02     96.75      1.03
    betas[3]     -0.00      0.01     -0.00     -0.01      0.01    163.45      1.00
    betas[4]     -0.00      0.05     -0.00     -0.02      0.03     87.50      1.00
    betas[5]     -0.02      0.05     -0.00     -0.09      0.03     11.75      1.08
    betas[6]      0.00      0.01      0.00     -0.01      0.01     75.07      1.02
    betas[7]      0.00      0.01      0.00     -0.01      0.01    357.82      1.00
    betas[8]     -0.00      0.01      0.00     -0.02      0.01    378.84      1.01
    betas[9]     -0.01      0.04     -0.00     -0.04      0.04     12.07      1.12
   betas[10]     -0.00      0.01     -0.00     -0.02      0.02     94.46      1.00
   betas[11]     -0.02      0.02     -0.02     -0.06      0.00      4.18      1.18
   betas[12]      0.00      0.02      0.00     -0.03      0.04     20.02      1.09
   betas[13]      0.00      0.01      0.00     -0.01      0.01    293.32      1.00
   betas[14]      0.00      0.01      0.00     -0.02      0.02     68.68      1.04
   betas[15]     -0.00      0.03      0.00     -0.08      0.03     19.21      1.02
   betas[16]     -0.01      0.03     -0.00     -0.08      0.02      5.79      1.26
   betas[17]      0.00      0.03      0.00     -0.03      0.03     67.48      1.03
   betas[18]      0.00      0.02     -0.00     -0.03      0.03     11.10      1.17
   betas[19]      0.00      0.01      0.00     -0.01      0.02    122.03      1.01
   betas[20]      0.00      0.01     -0.00     -0.01      0.01    330.34      1.00
   betas[21]     -0.00      0.02     -0.00     -0.02      0.02     50.76      1.01
   betas[22]     -0.00      0.01     -0.00     -0.01      0.01     80.41      1.00
...
    prec_obs      0.00      0.00      0.00      0.00      0.00     26.99      1.00
      tau[0]      0.00      0.00      0.00      0.00      0.01     18.63      1.04

Number of divergences: 1261  

正則化付き馬蹄事前分布を使用した回帰モデル

論文で提案されている正則化付き馬蹄事前分布を使用した回帰モデルを実装します。理論の説明は省略しますが、スライドもあるので、こちらの方がイメージつくかもしれません。また、パラメータがいくつかありますが、brmsのデフォルト値を参考にしています。

モデルの定義

NumPyroの実装として新しい点はdist.FoldedDistributionです。既存のクラスにdist.HalfNormaldist.HalfCohcyはありますがそれ以外の分布はないので、dist.FoldedDistribution(baseのdist)のように定義し使用します。この点以外は実装で新しいところはありません。

以下は今回のモデルのパラメータです。par_ratioを指定するとscale_globalが無視されますが、par_ratioを直接指定した方が推奨されているそうです。今回も直接指定した方が良い結果になりました。

  • scale_icept : 切片の正規分布におけるスケール
  • scale_global : global shrinkage parameterのstudent-t分布のスケール
  • nu_local : local shrinkage parametersのStudent-t分布の自由度
  • nu_global : global shrinkage parametersのStudent-t分布の自由度
  • slab_df : student-t slabの自由度
  • slab_scale : Student-t slabのスケール
  • par_ratio : ゼロの係数に対する非ゼロ係数の割合
def regularized_horseshoe(X, y=None, scale_icept=10, scale_global=1, nu_local=1, nu_global=1, slab_df=4, slab_scale=2, par_ratio=None):
    # パラメータはこれを参照 https://rdrr.io/cran/brms/man/horseshoe.html
    # Appendix C : https://projecteuclid.org/journals/electronic-journal-of-statistics/volume-11/issue-2/Sparsity-information-and-regularization-in-the-horseshoe-and-other-shrinkage/10.1214/17-EJS1337SI.full
    # https://discourse.julialang.org/t/regularized-horseshoe-prior/71599/2
    
    if par_ratio is not None:
        scale_global = par_ratio / np.sqrt(len(X))
    
    D_X = X.shape[1]
    
    # unscaled_betas
    z = numpyro.sample("z", dist.Normal(jnp.zeros(D_X), jnp.ones(D_X)))
    # local shrinkage parameter
    lambdas = numpyro.sample("lambdas", dist.FoldedDistribution(dist.StudentT(nu_local, 0, 1)).expand([D_X]))
    # noise std
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    # global shrinkage parameter
    tau = numpyro.sample("tau", dist.FoldedDistribution(dist.StudentT(nu_global, 0, scale_global*sigma)))
    # slab degrees of freedom for the regularized horseshoe
    caux = numpyro.sample("caux", dist.InverseGamma(0.5*slab_df, 0.5*slab_df))
    # intercept
    beta0 = numpyro.sample("beta0", dist.Normal(0, scale_icept)) 
    
    # slab scale
    c = numpyro.deterministic("c", slab_scale*jnp.sqrt(caux))
    # truncated local shrinkage parameter
    lambda_tilde = numpyro.deterministic("lambda_tilde", jnp.sqrt(c**2*jnp.square(lambdas) / (c**2 + tau**2*jnp.square(lambdas))))
    # beta
    beta = numpyro.deterministic("beta", z*lambda_tilde*tau)
    
    # mean function
    f = numpyro.deterministic("f", beta0 + jnp.dot(X, beta))
    
    with numpyro.plate("N", len(X)):
        numpyro.sample("obs", dist.Normal(f, sigma), obs=y)

モデルのレンダリング

MCMC

オリジナルのhorseshoeに比べ少し収束している感じも気持ちありますがまだN_effやR_hatを見ると収束してないですね。また、正則化付き馬蹄事前分布は傾向としてNumber of divergencesも少ないことが知られていますが、今回はそれでも少し多かったのでtarget_accept_probをデフォルトの0.8から0.9へ変更しています。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(regularized_horseshoe, target_accept_prob=0.9)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df.iloc[:, :-1].values,
    y=df["y"].values,
)
mcmc.print_summary()
                  mean       std    median      5.0%     95.0%     n_eff     r_hat
       beta0      1.20      0.51      1.15      0.47      2.21      9.62      1.00
        caux    836.78    368.27    774.34    335.49   1331.69     13.21      1.21
  lambdas[0]      2.54      3.56      1.11      0.02      7.07      9.65      1.21
  lambdas[1]      1.69      2.34      0.82      0.01      5.34      9.84      1.11
  lambdas[2]      1.46      1.40      1.08      0.02      3.38     39.71      1.00
  lambdas[3]      2.38      3.20      1.20      0.04      5.59     10.77      1.18
  lambdas[4]      1.79      2.91      0.82      0.01      4.37     32.66      1.06
  lambdas[5]      1.72      1.98      1.12      0.13      3.53     51.05      1.00
  lambdas[6]      1.72      1.80      1.08      0.06      4.03     25.36      1.10
  lambdas[7]      0.74      0.70      0.48      0.02      1.70     11.24      1.02
  lambdas[8]      1.69      2.37      1.01      0.09      3.50     28.32      1.07
  lambdas[9]      2.27      4.12      0.57      0.00      7.06     19.25      1.01
 lambdas[10]      0.96      1.09      0.51      0.04      2.40     17.42      1.00
 lambdas[11]      1.44      2.97      0.69      0.01      3.10     37.59      1.03
 lambdas[12]      2.64      3.38      1.38      0.02      6.85     12.09      1.24
 lambdas[13]      0.78      0.71      0.67      0.00      1.79     15.80      1.07
 lambdas[14]      1.79      2.48      0.93      0.10      4.63     41.14      1.01
 lambdas[15]      5.11      9.00      1.27      0.02     14.57      9.80      1.11
 lambdas[16]      1.39      1.70      0.91      0.10      2.83     32.90      1.01
 lambdas[17]      0.41      0.48      0.25      0.00      0.95     34.87      1.07
 lambdas[18]      2.27      3.56      0.87      0.06      6.70     17.80      1.14
 lambdas[19]      1.63      2.51      0.82      0.02      4.00     23.02      1.06
 lambdas[20]      1.65      2.13      0.97      0.02      4.12     12.47      1.17
...
      z[998]     -0.15      0.89     -0.21     -1.55      1.31      5.59      1.48
      z[999]      0.47      1.06      0.41     -1.24      1.87     10.09      1.32

Number of divergences: 11

収束は悪かったですが、係数を比較するとかなり近い値になりました。

samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])
[26.861206  12.739779  67.2756    69.541016   3.4015157 76.7576
 58.429512 ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
 59.46697217]

各係数をプロットしてみます。今回の真の係数の7つ以外にも0から離れている係数が10程度ありそうです。

numpyro2az = az.from_numpyro(mcmc)
az.plot_forest(numpyro2az, var_names=["beta"], figsize=(8,4));

MCMC par_ratioを指定

正則化付き馬蹄事前分布のパラメータとして係数が非0の個数の割合を(なんとなくでもいいので)知識として持っていた場合その情報を以下のようにpar_ratioとして加えることができます。今回は上記の結果を見て適当に10個としてみました。
結果を見るとN_effとR_hat共に改善していることがわかります。ある程度知識を持っているのであれば大雑把にでも指定したほうがよさそうです。

# Xの次元数
D=1000
# 非ゼロの係数の個数
p0=10
par_ratio = p0/(D-p0)

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(regularized_horseshoe, target_accept_prob=0.9)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df.iloc[:, :-1].values,
    y=df["y"].values,
    par_ratio=par_ratio
)
mcmc.print_summary()
                  mean       std    median      5.0%     95.0%     n_eff     r_hat
       beta0      0.77      0.79      0.75     -0.50      2.09    249.31      1.00
        caux   1032.52    996.18    775.49    210.42   1861.12    445.57      1.00
  lambdas[0]      2.55      8.49      0.91      0.00      5.14   1057.49      1.00
  lambdas[1]      2.19      4.29      0.94      0.00      4.91    894.40      1.00
  lambdas[2]      2.60      7.04      0.97      0.00      5.70   1145.11      1.00
  lambdas[3]      2.11      4.64      0.91      0.00      4.71    879.15      1.00
  lambdas[4]      2.99      8.28      1.03      0.00      6.47    787.06      1.00
  lambdas[5]      3.89     13.76      1.09      0.00      7.50    943.71      1.00
  lambdas[6]      2.30      5.51      0.94      0.00      5.18   1129.57      1.00
  lambdas[7]      2.36      5.87      0.96      0.00      4.82    983.58      1.00
  lambdas[8]      4.17     26.52      1.05      0.00      6.55    700.98      1.00
  lambdas[9]      3.02      8.22      0.99      0.00      6.37    611.91      1.00
 lambdas[10]      2.91     10.19      0.95      0.00      5.62    875.64      1.00
 lambdas[11]      2.88      6.67      1.01      0.00      6.16    491.22      1.00
 lambdas[12]      3.73     11.99      1.08      0.00      7.33    801.54      1.00
 lambdas[13]      2.31      4.59      0.91      0.00      5.44   1082.74      1.00
 lambdas[14]      2.53      5.13      1.00      0.00      5.72    634.83      1.00
 lambdas[15]      2.87      7.23      1.02      0.00      6.29    903.70      1.00
 lambdas[16]     16.10     38.02      2.09      0.00     44.98    221.81      1.00
 lambdas[17]      2.31      5.35      0.95      0.00      5.08    904.12      1.00
 lambdas[18]      2.22      4.61      0.94      0.00      4.99   1176.63      1.00
 lambdas[19]      3.84     19.86      1.04      0.00      6.18   1069.36      1.00
 lambdas[20]      2.54      5.68      0.95      0.00      5.82    856.07      1.00
...
      z[998]     -0.08      1.00     -0.06     -1.88      1.44    948.18      1.00
      z[999]      0.02      0.99     -0.00     -1.51      1.72    882.59      1.00

Number of divergences: 50

係数もかなり良く一致しています。

samples = mcmc.get_samples()

print(samples["beta"].mean(axis=0)[np.abs(true_coef) > 0])
print(true_coef[np.abs(true_coef) > 0])
[26.88424   13.017417  66.96016   69.43479    3.9294827 77.563736
 58.247986 ]
[28.63790195 12.72556546 66.80163498 68.72486556  4.45580178 77.34607556
 59.46697217]

先ほどに比べ、真の係数ではない0から離れた係数の数が少なくなっていることがわかります。

numpyro2az = az.from_numpyro(mcmc)
az.plot_forest(numpyro2az, var_names=["beta"], figsize=(8,4));

最後に

今回はスパースモデルとして、Bayesian Lasso回帰と馬蹄事前分布を使用した回帰モデル、正則化つき馬蹄事前分布を使用した回帰モデルを扱いました。今回のようなN<<Pの例では正則化つき馬蹄事前分布を使用した回帰モデルだけがうまく推定できていたので、今後いろいろなケースで活用してみたいと感じる結果でした。次回は「打ち切りデータの扱い方」です。

Discussion