NumPyroでWAICを計算する
概要
WAICはAICを改良して様々な場合でも利用できるようにした情報量基準です。詳しい理論的な背景は考案者である渡辺澄夫氏の以下のWebページなどを参照してください。(2024年には公開が終わるようなので早めに保存したほうがいいかも)
この記事ではJaxをベースに開発された確率プログラミングのフレームワークであるNumPyroを用いてこのWAICを求める方法について説明します。(WAICの理論やNumPyro自体の詳細な解説はしません)
コードはこちらで公開しています。
環境
- Python 3.11
- numpyroとjaxtypingをpipなどでインストール
モデル定義とパラメータ推定
ここでは以下のような単純な線形モデルを例にとります。
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jaxtyping import Array, Float32
def linear_model(X: Float32[Array, "N D"], y: Float32[Array, " N"] | None):
N, D = X.shape
with numpyro.plate("dimension", D):
w = numpyro.sample("w", dist.Normal(0, 1))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
z = jnp.dot(X, w) # type: ignore
with numpyro.plate("data", N):
y = numpyro.sample("y", dist.Normal(z, sigma), obs=y) # type: ignore
Hamiltonモンテカルロ法を用いたパラメータのサンプリングは以下のように行います。
import jax
import numpy as np
from jaxtyping import Array, Float, Float32
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_likelihood
from scipy.special import logsumexp
def run_mcmc(
model,
X: Float32[Array, "N D"],
y: Float32[Array, " N"],
num_warmup: int = 1000,
num_samples: int = 1000,
num_chains: int = 1,
seed: int = 0,
):
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
progress_bar=False,
)
rng_key = jax.random.PRNGKey(seed)
mcmc.run(rng_key, X, y)
return mcmc
この関数を実行することでNumPyroのMCMC構造体が得られ、パラメータの事後分布からのサンプルが得られます。
WAICの計算
いよいよWAICの計算です。上記の渡辺氏の資料によるとWAICは以下のように計算することができます。
この計算をするためには事後分布のパラメータのサンプルそれぞれについて尤度numpyro.infer.util.log_likelihood
を使用することでこれを計算することができます。
logp = log_likelihood(model, posterior_samples, X, y)["y"]
これを使用して
まず
b
を受け付けるのでこれを使用すればいいです。
まとめると以下のようになります。
def calc_waic(logp: Float[np.ndarray, "M D"]) -> float:
M = logp.shape[0] # number of samples
T = -logsumexp(logp, axis=0, b=1.0 / M).mean()
V = logp.var(axis=0).mean()
return T + V
def evaluate_model(
model,
X: Float32[Array, "N D"],
y: Float32[Array, " N"],
posterior_samples: dict[str, Float32[Array, "M _*"]],
):
logp = log_likelihood(model, posterior_samples, X, y)["y"]
return calc_waic(jax.device_get(logp))
実験
# テストデータ準備
# 線形結合の係数wはあえて無駄な次元を2つ付け、モデル選択の検証に用いる
w = np.array([3.5, -1.5, 0.0, 0.0])
sigma = 0.5
D = len(w)
N = 100
np.random.seed(0)
X_ = np.random.randn(N, D)
y_ = np.dot(X_, w) + np.random.randn(N) * sigma
X = jax.device_put(X_)
y = jax.device_put(y_)
# パラメータ推定とWAICの計算
mcmc = run_mcmc(linear_model, X, y)
waic = evaluate_model(linear_model, X, y, mcmc.get_samples())
print(waic)
> 0.7954699710394434
# arvizの実装と比較
# どうやらデータ数Nで割り算されていないようであるが、その分を除けば一致している
import arviz
arviz.waic(mcmc, scale="negative_log")
> Estimate SE
> -elpd_waic 79.55 7.12
> p_waic 4.73 -
# モデル選択
# 後ろ2つのダミー次元を除いた2次元が最適であると正しく求まった
for i in range(4):
XX = X[:, :D-i]
mcmc = run_mcmc(linear_model, XX, y)
waic = evaluate_model(linear_model, XX, y, mcmc.get_samples())
print(f"WAIC for {D-i} dimensions: {waic}")
> WAIC for 4 dimensions: 0.7954699710394434
> WAIC for 3 dimensions: 0.7834370291883676
> WAIC for 2 dimensions: 0.7765665572323642
> WAIC for 1 dimensions: 1.9153631023793407
まとめ
WAICはパラメータの事後分布からの各サンプルに対する大数尤度が求まればcalc_waic
のように非常に簡単に求めることができると分かりました。機械学習のモデルの訓練をベイズ推定で行う場合、ハイパーパラメータの最適化時は交差検証を使用せずにWAICを利用することで大幅に計算時間を削減できそうです。
Discussion