👌

NumPyro:次元圧縮

2023/04/28に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は次元圧縮としてベイジアン主成分分析とその拡張であるautomatic relevance determination(ARD)付き主成分分析を扱います。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

ベイジアン主成分分析

データの準備

Irisデータセットを使用します。今回は、平均が0であることを前提にモデリングするので、標準化をしています。

iris = load_iris()
df = pd.DataFrame(iris["data"], columns=iris['feature_names'])
df["name"] = iris.target

# 前処理
X = df[iris['feature_names']]
sc = StandardScaler()
X_sc = sc.fit_transform(X)

モデルの定義

主成分分析の場合は、K次元の潜在変数zWにより線形変換され元の行列Xが生成されるとしてモデリングします。書籍を参考にしてモデルの実装を行いました。

z_j \sim Normal(0, 1) \\ w_{ij} \sim Normal(0, 1) \\ sigma \sim HalfCauchy(1) \\ x_{i} \sim Normal((wz)_i, sigma)
def model(X):
    N = len(X)
    D = X.shape[1]
    K = 2
    
    W = numpyro.sample("W", dist.Normal(0, 1).expand([K, D]).to_event(2))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
    
    with numpyro.plate("N", N):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([K]).to_event(1))
        mu = numpyro.deterministic("mu", jnp.dot(z, W))
        numpyro.sample("X", dist.Normal(mu, sigma).to_event(1), obs=X)

Shapeの確認

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model).get_trace(X=X_sc)
print(numpyro.util.format_shapes(trace))
Trace Shapes:          
 Param Sites:          
Sample Sites:          
       W dist     | 2 4
        value     | 2 4
   sigma dist     |    
        value     |    
      N plate 150 |    
       z dist 150 | 2  
        value 150 | 2  
       X dist 150 | 4  
        value 150 | 4  

モデルのレンダリング

numpyro.render_model(
    model=model, 
    model_kwargs={"X": X_sc}, 
    render_params=True, 
    render_distributions=True
)

MCMC

今回はWとZの符号を反転させても同じ観測値が得られるので、Wが一意に決まりません。そのため、各chainごとに異なる結果になるため、num_chains=1としています。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=X_sc
)

可視化

結果をsklearnのPCAと比較してみます。ほとんど一緒の結果になりました。

samples = mcmc.get_samples()
mean = samples["z"].mean(axis=0)

plt.scatter(mean[:, 0], mean[:, 1], c=iris.target)

from sklearn.decomposition import PCA

X_pca = PCA(n_components=2).fit_transform(X_sc)

plt.scatter(X_pca[:, 0], X_pca[:, 1], c=iris.target)

Automatic Relevance Determination(ARD)付き主成分分析

上記のPCAだと潜在変数の次元数を2と決め打ちでしたが、Automatic Relevance Determination(ARD)を使用することで自動的に情報量が多い次元を考慮して次元数を決めることができます。

データの準備

こちらのデモデータのコードを使用しています。ここで、平均は標準化で0にするので省略してます。

def create_toy_data(sample_size=100, ndim_hidden=1, ndim_observe=2, std=1.):
    Z = np.random.normal(size=(sample_size, ndim_hidden))
    #mu = np.random.uniform(-5, 5, size=(ndim_observe))
    W = np.random.uniform(-5, 5, (ndim_hidden, ndim_observe))

    # PRML式(12.33)
    X = Z.dot(W) + np.random.normal(scale=std, size=(sample_size, ndim_observe))
    #X = Z.dot(W) + mu + np.random.normal(scale=std, size=(sample_size, ndim_observe))
    return X, W

X, W_true = create_toy_data(sample_size=100, ndim_hidden=3, ndim_observe=10, std=1.)

sc = StandardScaler()
X_sc = sc.fit_transform(X)

モデルの定義

こちらを参考にして実装しました。先ほどと異なる点はWのスケール\alphaを事前分布からサンプリングしている点です。これが自動で調整されて、情報量が多い次元だけ残るイメージです。

\alpha_j \sim Gamma(1, 1) \\ w_{ij} \sim Normal(0, \alpha_j) \\ z_j \sim Normal(0, 1) \\ \sigma \sim Gamma(1, 1) \\ x_{i} \sim Normal((wz)_i, sigma)
def model_ard(X):
    N = len(X)
    D = X.shape[1]
    
    alpha = numpyro.sample("alpha", dist.Gamma(1, 1).expand([D]).to_event(1))
    W = numpyro.sample("W", dist.Normal(jnp.zeros((D, D)), jnp.tile(alpha, D).reshape((D, D)).T).to_event(2))
    sigma = numpyro.sample("sigma", dist.Gamma(1, 1))
    #mu = numpyro.sample("mu", dist.Normal(0, 1).expand([D]).to_event(1))
    
    with numpyro.plate("N", N):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([D]).to_event(1))
        mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W))
        #mu_x = numpyro.deterministic("mu_x", jnp.dot(z, W) + mu)
        numpyro.sample("X", dist.Normal(mu_x, sigma).to_event(1), obs=X)

MCMC

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_ard)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=X_sc
)

結果の確認

\alphaの事後分布の平均値を確認します。今回のデモデータは3次元だけ意味のある次元が含まれていますが、\alphaの値が1.9付近の次元が3つあることが分かります。

samples = mcmc.get_samples()
np.exp(samples["alpha"].mean(axis=0))
array([1.0321453, 1.856095 , 1.0323012, 1.0356137, 1.0291867, 1.035723 ,
       1.839454 , 1.9357456, 1.0326754, 1.0342481], dtype=float32)

値が小さい次元のWを見ると、Wのほとんどが0になっていることが分かります。

sns.displot(samples["W"][:, 0, :].ravel())

それに対して、値が大きい次元のWを見ると、何かしら情報が取れていそうです。このように、自動で\alphaが調整されて潜在変数の次元数が決定できることがわかります。

最後に、Hinton diagramを描画して、Wを可視化してみます。3次元以外もかすかに残っていますが、重要な次元は3次元だけであることがわかります。

def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([y - size / 2, x - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()
    plt.xlim(-0.5, np.size(matrix, 1) - 0.5)
    plt.ylim(-0.5, len(matrix) - 0.5)
    plt.show()

hinton(samples["W"].mean(axis=0))

最後に

以上で「次元圧縮」は終わりです。automatic relevance determination(ARD)付き主成分分析は使えそうな気がしますね。次回は「ABテスト」です。

Discussion