👻

NumPyro:順序回帰と独自の分布の定義

2023/04/23に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は順序回帰と独自の分布の定義の仕方を見ていきます。順序回帰に関してはNumPyroのチュートリアル通りになります。より詳細な式などに関してはこちらを参照ください。
独自の分布の定義は他に入れる記事がなかったので、適当に今回の記事の中に入っているだけで順序回帰との関係はありません。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
from jax.nn import softplus
import jax.numpy as jnp
from jax.scipy.special import expit, logit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
from numpyro.infer.reparam import TransformReparam

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

順序回帰

順序付きのデータをモデリングする際は以下のように考えます。
左図は何かしらの連続値の潜在変数(X)の確率密度分布を表しており、右図はその累積分布関数を表しています。c_kがカットポイントと呼ばれるもので、X軸を区切る点になっており、例えばc_0からc_1はカテゴリ1、c_1からc_2はカテゴリ2というように順序を区切る点になっています。また、最初と最後のカットポイントはそれぞれ-∞+∞になります。

このとき、右図に示されているように、各カテゴリの確率は以下のように累積分布関数の差分で表されます。

p_k = \Pi(c_k) - \Pi(c_{k-1})

ここで、累積分布関数は以下の形で与えられます。

\Pi(x) = \frac{1}{1 + \exp(-x)} \equiv \sigma(x)

ここで、上記の関数に1 - \sigma(x) = \sigma(-x)という関係式が成り立つので、以下の式が成り立ちます。

p_k = \sigma(c_k) - \sigma(c_{k-1}) \\ = 1 - \sigma(-c_k) - (1 - \sigma(-c_{k-1})) \\ = \sigma(-c_{k-1}) - \sigma(-c_k)

この式に今考えている潜在変数Xからの影響を表した\gamma(ここでは\gamma = \beta x)を考慮したものが以下の式になります。

p_k = \Pi(c_k - \gamma) - \Pi(c_{k-1} - \gamma) \\ = \sigma(\gamma - c_{k-1}) - \sigma(\gamma - c_k)

これが実装されているのがdist.OrderedLogistic分布になります。同じ式ですが、この分布はStanのマニュアルに記載のものと同じです。

データの準備

チュートリアルと同じデータです。Xは1変数で、Yが3クラスある順序付きの変数になります。

simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = dist.Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X =dist. Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y

print("value counts of Y:")
df = pd.DataFrame({"X": X, "Y": Y})
print(df.Y.value_counts())

for i in range(nclasses):
    print(f"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}")
value counts of Y:
1    19
2    16
0    15
Name: Y, dtype: int64
mean(X) for Y == 0: 0.042
mean(X) for Y == 1: 0.832
mean(X) for Y == 2: 1.448

モデルの定義とMCMC

上記の内容からdist.OrderedLogisticには潜在変数からの影響を表した\gammaとカットポイントc_kの両方が必要になります。gammaは今までの線形回帰と同様に\gamma = \beta xを使用すれば良いですが、カットポイントも何かしらの値をサンプリングする必要があります。ここでは、ImproperUniformNormalDirichlet分布からサンプリングした値を使用する3つの方法を実装します。

ImproperUniforを使用する場合

カットポイントc_0c_Kはそれぞれ-∞なので、サンプリングするカットポイントの数としては間のnclasses - 1個になります。また、カットポイントはパラメータの識別可能性の点からも順序付きのベクトルである必要があるため、support=constraints.ordered_vectorとして指定しています。
ImproperUniform分布は、制約もある分布ですが、そのパラメータに関する事前分布の位置やスケールなどの情報を追加することなく、ドメインに制約のあるパラメータを使用することができます。

def model1(X, Y=None):
    
    n_classes = len(np.unique(Y))
    
    b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 10))
    c_y = numpyro.sample(
        "c_y",
        dist.ImproperUniform(
            support=constraints.ordered_vector,
            batch_shape=(),
            event_shape=(nclasses - 1,),
        ),
    )
    
    eta = numpyro.deterministic("eta", X * b_X_eta)
    
    with numpyro.plate("obs", X.shape[0]):
        numpyro.sample("Y", dist.OrderedLogistic(predictor=eta, cutpoints=c_y), obs=Y)

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df["X"].values,
    Y=df["Y"].values,
)
mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
   b_X_eta      1.43      0.37      1.42      0.83      2.03   4135.09      1.00
    c_y[0]     -0.11      0.40     -0.11     -0.77      0.53   4096.81      1.00
    c_y[1]      2.16      0.51      2.14      1.33      2.99   4468.53      1.00

Number of divergences: 0

Normalを使用する場合

TransformedDistributionを使用することで、nclasses - 1個の正規分布dist.Normal(0, 1)からサンプリングしたベクトルをdist.transforms.OrderedTransformにより順序付きのベクトルに変換しています。
ベースの分布に正規分布dist.Normal(0, 1)を使用しているため、カットポイントが0から離れすぎない値になってほしい時などに有効です。

def model2(X, Y=None):
    
    n_classes = len(np.unique(Y))
    
    b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 10))
    # 今回はカットポイントが0から離れすぎない値になってほしいのでNormal(0, 1)で順序付きに変換
    c_y = numpyro.sample(
        "c_y",
        dist.TransformedDistribution(
            dist.Normal(0, 1).expand([nclasses - 1]), dist.transforms.OrderedTransform()
        ),
    )
    
    eta = numpyro.deterministic("eta", X * b_X_eta)
    
    with numpyro.plate("obs", X.shape[0]):
        numpyro.sample("Y", dist.OrderedLogistic(predictor=eta, cutpoints=c_y), obs=Y)

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model2)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df["X"].values,
    Y=df["Y"].values,
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   b_X_eta      1.37      0.35      1.36      0.80      1.94   4868.12      1.00
    c_y[0]     -0.05      0.36     -0.06     -0.62      0.55   4570.64      1.00
    c_y[1]      2.04      0.47      2.02      1.22      2.77   4933.96      1.00

Number of divergences: 0

Dirichletを使用する場合

こちらは2023/4/22現在の時点でバグが見つかったので、既存の関数に変更を加えたものを使用します。

class SimplexToOrderedTransform_(dist.transforms.Transform):
    """
    Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints)
    Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

    :param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform.
        For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1].
        For more details please refer to Section 2.2 in [1]

    **References:**

    1. *Ordinal Regression Case Study, section 2.2*,
       M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

    **Example**

    .. doctest::

       >>> import jax.numpy as jnp
       >>> from numpyro.distributions.transforms import SimplexToOrderedTransform
       >>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
       >>> transform = SimplexToOrderedTransform()
       >>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)

    """

    domain = constraints.simplex
    codomain = constraints.ordered_vector

    def __init__(self, anchor_point=0.0):
        self.anchor_point = anchor_point

    def __call__(self, x):
        s = jnp.cumsum(x[..., :-1], axis=-1)
        y = logit(s) + jnp.expand_dims(self.anchor_point, -1)
        return y

    def _inverse(self, y):
        y = y - jnp.expand_dims(self.anchor_point, -1)
        s = expit(y)
        # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1]
        # add two boundary points 0 and 1
        pad_width = [(0, 0)] * (jnp.ndim(s) - 1) + [(1, 1)]
        s = jnp.pad(s, pad_width, constant_values=(0, 1))
        x = s[..., 1:] - s[..., :-1]
        return x

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y))
        # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)`
        J_logdet = (softplus(y) + softplus(-y)).sum(-1)
        return J_logdet

    def forward_shape(self, shape):
        """
        Infers the shape of the forward computation, given the input shape.
        Defaults to preserving shape.
        """
        return shape[:-1] + (shape[-1] - 1,)


    def inverse_shape(self, shape):
        """
        Infers the shapes of the inverse computation, given the output shape.
        Defaults to preserving shape.
        """
        return shape[:-1] + (shape[-1] - 1,)

Dirichlet分布からSimplexなベクトルをサンプリングし、その後SimplexToOrderedTransform_により順序付きのベクトルへ変換しています。Dirichlet分布を事前分布に使用することで各クラスの出現割合に事前知識を持っていた場合に反映することが可能になります。

例えば、以下のようにconcentrationを各クラス全てで同じようにすれば各クラスの出現確率は同じになります。

concentration = jnp.ones(3)*10
d = dist.TransformedDistribution(
    dist.Dirichlet(concentration),
    SimplexToOrderedTransform_(0.0),
)
cutpoints = d.sample(random.PRNGKey(0), sample_shape=(1000,))

from jax.scipy.special import logit, expit
p1 = expit(cutpoints[:, 0]) - expit(-np.inf)
p2 = expit(cutpoints[:, 1]) - expit(cutpoints[:, 0])
p3 = expit(np.inf) - expit(cutpoints[:, 1])

print(p1.mean(), p2.mean(), p3.mean())
0.33179164 0.33491898 0.33328944

また、レアなクラスがあった場合に以下のように事前知識を入れることができます。特にデータが少なく&レアなクラスがある場合などに有効そうです。

# 事前情報として出現割合を知識として入れ込むことができる
concentration = jnp.array([1, 4, 5])
d = dist.TransformedDistribution(
    dist.Dirichlet(concentration),
    SimplexToOrderedTransform_(0.0),
)
cutpoints = d.sample(random.PRNGKey(0), sample_shape=(1000,))

from jax.scipy.special import logit, expit
p1 = expit(cutpoints[:, 0]) - expit(-np.inf)
p2 = expit(cutpoints[:, 1]) - expit(cutpoints[:, 0])
p3 = expit(np.inf) - expit(cutpoints[:, 1])

print(p1.mean(), p2.mean(), p3.mean())

これらを使用したモデルが以下になります。anchor_pointはsimplexベクトルは総和が1になるという制約がある中でカットポイントを決める際に使用するものです。

def model3(X, Y, nclasses, concentration, anchor_point=0.0):
    b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 5))

    with numpyro.handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
		SimplexToOrderedTransform_(anchor_point),
                #dist.transforms.SimplexToOrderedTransform(anchor_point),
            )
        )

    with numpyro.plate("obs", X.shape[0]):
        eta = X * b_X_eta
        numpyro.sample("Y", dist.OrderedLogistic(eta, c_y), obs=Y)

concentration = np.ones((nclasses,)) * 10.0
rng_key= random.PRNGKey(0)
kernel = NUTS(model3)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df["X"].values,
    Y=df["Y"].values,
    nclasses=nclasses,
    concentration=concentration,
)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
                 mean       std    median      5.0%     95.0%     n_eff     r_hat
    b_X_eta      1.01      0.29      1.00      0.51      1.47   1222.00      1.00
     c_y[0]     -0.43      0.28     -0.43     -0.85      0.07   1566.81      1.00
     c_y[1]      1.35      0.31      1.36      0.84      1.83   1212.22      1.00
c_y_base[0]      0.40      0.07      0.39      0.29      0.50   1559.41      1.00
c_y_base[1]      0.39      0.06      0.39      0.30      0.49   1721.26      1.00
c_y_base[2]      0.21      0.05      0.20      0.13      0.29   1249.45      1.00

独自の分布の定義

こでは、Conway–Maxwell–Poisson distributionというNumPyroでは実装がされていない分布を作成してみます。この分布はポアソン分布を拡張したような分布です。ポアソン分布の場合平均と分散が同じ\lambdaでただ一つのパラメータで表されますが、それゆえカウントデータでよく見られる過剰分散や過小分散を考慮できません。Conway–Maxwell–Poisson distributionはは、\lambda\nuという2つのパタメータを持ち、過剰分散や過小分散を考慮することができるより柔軟な分布です。

実装としては、英語版Wikipediapymc3での実装numpyroのdist.Poissonのコードを参考にしました。

sampleメソッドは一様分布からサンプリングした値とCDFの値から乱数生成しています。log_probメソッドは英語版Wikipediaの式をそのまま実装してます。pymc3での実装では無限級数のところを無視していたのですが、CDFの値がWikipediaの図と結構違うと個人的に感じたので修正しています。cdfメソッドは今回は離散変数なので全部の確率の総和を取っているだけです。
このように必要なメソッドを定義すれば自分でも独自の分布を定義できます。 今回はsampleメソッドを2重ループで書いているので複雑なモデルでは遅くなりそうです。逆関数を作成するなどして行列演算で書き直した方がいいかもしれません。もし独自の分布の定義の仕方をより詳しく知りたい方はドキュメントの5.1 Recap of NumPyro distributionsに載っているので参考にしてみてください。

from numpyro.distributions import Distribution
from jax.scipy.special import gammaln, logsumexp
from numpyro.distributions.util import is_prng_key, validate_sample


class CMPoisson(Distribution):
    # https://en.wikipedia.org/wiki/Conway–Maxwell–Poisson_distribution
    # dist.Poissonを参考に実装
    # この分布は、カウントデータでよく見られる過剰分散や過小分散を考慮することができる柔軟な分布である
    # https://rss.onlinelibrary.wiley.com/doi/10.1111/j.1467-9876.2005.00474.x
    # https://gist.github.com/dadaromeo/33e581d9e3bcbad83531b4a91a87509f
    
    arg_constraints = {"lambda_": constraints.positive, "nu": constraints.positive}
    support = constraints.nonnegative_integer
    
    def __init__(self, lambda_, nu, *, validate_args=None):
        self.lambda_ = lambda_
        self.nu = nu
        self.alpha = jnp.power(lambda_, 1/nu)
        super(CMPoisson, self).__init__(jnp.shape(lambda_), validate_args=validate_args)
    
    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        
        shape = sample_shape + self.batch_shape
        u = random.uniform(key, shape).ravel()
        
        values = jnp.empty(u.shape, dtype=int)
        for i in range(len(u)):
            value = 0
            cdf = self.cdf(value)
            while u[i] > cdf:
                value += 1
                cdf = self.cdf(value)
            values = values.at[i].set(value)
        
        return values.reshape(shape)

    @validate_sample
    def log_prob(self, value):
        """ log(PMF) 
        gammaln : 正の整数 nに対して階乗関数にlogを取ったものになる 
        https://hazm.at/mox/math/special-function/gamma-function.html
        """
        if self._validate_args:
            self._validate_sample(value)

        lambda_ = self.lambda_
        nu = self.nu
        alpha = self.alpha

        # The normalizing constantはclosed formを持たないので、漸近展開した形を使用
        # 参考にしたコードでは無限級数のところを無視しているので修正
        log_Z = nu * alpha - (nu - 1)*0.5*jnp.log(2*jnp.pi*alpha) - 0.5*jnp.log(nu)
        log_Z += jnp.log(1. + (nu**2 - 1.)/24.*jnp.power(nu*alpha, -1.) + (nu**2 - 1.)*(nu**2+23)/1152*jnp.power(nu*alpha, -2.))

        return value * jnp.log(lambda_) - nu * gammaln(value+1) - log_Z

    def cdf(self, value):
        return jnp.sum(jnp.exp(self.log_prob(jnp.arange(value+1))))

ちゃんと実装できているかCDFの図をWikipediaと比べてみます。

X = jnp.arange(20)
plt.figure(figsize=(4,2))
plt.plot(X, [CMPoisson(1, 1.5).cdf(x) for x in X], c="blue")
plt.plot(X, [CMPoisson(3, 1.1).cdf(x) for x in X], c="green")
plt.plot(X, [CMPoisson(5, 0.7).cdf(x) for x in X], c="red")

Wikipediaの図

pymc3での実装を真似てモデルに使用してみます。これはちゃんと動くかどうかの確認だけです。

データの準備

n,d = 1000, 4
X = np.abs(np.random.randn(n,d))
y = np.round(X.sum(axis=1)).astype(int)

pd.Series(y).describe()
count    1000.000000
mean        3.236000
std         1.216464
min         1.000000
25%         2.000000
50%         3.000000
75%         4.000000
max         8.000000
dtype: float64
plt.hist(y, bins=50)
plt.title("Observed data")

モデルの定義とMCMC

def model(X, y=None):
    alpha = numpyro.sample("alpha", dist.Normal(1, 1))
    beta = numpyro.sample("beta", dist.Normal(1, 1).expand([X.shape[1]]))
    nu = numpyro.sample("nu", dist.HalfNormal(10))
    
    lam = numpyro.deterministic("lam", alpha + jnp.dot(X, beta))
    
    with numpyro.plate("N", len(X)):
        numpyro.sample("obs", CMPoisson(lambda_=lam, nu=nu), obs=y)

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.99)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4)

mcmc.run(random.PRNGKey(0), X, y)
mcmc_samples = mcmc.get_samples()
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha     -2.17      0.29     -2.18     -2.62     -1.67   1480.93      1.00
   beta[0]      4.83      0.49      4.81      4.00      5.59   2194.15      1.00
   beta[1]      5.43      0.55      5.42      4.48      6.26   2130.09      1.00
   beta[2]      5.16      0.52      5.14      4.36      6.07   2286.47      1.00
   beta[3]      4.63      0.52      4.62      3.75      5.41   2035.34      1.00
        nu      2.11      0.05      2.11      2.03      2.18   1552.29      1.00

Number of divergences: 41

最後に

以上で「順序回帰と独自の分布の定義」は終わりです。Dirichlet分布を事前分布に使用する順序回帰は使い所がありそうですね。次回は「スパースモデル」です。

Discussion