👌

2023/04/28に公開

# ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import StandardScaler

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)


# ベイジアン主成分分析

## データの準備

Irisデータセットを使用します。今回は、平均が0であることを前提にモデリングするので、標準化をしています。

iris = load_iris()
df = pd.DataFrame(iris["data"], columns=iris['feature_names'])
df["name"] = iris.target

# 前処理
X = df[iris['feature_names']]
sc = StandardScaler()
X_sc = sc.fit_transform(X)


## モデルの定義

z_j \sim Normal(0, 1) \\ w_{ij} \sim Normal(0, 1) \\ sigma \sim HalfCauchy(1) \\ x_{i} \sim Normal((wz)_i, sigma)
def model(X):
N = len(X)
D = X.shape[1]
K = 2

W = numpyro.sample("W", dist.Normal(0, 1).expand([K, D]).to_event(2))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))

with numpyro.plate("N", N):
z = numpyro.sample("z", dist.Normal(0, 1).expand([K]).to_event(1))
mu = numpyro.deterministic("mu", jnp.dot(z, W))
numpyro.sample("X", dist.Normal(mu, sigma).to_event(1), obs=X)


## Shapeの確認

with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace(X=X_sc)
print(numpyro.util.format_shapes(trace))

Trace Shapes:
Param Sites:
Sample Sites:
W dist     | 2 4
value     | 2 4
sigma dist     |
value     |
N plate 150 |
z dist 150 | 2
value 150 | 2
X dist 150 | 4
value 150 | 4


## モデルのレンダリング

numpyro.render_model(
model=model,
model_kwargs={"X": X_sc},
render_params=True,
render_distributions=True
)


## MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=X_sc
)


## 可視化

samples = mcmc.get_samples()
mean = samples["z"].mean(axis=0)

plt.scatter(mean[:, 0], mean[:, 1], c=iris.target)


from sklearn.decomposition import PCA

X_pca = PCA(n_components=2).fit_transform(X_sc)

plt.scatter(X_pca[:, 0], X_pca[:, 1], c=iris.target)


# Automatic Relevance Determination(ARD)付き主成分分析

## データの準備

こちらのデモデータのコードを使用しています。ここで、平均は標準化で0にするので省略してます。

def create_toy_data(sample_size=100, ndim_hidden=1, ndim_observe=2, std=1.):
Z = np.random.normal(size=(sample_size, ndim_hidden))
#mu = np.random.uniform(-5, 5, size=(ndim_observe))
W = np.random.uniform(-5, 5, (ndim_hidden, ndim_observe))

# PRML式(12.33)
X = Z.dot(W) + np.random.normal(scale=std, size=(sample_size, ndim_observe))
#X = Z.dot(W) + mu + np.random.normal(scale=std, size=(sample_size, ndim_observe))
return X, W

X, W_true = create_toy_data(sample_size=100, ndim_hidden=3, ndim_observe=10, std=1.)

sc = StandardScaler()
X_sc = sc.fit_transform(X)


## モデルの定義

こちらを参考にして実装しました。先ほどと異なる点はWのスケール\alphaを事前分布からサンプリングしている点です。これが自動で調整されて、情報量が多い次元だけ残るイメージです。

\alpha_j \sim Gamma(1, 1) \\ w_{ij} \sim Normal(0, \alpha_j) \\ z_j \sim Normal(0, 1) \\ \sigma \sim Gamma(1, 1) \\ x_{i} \sim Normal((wz)_i, sigma)
def model_ard(X):
N = len(X)
D = X.shape[1]

alpha = numpyro.sample("alpha", dist.Gamma(1, 1).expand([D]).to_event(1))
W = numpyro.sample("W", dist.Normal(jnp.zeros((D, D)), jnp.tile(alpha, D).reshape((D, D)).T).to_event(2))
sigma = numpyro.sample("sigma", dist.Gamma(1, 1))
#mu = numpyro.sample("mu", dist.Normal(0, 1).expand([D]).to_event(1))

with numpyro.plate("N", N):
z = numpyro.sample("z", dist.Normal(0, 1).expand([D]).to_event(1))
mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W))
#mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W) + mu)
numpyro.sample("X", dist.Normal(mu_x, sigma).to_event(1), obs=X)


## MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_ard)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=X_sc
)


## 結果の確認

\alphaの事後分布の平均値を確認します。今回のデモデータは３次元だけ意味のある次元が含まれていますが、\alphaの値が1.9付近の次元が3つあることが分かります。

samples = mcmc.get_samples()
np.exp(samples["alpha"].mean(axis=0))

array([1.0321453, 1.856095 , 1.0323012, 1.0356137, 1.0291867, 1.035723 ,
1.839454 , 1.9357456, 1.0326754, 1.0342481], dtype=float32)


sns.displot(samples["W"][:, 0, :].ravel())


それに対して、値が大きい次元のWを見ると、何かしら情報が取れていそうです。このように、自動で\alphaが調整されて潜在変数の次元数が決定できることがわかります。

def hinton(matrix, max_weight=None, ax=None):
"""Draw Hinton diagram for visualizing a weight matrix."""
ax = ax if ax is not None else plt.gca()

if not max_weight:
max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

ax.patch.set_facecolor('gray')
ax.set_aspect('equal', 'box')
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())

for (x, y), w in np.ndenumerate(matrix):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([y - size / 2, x - size / 2], size, size,
facecolor=color, edgecolor=color)

ax.autoscale_view()
ax.invert_yaxis()
plt.xlim(-0.5, np.size(matrix, 1) - 0.5)
plt.ylim(-0.5, len(matrix) - 0.5)
plt.show()

hinton(samples["W"].mean(axis=0))