🦔

NumPyroでWAICを計算する

2023/12/09に公開

概要

WAICはAICを改良して様々な場合でも利用できるようにした情報量基準です。詳しい理論的な背景は考案者である渡辺澄夫氏の以下のWebページなどを参照してください。(2024年には公開が終わるようなので早めに保存したほうがいいかも)
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waic2011.html

この記事ではJaxをベースに開発された確率プログラミングのフレームワークであるNumPyroを用いてこのWAICを求める方法について説明します。(WAICの理論やNumPyro自体の詳細な解説はしません)

コードはこちらで公開しています。
https://github.com/lucidfrontier45/numpyro_linear_regression_waic

環境

  • Python 3.11
  • numpyroとjaxtypingをpipなどでインストール

モデル定義とパラメータ推定

ここでは以下のような単純な線形モデルを例にとります。

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jaxtyping import Array, Float32


def linear_model(X: Float32[Array, "N D"], y: Float32[Array, " N"] | None):
    N, D = X.shape

    with numpyro.plate("dimension", D):
        w = numpyro.sample("w", dist.Normal(0, 1))

    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))

    z = jnp.dot(X, w)  # type: ignore

    with numpyro.plate("data", N):
        y = numpyro.sample("y", dist.Normal(z, sigma), obs=y)  # type: ignore

Hamiltonモンテカルロ法を用いたパラメータのサンプリングは以下のように行います。

import jax
import numpy as np
from jaxtyping import Array, Float, Float32
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_likelihood
from scipy.special import logsumexp


def run_mcmc(
    model,
    X: Float32[Array, "N D"],
    y: Float32[Array, " N"],
    num_warmup: int = 1000,
    num_samples: int = 1000,
    num_chains: int = 1,
    seed: int = 0,
):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=False,
    )
    rng_key = jax.random.PRNGKey(seed)
    mcmc.run(rng_key, X, y)
    return mcmc

この関数を実行することでNumPyroのMCMC構造体が得られ、パラメータの事後分布からのサンプルが得られます。

WAICの計算

いよいよWAICの計算です。上記の渡辺氏の資料によるとWAICは以下のように計算することができます。

\begin{align} \mathrm{WAIC} &= T + V \\ T &= -\frac{1}{N}\sum_n^N \log \mathrm{E}_w\left[P(Z_n|w)\right] \\ V &= \frac{1}{N}\sum_n^N \left(\mathrm{E}_w\left[\log{P(Z_n|w)}^2\right] - \mathrm{E}_w\left[\log{P(Z_n|w)}\right]^2 \right) \\ Z_n &= \left\{X_n, y_n \right\} \end{align}

この計算をするためには事後分布のパラメータのサンプルそれぞれについて尤度P(Z_n|w)あるいは大数尤度\log P(Z_n|w)を求める必要がありますが、NumPyroではnumpyro.infer.util.log_likelihoodを使用することでこれを計算することができます。

logp = log_likelihood(model, posterior_samples, X, y)["y"]

これを使用してT,Vを計算していきます。基本的には確率分布に対する期待値をサンプル平均で置き換えます。

まずTですが、式を変形すると以下のようになります。

\begin{align} T = -\frac{1}{N}\sum_n^N \log \frac{1}{M}\sum_m^M \exp \log P(Z_n|w_m) \\ \end{align}

1/Mが途中にありますが、基本的にはいわゆるlogsumexpと呼ばれる計算です。大数尤度のlogsumexpに定数項-\log Mを足すという方式でもいいですが、SciPyのlogsumexp関数はスケーリング項bを受け付けるのでこれを使用すればいいです。

Vについてはnについての和の中身が二乗の平均引く平均の二乗で分散の定義そのままですのでパラメータのサンプル軸m方向にで\log P(Z_n|w_m)の分散を計算し、データのサンプル軸n方向に平均を計算すればいいです。

\begin{align} V = \frac{1}{N}\sum_n^N V_{w}\left[ \log P(Z_n|w_m) \right] \end{align}

まとめると以下のようになります。

def calc_waic(logp: Float[np.ndarray, "M D"]) -> float:
    M = logp.shape[0]  # number of samples
    T = -logsumexp(logp, axis=0, b=1.0 / M).mean()
    V = logp.var(axis=0).mean()
    return T + V


def evaluate_model(
    model,
    X: Float32[Array, "N D"],
    y: Float32[Array, " N"],
    posterior_samples: dict[str, Float32[Array, "M _*"]],
):
    logp = log_likelihood(model, posterior_samples, X, y)["y"]
    return calc_waic(jax.device_get(logp))

実験

# テストデータ準備
# 線形結合の係数wはあえて無駄な次元を2つ付け、モデル選択の検証に用いる 
w = np.array([3.5, -1.5,  0.0, 0.0])
sigma = 0.5

D = len(w)
N = 100
np.random.seed(0)
X_ = np.random.randn(N, D)
y_ = np.dot(X_, w) + np.random.randn(N) * sigma

X = jax.device_put(X_)
y = jax.device_put(y_)

# パラメータ推定とWAICの計算
mcmc = run_mcmc(linear_model, X, y)
waic = evaluate_model(linear_model, X, y, mcmc.get_samples())
print(waic)
> 0.7954699710394434

# arvizの実装と比較
# どうやらデータ数Nで割り算されていないようであるが、その分を除けば一致している
import arviz
arviz.waic(mcmc, scale="negative_log")
>            Estimate       SE
> -elpd_waic    79.55     7.12
> p_waic         4.73        -

# モデル選択
# 後ろ2つのダミー次元を除いた2次元が最適であると正しく求まった
for i in range(4):
    XX = X[:, :D-i]
    mcmc = run_mcmc(linear_model, XX, y)
    waic = evaluate_model(linear_model, XX, y, mcmc.get_samples())
    print(f"WAIC for {D-i} dimensions: {waic}")
> WAIC for 4 dimensions: 0.7954699710394434
> WAIC for 3 dimensions: 0.7834370291883676
> WAIC for 2 dimensions: 0.7765665572323642
> WAIC for 1 dimensions: 1.9153631023793407

まとめ

WAICはパラメータの事後分布からの各サンプルに対する大数尤度が求まればcalc_waicのように非常に簡単に求めることができると分かりました。機械学習のモデルの訓練をベイズ推定で行う場合、ハイパーパラメータの最適化時は交差検証を使用せずにWAICを利用することで大幅に計算時間を削減できそうです。

Discussion