✂️

切断パレート分布のNumPyro実装

2024/12/07に公開

切断パレート分布を実装した。コードはgithubを参照。

numpyroの切断分布はパレート分布をサポートしていない

ふと、切断パレート分布からサンプリングしたくなった。
NumPyroにはTruncatedDistributionが用意されていて、確率分布をラップして簡単に切断分布を定義できる。

パレート分布も同じ手順で実装できると思っていたらエラーを吐いた。

def model():
    dist.TruncatedDistribution(dist.Pareto(scale, alpha), low=low, high=high, validate_args=validate_args)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, num_chains=4)

# エラーになる
mcmc.run(jax.random.PRNGKey(0), **inputs)

ソースコードを確認したところ、どうやらパレート分布はサポートされていないようだった。

https://github.com/pyro-ppl/numpyro/blob/0.15.3/numpyro/distributions/truncated.py#L32

困ったのでカスタム分布として切断パレート分布を実装した。

切断パレート分布の実装

NumPyroでカスタム分布を実装するには確率密度関数とサンプリングを定義する必要がある。

パレート分布

定義を整理する。

確率密度関数

f(x; x_m, \alpha) = \frac{\alpha x_m^{\alpha}}{x^{\alpha+1}}

累積密度関数

F(x; x_m, \alpha) = 1 - \left( \frac{x_m}{x} \right) ^{\alpha}

累積密度関数の逆関数

F^{-1}(x; x_m, \alpha) = \frac{x_m}{(1-x)^{\frac{1}{\alpha}}}

確率密度関数

切断分布の確率密度関数はオリジナルの分布の確率密度関数を切断区間で調整する。f(x;\theta), F(x;\theta)をそれぞれ確率密度関数と累積密度関数とする。区間(L,U)の切断分布の確率密度関数f_{[L,U]}(x;\theta)は、

f_{[L,U]}(x;\theta) = \begin{cases} \frac{f(x;\theta)}{F(U;\theta)-F(L;\theta)} &\text{if } x \in [L, U] \\ 0 &\text{otherwise } \end{cases}

となる。分母は\int_L^U f(x;\theta) dxで、片側の場合は積分区間が[-\infty,U]とか[L,\infty]になる。
実装は、

    def log_prob(self, value):
        log_m = self.logcdf(self.high, self.scale, self.alpha) - self.logcdf(self.low, self.scale, self.alpha)
        log_p = self.logpdf(value, self.scale, self.alpha)
        return jnp.where((self.low < value) * (value < self.high), log_p - log_m, -jnp.inf)

サンプリング

逆関数法を用いる。切断分布の場合は、区間[F(L;\theta), F(U;\theta)]から取得した一様乱数を累積密度関数の逆関数で変換する。これを実装すると、

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        minval = lax.exp(self.logcdf(self.low, self.scale, self.alpha))
        maxval = lax.exp(self.logcdf(self.high, self.scale, self.alpha))
        u = jax.random.uniform(key, shape, minval=minval, maxval=maxval)
        return lax.exp(self.logicdf(u, self.scale, self.alpha))

動作確認

サンプリング結果をscipyの乱数と比較する。

scale = 10.0
alpha = 2.0
low = 15.0
high = 100.0

# numpyro
samples_numpyro = TruncatedPareto(scale, alpha, low, high).sample(jax.random.PRNGKey(0), (10000,))

# scipy
samples_scipy = scipy.stats.pareto.rvs(b=2, scale=10, size=10000)
samples_scipy = samples_scipy[(samples_scipy > low) & (samples_scipy < high)]

ヒストグラムの形状はだいたい同じ。

分布推定も試す。

def model(num_samples, x=None, min_x=None, max_x=None):
    scale = numpyro.sample("scale", dist.LogNormal())
    alpha = numpyro.sample("alpha", dist.LogNormal())
    low = numpyro.sample("low", dist.TruncatedDistribution(dist.Cauchy(), low=0, high=min_x))
    high = numpyro.sample("high", dist.TruncatedDistribution(dist.Cauchy(), low=max_x))
    with numpyro.plate("observations", num_samples):
        numpyro.sample("x", TruncatedPareto(scale, alpha, low, high), obs=x)

inputs = dict(num_samples=len(samples_scipy), x=samples_scipy, min_x=np.min(samples_scipy), max_x=np.max(samples_scipy))

kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, num_chains=1)
mcmc.run(jax.random.PRNGKey(0), **inputs)
mcmc.print_summary()
sample: 100%|██████████| 1500/1500 [00:01<00:00, 772.30it/s, 7 steps of size 4.13e-01. acc. prob=0.93]  

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      2.12      0.03      2.12      2.06      2.17    556.78      1.00
      high    100.77      1.23    100.44     99.57    102.37    847.73      1.00
       low     15.00      0.00     15.00     15.00     15.00    715.23      1.00
     scale     10.83      0.08     10.84     10.72     10.96    503.15      1.00

Number of divergences: 0

もとのパラメータに近い値が推定できている。

乱数が確率分布に従うことを言うのはけっこう大変な気がする。真面目に検証するならベイズファクターとか分布間距離を持ち出すのだろうか?

Discussion