🙆

NumPyro:ガウス過程

2023/05/04に公開

はじめに

今回はガウス過程を扱います。

ライブラリのインポート

import os

import jax
import jax.numpy as jnp
from jax import vmap
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
from numpyro.infer import (
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.enable_x64(True)
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

RBFカーネルと事前分布からのサンプリング

今回は代表的なカーネルであるRBFカーネルを使用します。

def RBF(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    # https://docs.pyro.ai/en/stable/_modules/pyro/contrib/gp/kernels/isotropic.html#RBF
    X = jnp.asarray(X)
    Z = jnp.asarray(Z)
    scaled_X = X / length
    scaled_Z = Z / length
    X2 = (scaled_X**2).sum(axis=1, keepdims=True)
    Z2 = (scaled_Z**2).sum(axis=1, keepdims=True)
    XZ = jnp.matmul(scaled_X, scaled_Z.T)
    r2 = X2 - 2 * XZ + Z2.T
    r2 = jnp.clip(r2, a_min=0)
    k = var * jnp.exp(-0.5 * r2)
    
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k
def model_prior(X):
    K = RBF(X, X, var=1, length=0.2, noise=0.)
    numpyro.sample("y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=K))
    
x_sim = np.linspace(-1, 1, 100)
x_sim = x_sim[:, None]
    
rng_key = random.PRNGKey(0)
prior_predictive = Predictive(model_prior, num_samples=20)
prior_predictions = prior_predictive(rng_key, X=x_sim)["y"]

plt.figure(figsize=(6, 4))
for i in range(20):
    plt.plot(x_sim[:], prior_predictions[i,:])

ガウス過程

データの準備

NumPyroチュートリアル[https://num.pyro.ai/en/latest/examples/gp.html]のデータを使用します。

# create artificial regression dataset
def get_data(N=30, sigma_obs=0.15, N_test=400):
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
    Y += sigma_obs * np.random.randn(N)
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = jnp.linspace(-1.3, 1.3, N_test)

    return X, Y, X_test

X_, Y, X_test_ = get_data(N=25)
X = X_[:, jnp.newaxis]
X_test = X_test_[:, jnp.newaxis]

the marginal likelihood GP

数値的により推奨されているコレスキー分解を使用していない形式です。イメージは一番掴みやすいコードになっています。

def model_marginal_likelihood_GP(X, Y):
    
  var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
  length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
  noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
  
  K = RBF(X, X, var, length, noise)
  numpyro.sample(
      "Y",
      dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=K),
      obs=Y,
  )

rng_key, rng_key_predict = random.split(random.PRNGKey(0))

kernel = NUTS(model_marginal_likelihood_GP, init_strategy=init_to_feasible)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    thinning=2,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
samples = mcmc.get_samples()
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  kernel_length      0.70      0.23      0.66      0.37      1.00    322.50      1.01
   kernel_noise      0.06      0.02      0.06      0.03      0.09    423.82      1.00
     kernel_var      2.57      4.00      1.36      0.29      5.31    328.78      1.00

Number of divergences: 0
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = RBF(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = RBF(X_test, X, var, length, noise, include_noise=False)
    k_XX = RBF(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise

# do prediction
vmap_args = (
    random.split(rng_key_predict, samples["kernel_var"].shape[0]),
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
)
means, predictions = vmap(
    lambda rng_key, var, length, noise: predict(
        rng_key, X, Y, X_test, var, length, noise
    )
)(*vmap_args)

mean_prediction = np.mean(means, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(X.ravel(), Y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test.ravel(), percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test.ravel(), mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

the marginal likelihood GP + cholesky_decompose

数値的により推奨されているコレスキー分解を使用した形式です。

def model_marginal_likelihood_GP_cholesky_decompose(X, Y):
    
  var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
  length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
  noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
  
  K = RBF(X, X, var, length, noise)
  L_K = jnp.linalg.cholesky(K)
  
  numpyro.sample(
      "Y",
      dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), scale_tril=L_K),
      obs=Y,
  )

rng_key, rng_key_predict = random.split(random.PRNGKey(0))

kernel = NUTS(model_marginal_likelihood_GP_cholesky_decompose, init_strategy=init_to_feasible)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    thinning=2,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
samples = mcmc.get_samples()
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  kernel_length      0.67      0.22      0.63      0.33      0.97    393.17      1.00
   kernel_noise      0.06      0.02      0.06      0.03      0.09    425.12      1.00
     kernel_var      2.57      4.22      1.30      0.27      5.55    252.62      1.01

Number of divergences: 0
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = RBF(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = RBF(X_test, X, var, length, noise, include_noise=False)
    k_XX = RBF(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise

# do prediction
vmap_args = (
    random.split(rng_key_predict, samples["kernel_var"].shape[0]),
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
)
means, predictions = vmap(
    lambda rng_key, var, length, noise: predict(
        rng_key, X, Y, X_test, var, length, noise
    )
)(*vmap_args)

mean_prediction = np.mean(means, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(X.ravel(), Y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test.ravel(), percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test.ravel(), mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

the latent variable GP + cholesky_decompose

GPを潜在変数形式にしたものです。最後のNormalのところをPoissonなどに変更するとポアソン回帰などが実装できます。

# パラメータの数が増えるから誤差が増える?
def RBF_for_latent(X, Z, var, length, jitter=1.0e-9, include_noise=True):
    # https://docs.pyro.ai/en/stable/_modules/pyro/contrib/gp/kernels/isotropic.html#RBF
    X = jnp.asarray(X)
    Z = jnp.asarray(Z)
    scaled_X = X / length
    scaled_Z = Z / length
    X2 = (scaled_X**2).sum(axis=1, keepdims=True)
    Z2 = (scaled_Z**2).sum(axis=1, keepdims=True)
    XZ = jnp.matmul(scaled_X, scaled_Z.T)
    r2 = X2 - 2 * XZ + Z2.T
    r2 = jnp.clip(r2, a_min=0)
    k = var * jnp.exp(-0.5 * r2)
    
    if include_noise:
        k += jitter * jnp.eye(X.shape[0])
    return k

def model_latent_variable_gp_cholesky_decompose(X, Y):
    
  var = numpyro.sample("kernel_var", dist.HalfNormal(1))
  length = numpyro.sample("kernel_length", dist.InverseGamma(5, 5))
  noise = numpyro.sample("kernel_noise", dist.HalfNormal(1))
  eta = numpyro.sample("eta", dist.Normal(0.0, 1.0).expand([X.shape[0]]))
  
  K = RBF_for_latent(X, X, var, length)
  L_K = jnp.linalg.cholesky(K)
  f = numpyro.deterministic("f", jnp.matmul(L_K, eta))
  
  numpyro.sample(
      "Y",
      dist.Normal(loc=f, scale=jnp.sqrt(noise)),
      obs=Y,
  )

rng_key, rng_key_predict = random.split(random.PRNGKey(0))

kernel = NUTS(model_latent_variable_gp_cholesky_decompose, init_strategy=init_to_feasible)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    thinning=2,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
samples = mcmc.get_samples()
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
         eta[0]     -1.43      0.35     -1.39     -1.97     -0.89    360.96      1.00
         eta[1]      0.22      0.51      0.21     -0.67      0.98    455.29      1.00
         eta[2]      0.10      0.96      0.02     -1.48      1.62    358.04      1.00
         eta[3]      0.59      0.90      0.64     -0.77      2.17    346.34      1.00
         eta[4]      0.76      0.85      0.80     -0.64      2.03    470.42      1.00
         eta[5]      0.60      0.96      0.55     -0.87      2.17    467.36      1.00
         eta[6]      0.41      1.06      0.42     -1.34      2.15    435.65      1.00
         eta[7]      0.36      0.98      0.33     -1.24      1.88    507.43      1.01
         eta[8]      0.19      0.99      0.20     -1.28      1.92    378.93      1.00
         eta[9]     -0.05      0.92     -0.06     -1.64      1.31    414.54      1.00
        eta[10]     -0.09      1.05     -0.10     -1.74      1.57    388.84      1.00
        eta[11]     -0.11      0.98     -0.11     -1.78      1.32    297.98      1.00
        eta[12]     -0.25      1.00     -0.27     -1.72      1.58    394.82      1.00
        eta[13]     -0.20      0.98     -0.18     -1.79      1.46    415.22      1.00
        eta[14]     -0.19      0.97     -0.16     -1.78      1.34    471.53      1.00
        eta[15]     -0.10      0.97     -0.15     -1.54      1.54    429.80      1.00
        eta[16]     -0.10      0.99     -0.14     -1.43      1.76    441.01      1.00
        eta[17]     -0.05      0.96     -0.10     -1.47      1.64    316.26      1.00
        eta[18]     -0.06      0.98     -0.08     -1.58      1.62    451.81      1.00
        eta[19]     -0.04      0.94     -0.06     -1.42      1.59    484.91      1.00
        eta[20]      0.08      1.10      0.08     -1.79      1.70    403.79      1.00
        eta[21]      0.08      1.04      0.10     -1.74      1.55    444.79      1.00
        eta[22]      0.04      1.01     -0.02     -1.59      1.63    318.95      1.00
...
   kernel_noise      0.07      0.03      0.07      0.03      0.11    331.39      1.00
     kernel_var      1.13      0.49      1.05      0.40      1.88    389.86      1.00

Number of divergences: 0

観測誤差がGPの外に出ているので、GPの事後予測にはyではなくfを使用します。

def predict(rng_key, X, f, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = RBF_for_latent(X_test, X_test, var, length, include_noise=True)
    k_pX = RBF_for_latent(X_test, X, var, length, include_noise=False)
    k_XX = RBF_for_latent(X, X, var, length, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = (jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0))) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, f))

    sigma_obs = jax.random.normal(rng_key, X_test.shape[:1])*jnp.sqrt(noise)
    
    return mean, mean + sigma_noise + sigma_obs

# do prediction
vmap_args = (
    random.split(rng_key_predict, samples["kernel_var"].shape[0]),
    samples["f"],
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
)
means, predictions = vmap(
    lambda rng_key, f, var, length, noise: predict(
        rng_key, X, f, X_test, var, length, noise
    )
)(*vmap_args)

mean_prediction = np.mean(means, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot training data
ax.plot(X.ravel(), Y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test.ravel(), percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test.ravel(), mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

予測

予測は速度はかなり遅くなりますが、以下のようにしても実施できます。

# パラメータの数が増えるから誤差が増える?
def RBF_for_latent(X, Z, var, length, jitter=1.0e-9, include_noise=True):
    # https://docs.pyro.ai/en/stable/_modules/pyro/contrib/gp/kernels/isotropic.html#RBF
    X = jnp.asarray(X)
    Z = jnp.asarray(Z)
    scaled_X = X / length
    scaled_Z = Z / length
    X2 = (scaled_X**2).sum(axis=1, keepdims=True)
    Z2 = (scaled_Z**2).sum(axis=1, keepdims=True)
    XZ = jnp.matmul(scaled_X, scaled_Z.T)
    r2 = X2 - 2 * XZ + Z2.T
    r2 = jnp.clip(r2, a_min=0)
    k = var * jnp.exp(-0.5 * r2)
    
    if include_noise:
        k += jitter * jnp.eye(X.shape[0])
    return k

def model_latent_variable_gp_cholesky_decompose(X_train, Y_train, X_test=None):
    
  N_train = X_train.shape[0]
  X = jnp.vstack([X_train, X_test])
    
  var = numpyro.sample("kernel_var", dist.HalfNormal(1))
  length = numpyro.sample("kernel_length", dist.InverseGamma(5, 5))
  noise = numpyro.sample("kernel_noise", dist.HalfNormal(1))
  eta = numpyro.sample("eta", dist.Normal(0.0, 1.0).expand([X.shape[0]]))
  
  K = RBF_for_latent(X, X, var, length)
  L_K = jnp.linalg.cholesky(K)
  f = numpyro.deterministic("f", jnp.matmul(L_K, eta))
  
  numpyro.sample(
      "Y",
      # 予測の式のnoiseが分散前提なので、標準偏差に変換
      dist.Normal(loc=f[:N_train, ...], scale=jnp.sqrt(noise)),
      obs=Y_train,
  )
  Y_test = numpyro.deterministic(
      "y_test",
       numpyro.sample("y_tmp", dist.Normal(loc=f[N_train:, ...], scale=jnp.sqrt(noise)))
      )

X_test = jnp.linspace(-1.3, 1.3, 50)
X_test = X_test[..., None]

rng_key, rng_key_predict = random.split(random.PRNGKey(0))

kernel = NUTS(model_latent_variable_gp_cholesky_decompose, init_strategy=init_to_feasible)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    thinning=2,
)
mcmc.run(rng_key, X, Y, X_test)
mcmc.print_summary()
samples = mcmc.get_samples()
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
         eta[0]     -1.42      0.40     -1.36     -2.01     -0.81    249.50      1.00
         eta[1]      0.18      0.50      0.18     -0.63      0.97    348.62      1.00
         eta[2]      0.09      0.90     -0.00     -1.36      1.47    271.80      1.01
         eta[3]      0.58      0.95      0.57     -1.26      1.86    439.55      1.00
         eta[4]      0.68      0.96      0.67     -0.85      2.20    494.06      1.00
         eta[5]      0.56      0.94      0.62     -1.03      2.08    480.94      1.00
         eta[6]      0.55      1.00      0.58     -1.24      2.08    409.36      1.00
         eta[7]      0.33      1.01      0.33     -1.28      2.07    384.09      1.00
         eta[8]      0.21      1.02      0.21     -1.33      2.03    419.77      1.00
         eta[9]      0.07      0.95      0.11     -1.59      1.55    455.84      1.00
        eta[10]     -0.01      0.98      0.04     -1.53      1.61    387.75      1.00
        eta[11]     -0.19      1.01     -0.17     -1.88      1.34    554.14      1.00
        eta[12]     -0.18      1.01     -0.16     -1.98      1.27    521.37      1.00
        eta[13]     -0.12      0.96     -0.14     -1.82      1.28    358.13      1.00
        eta[14]     -0.09      0.90     -0.08     -1.65      1.23    467.08      1.00
        eta[15]     -0.10      0.94     -0.11     -1.66      1.40    505.94      1.00
        eta[16]     -0.07      0.99     -0.14     -1.45      1.70    537.61      1.00
        eta[17]     -0.06      1.02     -0.10     -1.60      1.73    475.25      1.00
        eta[18]     -0.07      1.03     -0.15     -1.75      1.72    446.85      1.00
        eta[19]     -0.05      0.99     -0.04     -1.70      1.50    435.32      1.00
        eta[20]      0.05      1.09      0.06     -1.98      1.62    488.93      1.00
        eta[21]      0.03      1.04      0.01     -1.64      1.82    495.63      1.00
        eta[22]     -0.04      1.01     -0.04     -1.90      1.44    562.01      1.00
...
   kernel_noise      0.07      0.03      0.07      0.03      0.11    365.76      1.00
     kernel_var      1.16      0.57      1.03      0.36      2.03    307.11      1.00

Number of divergences: 0
mean_prediction = np.mean(samples["f"], axis=0)
percentiles = np.percentile(samples["y_test"], [5.0, 95.0], axis=0)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

N_train = X.shape[0]

# plot training data
ax.plot(X.ravel(), Y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test.ravel(), percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test.ravel(), mean_prediction[N_train:], "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

最後に

以上で「ガウス過程」は終わりです。
データ(訓練時/推論時の両方)が許容できる範囲で少ない&カスタマイズのしやすさを重視したい場合はNumPyroでも問題はないですが、NumPyroと同じように書けるPyroではガウス過程用のモジュールがあり、GPyTorchとの連携も可能なので、データ数が多くなった場合や単純なガウス過程を行いたい場合はGPyTorchやPyroを使った方が楽かなと思いました。次は、「時系列分析」です。

Discussion