NumPyro:順序回帰と独自の分布の定義
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
今回は順序回帰と独自の分布の定義の仕方を見ていきます。順序回帰に関してはNumPyroのチュートリアル通りになります。より詳細な式などに関してはこちらを参照ください。
独自の分布の定義は他に入れる記事がなかったので、適当に今回の記事の中に入っているだけで順序回帰との関係はありません。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
from jax.nn import softplus
import jax.numpy as jnp
from jax.scipy.special import expit, logit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
from numpyro.infer.reparam import TransformReparam
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
順序回帰
順序付きのデータをモデリングする際は以下のように考えます。
左図は何かしらの連続値の潜在変数(
このとき、右図に示されているように、各カテゴリの確率は以下のように累積分布関数の差分で表されます。
ここで、累積分布関数は以下の形で与えられます。
ここで、上記の関数に
この式に今考えている潜在変数Xからの影響を表した
これが実装されているのがdist.OrderedLogistic
分布になります。同じ式ですが、この分布はStanのマニュアルに記載のものと同じです。
データの準備
チュートリアルと同じデータです。Xは1変数で、Yが3クラスある順序付きの変数になります。
simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = dist.Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X =dist. Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y
print("value counts of Y:")
df = pd.DataFrame({"X": X, "Y": Y})
print(df.Y.value_counts())
for i in range(nclasses):
print(f"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}")
value counts of Y:
1 19
2 16
0 15
Name: Y, dtype: int64
mean(X) for Y == 0: 0.042
mean(X) for Y == 1: 0.832
mean(X) for Y == 2: 1.448
モデルの定義とMCMC
上記の内容からdist.OrderedLogistic
には潜在変数からの影響を表したc_k
の両方が必要になります。ImproperUniform
とNormal
、Dirichlet
分布からサンプリングした値を使用する3つの方法を実装します。
ImproperUniforを使用する場合
カットポイントnclasses - 1
個になります。また、カットポイントはパラメータの識別可能性の点からも順序付きのベクトルである必要があるため、support=constraints.ordered_vector
として指定しています。
ImproperUniform分布は、制約もある分布ですが、そのパラメータに関する事前分布の位置やスケールなどの情報を追加することなく、ドメインに制約のあるパラメータを使用することができます。
def model1(X, Y=None):
n_classes = len(np.unique(Y))
b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 10))
c_y = numpyro.sample(
"c_y",
dist.ImproperUniform(
support=constraints.ordered_vector,
batch_shape=(),
event_shape=(nclasses - 1,),
),
)
eta = numpyro.deterministic("eta", X * b_X_eta)
with numpyro.plate("obs", X.shape[0]):
numpyro.sample("Y", dist.OrderedLogistic(predictor=eta, cutpoints=c_y), obs=Y)
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df["X"].values,
Y=df["Y"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.43 0.37 1.42 0.83 2.03 4135.09 1.00
c_y[0] -0.11 0.40 -0.11 -0.77 0.53 4096.81 1.00
c_y[1] 2.16 0.51 2.14 1.33 2.99 4468.53 1.00
Number of divergences: 0
Normalを使用する場合
TransformedDistribution
を使用することで、nclasses - 1
個の正規分布dist.Normal(0, 1)
からサンプリングしたベクトルをdist.transforms.OrderedTransform
により順序付きのベクトルに変換しています。
ベースの分布に正規分布dist.Normal(0, 1)
を使用しているため、カットポイントが0から離れすぎない値になってほしい時などに有効です。
def model2(X, Y=None):
n_classes = len(np.unique(Y))
b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 10))
# 今回はカットポイントが0から離れすぎない値になってほしいのでNormal(0, 1)で順序付きに変換
c_y = numpyro.sample(
"c_y",
dist.TransformedDistribution(
dist.Normal(0, 1).expand([nclasses - 1]), dist.transforms.OrderedTransform()
),
)
eta = numpyro.deterministic("eta", X * b_X_eta)
with numpyro.plate("obs", X.shape[0]):
numpyro.sample("Y", dist.OrderedLogistic(predictor=eta, cutpoints=c_y), obs=Y)
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model2)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df["X"].values,
Y=df["Y"].values,
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.37 0.35 1.36 0.80 1.94 4868.12 1.00
c_y[0] -0.05 0.36 -0.06 -0.62 0.55 4570.64 1.00
c_y[1] 2.04 0.47 2.02 1.22 2.77 4933.96 1.00
Number of divergences: 0
Dirichletを使用する場合
こちらは2023/4/22現在の時点でバグが見つかったので、既存の関数に変更を加えたものを使用します。
class SimplexToOrderedTransform_(dist.transforms.Transform):
"""
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints)
Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.
:param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform.
For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1].
For more details please refer to Section 2.2 in [1]
**References:**
1. *Ordinal Regression Case Study, section 2.2*,
M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
"""
domain = constraints.simplex
codomain = constraints.ordered_vector
def __init__(self, anchor_point=0.0):
self.anchor_point = anchor_point
def __call__(self, x):
s = jnp.cumsum(x[..., :-1], axis=-1)
y = logit(s) + jnp.expand_dims(self.anchor_point, -1)
return y
def _inverse(self, y):
y = y - jnp.expand_dims(self.anchor_point, -1)
s = expit(y)
# x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1]
# add two boundary points 0 and 1
pad_width = [(0, 0)] * (jnp.ndim(s) - 1) + [(1, 1)]
s = jnp.pad(s, pad_width, constant_values=(0, 1))
x = s[..., 1:] - s[..., :-1]
return x
def log_abs_det_jacobian(self, x, y, intermediates=None):
# |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y))
# we know log derivative of expit(y) is `-softplus(y) - softplus(-y)`
J_logdet = (softplus(y) + softplus(-y)).sum(-1)
return J_logdet
def forward_shape(self, shape):
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape[:-1] + (shape[-1] - 1,)
def inverse_shape(self, shape):
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape[:-1] + (shape[-1] - 1,)
Dirichlet分布からSimplexなベクトルをサンプリングし、その後SimplexToOrderedTransform_
により順序付きのベクトルへ変換しています。Dirichlet分布を事前分布に使用することで各クラスの出現割合に事前知識を持っていた場合に反映することが可能になります。
例えば、以下のようにconcentration
を各クラス全てで同じようにすれば各クラスの出現確率は同じになります。
concentration = jnp.ones(3)*10
d = dist.TransformedDistribution(
dist.Dirichlet(concentration),
SimplexToOrderedTransform_(0.0),
)
cutpoints = d.sample(random.PRNGKey(0), sample_shape=(1000,))
from jax.scipy.special import logit, expit
p1 = expit(cutpoints[:, 0]) - expit(-np.inf)
p2 = expit(cutpoints[:, 1]) - expit(cutpoints[:, 0])
p3 = expit(np.inf) - expit(cutpoints[:, 1])
print(p1.mean(), p2.mean(), p3.mean())
0.33179164 0.33491898 0.33328944
また、レアなクラスがあった場合に以下のように事前知識を入れることができます。特にデータが少なく&レアなクラスがある場合などに有効そうです。
# 事前情報として出現割合を知識として入れ込むことができる
concentration = jnp.array([1, 4, 5])
d = dist.TransformedDistribution(
dist.Dirichlet(concentration),
SimplexToOrderedTransform_(0.0),
)
cutpoints = d.sample(random.PRNGKey(0), sample_shape=(1000,))
from jax.scipy.special import logit, expit
p1 = expit(cutpoints[:, 0]) - expit(-np.inf)
p2 = expit(cutpoints[:, 1]) - expit(cutpoints[:, 0])
p3 = expit(np.inf) - expit(cutpoints[:, 1])
print(p1.mean(), p2.mean(), p3.mean())
これらを使用したモデルが以下になります。anchor_point
はsimplexベクトルは総和が1になるという制約がある中でカットポイントを決める際に使用するものです。
def model3(X, Y, nclasses, concentration, anchor_point=0.0):
b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 5))
with numpyro.handlers.reparam(config={"c_y": TransformReparam()}):
c_y = numpyro.sample(
"c_y",
dist.TransformedDistribution(
dist.Dirichlet(concentration),
SimplexToOrderedTransform_(anchor_point),
#dist.transforms.SimplexToOrderedTransform(anchor_point),
)
)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
numpyro.sample("Y", dist.OrderedLogistic(eta, c_y), obs=Y)
concentration = np.ones((nclasses,)) * 10.0
rng_key= random.PRNGKey(0)
kernel = NUTS(model3)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df["X"].values,
Y=df["Y"].values,
nclasses=nclasses,
concentration=concentration,
)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
mean std median 5.0% 95.0% n_eff r_hat
b_X_eta 1.01 0.29 1.00 0.51 1.47 1222.00 1.00
c_y[0] -0.43 0.28 -0.43 -0.85 0.07 1566.81 1.00
c_y[1] 1.35 0.31 1.36 0.84 1.83 1212.22 1.00
c_y_base[0] 0.40 0.07 0.39 0.29 0.50 1559.41 1.00
c_y_base[1] 0.39 0.06 0.39 0.30 0.49 1721.26 1.00
c_y_base[2] 0.21 0.05 0.20 0.13 0.29 1249.45 1.00
独自の分布の定義
こでは、Conway–Maxwell–Poisson distributionというNumPyroでは実装がされていない分布を作成してみます。この分布はポアソン分布を拡張したような分布です。ポアソン分布の場合平均と分散が同じ
実装としては、英語版Wikipediaとpymc3での実装、numpyroのdist.Poissonのコードを参考にしました。
sample
メソッドは一様分布からサンプリングした値とCDFの値から乱数生成しています。log_prob
メソッドは英語版Wikipediaの式をそのまま実装してます。pymc3での実装では無限級数のところを無視していたのですが、CDFの値がWikipediaの図と結構違うと個人的に感じたので修正しています。cdf
メソッドは今回は離散変数なので全部の確率の総和を取っているだけです。
このように必要なメソッドを定義すれば自分でも独自の分布を定義できます。 今回はsample
メソッドを2重ループで書いているので複雑なモデルでは遅くなりそうです。逆関数を作成するなどして行列演算で書き直した方がいいかもしれません。もし独自の分布の定義の仕方をより詳しく知りたい方はドキュメントの5.1 Recap of NumPyro distributionsに載っているので参考にしてみてください。
from numpyro.distributions import Distribution
from jax.scipy.special import gammaln, logsumexp
from numpyro.distributions.util import is_prng_key, validate_sample
class CMPoisson(Distribution):
# https://en.wikipedia.org/wiki/Conway–Maxwell–Poisson_distribution
# dist.Poissonを参考に実装
# この分布は、カウントデータでよく見られる過剰分散や過小分散を考慮することができる柔軟な分布である
# https://rss.onlinelibrary.wiley.com/doi/10.1111/j.1467-9876.2005.00474.x
# https://gist.github.com/dadaromeo/33e581d9e3bcbad83531b4a91a87509f
arg_constraints = {"lambda_": constraints.positive, "nu": constraints.positive}
support = constraints.nonnegative_integer
def __init__(self, lambda_, nu, *, validate_args=None):
self.lambda_ = lambda_
self.nu = nu
self.alpha = jnp.power(lambda_, 1/nu)
super(CMPoisson, self).__init__(jnp.shape(lambda_), validate_args=validate_args)
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
shape = sample_shape + self.batch_shape
u = random.uniform(key, shape).ravel()
values = jnp.empty(u.shape, dtype=int)
for i in range(len(u)):
value = 0
cdf = self.cdf(value)
while u[i] > cdf:
value += 1
cdf = self.cdf(value)
values = values.at[i].set(value)
return values.reshape(shape)
@validate_sample
def log_prob(self, value):
""" log(PMF)
gammaln : 正の整数 nに対して階乗関数にlogを取ったものになる
https://hazm.at/mox/math/special-function/gamma-function.html
"""
if self._validate_args:
self._validate_sample(value)
lambda_ = self.lambda_
nu = self.nu
alpha = self.alpha
# The normalizing constantはclosed formを持たないので、漸近展開した形を使用
# 参考にしたコードでは無限級数のところを無視しているので修正
log_Z = nu * alpha - (nu - 1)*0.5*jnp.log(2*jnp.pi*alpha) - 0.5*jnp.log(nu)
log_Z += jnp.log(1. + (nu**2 - 1.)/24.*jnp.power(nu*alpha, -1.) + (nu**2 - 1.)*(nu**2+23)/1152*jnp.power(nu*alpha, -2.))
return value * jnp.log(lambda_) - nu * gammaln(value+1) - log_Z
def cdf(self, value):
return jnp.sum(jnp.exp(self.log_prob(jnp.arange(value+1))))
ちゃんと実装できているかCDFの図をWikipediaと比べてみます。
X = jnp.arange(20)
plt.figure(figsize=(4,2))
plt.plot(X, [CMPoisson(1, 1.5).cdf(x) for x in X], c="blue")
plt.plot(X, [CMPoisson(3, 1.1).cdf(x) for x in X], c="green")
plt.plot(X, [CMPoisson(5, 0.7).cdf(x) for x in X], c="red")
Wikipediaの図
pymc3での実装を真似てモデルに使用してみます。これはちゃんと動くかどうかの確認だけです。
データの準備
n,d = 1000, 4
X = np.abs(np.random.randn(n,d))
y = np.round(X.sum(axis=1)).astype(int)
pd.Series(y).describe()
count 1000.000000
mean 3.236000
std 1.216464
min 1.000000
25% 2.000000
50% 3.000000
75% 4.000000
max 8.000000
dtype: float64
plt.hist(y, bins=50)
plt.title("Observed data")
モデルの定義とMCMC
def model(X, y=None):
alpha = numpyro.sample("alpha", dist.Normal(1, 1))
beta = numpyro.sample("beta", dist.Normal(1, 1).expand([X.shape[1]]))
nu = numpyro.sample("nu", dist.HalfNormal(10))
lam = numpyro.deterministic("lam", alpha + jnp.dot(X, beta))
with numpyro.plate("N", len(X)):
numpyro.sample("obs", CMPoisson(lambda_=lam, nu=nu), obs=y)
nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4)
mcmc.run(random.PRNGKey(0), X, y)
mcmc_samples = mcmc.get_samples()
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
alpha -2.17 0.29 -2.18 -2.62 -1.67 1480.93 1.00
beta[0] 4.83 0.49 4.81 4.00 5.59 2194.15 1.00
beta[1] 5.43 0.55 5.42 4.48 6.26 2130.09 1.00
beta[2] 5.16 0.52 5.14 4.36 6.07 2286.47 1.00
beta[3] 4.63 0.52 4.62 3.75 5.41 2035.34 1.00
nu 2.11 0.05 2.11 2.03 2.18 1552.29 1.00
Number of divergences: 41
最後に
以上で「順序回帰と独自の分布の定義」は終わりです。Dirichlet分布を事前分布に使用する順序回帰は使い所がありそうですね。次回は「スパースモデル」です。
Discussion