NumPyro:次元圧縮
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
今回は次元圧縮としてベイジアン主成分分析とその拡張であるautomatic relevance determination(ARD)付き主成分分析を扱います。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
ベイジアン主成分分析
データの準備
Irisデータセットを使用します。今回は、平均が0であることを前提にモデリングするので、標準化をしています。
iris = load_iris()
df = pd.DataFrame(iris["data"], columns=iris['feature_names'])
df["name"] = iris.target
# 前処理
X = df[iris['feature_names']]
sc = StandardScaler()
X_sc = sc.fit_transform(X)
モデルの定義
主成分分析の場合は、K次元の潜在変数z
がW
により線形変換され元の行列X
が生成されるとしてモデリングします。書籍を参考にしてモデルの実装を行いました。
def model(X):
N = len(X)
D = X.shape[1]
K = 2
W = numpyro.sample("W", dist.Normal(0, 1).expand([K, D]).to_event(2))
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
with numpyro.plate("N", N):
z = numpyro.sample("z", dist.Normal(0, 1).expand([K]).to_event(1))
mu = numpyro.deterministic("mu", jnp.dot(z, W))
numpyro.sample("X", dist.Normal(mu, sigma).to_event(1), obs=X)
Shapeの確認
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(model).get_trace(X=X_sc)
print(numpyro.util.format_shapes(trace))
Trace Shapes:
Param Sites:
Sample Sites:
W dist | 2 4
value | 2 4
sigma dist |
value |
N plate 150 |
z dist 150 | 2
value 150 | 2
X dist 150 | 4
value 150 | 4
モデルのレンダリング
numpyro.render_model(
model=model,
model_kwargs={"X": X_sc},
render_params=True,
render_distributions=True
)
MCMC
今回はWとZの符号を反転させても同じ観測値が得られるので、Wが一意に決まりません。そのため、各chainごとに異なる結果になるため、num_chains=1
としています。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=X_sc
)
可視化
結果をsklearnのPCAと比較してみます。ほとんど一緒の結果になりました。
samples = mcmc.get_samples()
mean = samples["z"].mean(axis=0)
plt.scatter(mean[:, 0], mean[:, 1], c=iris.target)
from sklearn.decomposition import PCA
X_pca = PCA(n_components=2).fit_transform(X_sc)
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=iris.target)
Automatic Relevance Determination(ARD)付き主成分分析
上記のPCAだと潜在変数の次元数を2と決め打ちでしたが、Automatic Relevance Determination(ARD)を使用することで自動的に情報量が多い次元を考慮して次元数を決めることができます。
データの準備
こちらのデモデータのコードを使用しています。ここで、平均は標準化で0にするので省略してます。
def create_toy_data(sample_size=100, ndim_hidden=1, ndim_observe=2, std=1.):
Z = np.random.normal(size=(sample_size, ndim_hidden))
#mu = np.random.uniform(-5, 5, size=(ndim_observe))
W = np.random.uniform(-5, 5, (ndim_hidden, ndim_observe))
# PRML式(12.33)
X = Z.dot(W) + np.random.normal(scale=std, size=(sample_size, ndim_observe))
#X = Z.dot(W) + mu + np.random.normal(scale=std, size=(sample_size, ndim_observe))
return X, W
X, W_true = create_toy_data(sample_size=100, ndim_hidden=3, ndim_observe=10, std=1.)
sc = StandardScaler()
X_sc = sc.fit_transform(X)
モデルの定義
こちらを参考にして実装しました。先ほどと異なる点はWのスケール
def model_ard(X):
N = len(X)
D = X.shape[1]
alpha = numpyro.sample("alpha", dist.Gamma(1, 1).expand([D]).to_event(1))
W = numpyro.sample("W", dist.Normal(jnp.zeros((D, D)), jnp.tile(alpha, D).reshape((D, D)).T).to_event(2))
sigma = numpyro.sample("sigma", dist.Gamma(1, 1))
#mu = numpyro.sample("mu", dist.Normal(0, 1).expand([D]).to_event(1))
with numpyro.plate("N", N):
z = numpyro.sample("z", dist.Normal(0, 1).expand([D]).to_event(1))
mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W))
#mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W) + mu)
numpyro.sample("X", dist.Normal(mu_x, sigma).to_event(1), obs=X)
MCMC
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model_ard)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=X_sc
)
結果の確認
samples = mcmc.get_samples()
np.exp(samples["alpha"].mean(axis=0))
array([1.0321453, 1.856095 , 1.0323012, 1.0356137, 1.0291867, 1.035723 ,
1.839454 , 1.9357456, 1.0326754, 1.0342481], dtype=float32)
値が小さい次元のWを見ると、Wのほとんどが0になっていることが分かります。
sns.displot(samples["W"][:, 0, :].ravel())
それに対して、値が大きい次元のWを見ると、何かしら情報が取れていそうです。このように、自動で
最後に、Hinton diagramを描画して、Wを可視化してみます。3次元以外もかすかに残っていますが、重要な次元は3次元だけであることがわかります。
def hinton(matrix, max_weight=None, ax=None):
"""Draw Hinton diagram for visualizing a weight matrix."""
ax = ax if ax is not None else plt.gca()
if not max_weight:
max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))
ax.patch.set_facecolor('gray')
ax.set_aspect('equal', 'box')
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
for (x, y), w in np.ndenumerate(matrix):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([y - size / 2, x - size / 2], size, size,
facecolor=color, edgecolor=color)
ax.add_patch(rect)
ax.autoscale_view()
ax.invert_yaxis()
plt.xlim(-0.5, np.size(matrix, 1) - 0.5)
plt.ylim(-0.5, len(matrix) - 0.5)
plt.show()
hinton(samples["W"].mean(axis=0))
最後に
以上で「次元圧縮」は終わりです。automatic relevance determination(ARD)付き主成分分析は使えそうな気がしますね。次回は「ABテスト」です。
Discussion