✂️
切断パレート分布のNumPyro実装
切断パレート分布を実装した。コードはgithubを参照。
numpyroの切断分布はパレート分布をサポートしていない
ふと、切断パレート分布からサンプリングしたくなった。
NumPyroにはTruncatedDistribution
が用意されていて、確率分布をラップして簡単に切断分布を定義できる。
パレート分布も同じ手順で実装できると思っていたらエラーを吐いた。
def model():
dist.TruncatedDistribution(dist.Pareto(scale, alpha), low=low, high=high, validate_args=validate_args)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, num_chains=4)
# エラーになる
mcmc.run(jax.random.PRNGKey(0), **inputs)
ソースコードを確認したところ、どうやらパレート分布はサポートされていないようだった。
困ったのでカスタム分布として切断パレート分布を実装した。
切断パレート分布の実装
NumPyroでカスタム分布を実装するには確率密度関数とサンプリングを定義する必要がある。
パレート分布
定義を整理する。
確率密度関数
累積密度関数
累積密度関数の逆関数
確率密度関数
切断分布の確率密度関数はオリジナルの分布の確率密度関数を切断区間で調整する。
となる。分母は
実装は、
def log_prob(self, value):
log_m = self.logcdf(self.high, self.scale, self.alpha) - self.logcdf(self.low, self.scale, self.alpha)
log_p = self.logpdf(value, self.scale, self.alpha)
return jnp.where((self.low < value) * (value < self.high), log_p - log_m, -jnp.inf)
サンプリング
逆関数法を用いる。切断分布の場合は、区間
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
minval = lax.exp(self.logcdf(self.low, self.scale, self.alpha))
maxval = lax.exp(self.logcdf(self.high, self.scale, self.alpha))
u = jax.random.uniform(key, shape, minval=minval, maxval=maxval)
return lax.exp(self.logicdf(u, self.scale, self.alpha))
動作確認
サンプリング結果をscipyの乱数と比較する。
scale = 10.0
alpha = 2.0
low = 15.0
high = 100.0
# numpyro
samples_numpyro = TruncatedPareto(scale, alpha, low, high).sample(jax.random.PRNGKey(0), (10000,))
# scipy
samples_scipy = scipy.stats.pareto.rvs(b=2, scale=10, size=10000)
samples_scipy = samples_scipy[(samples_scipy > low) & (samples_scipy < high)]
ヒストグラムの形状はだいたい同じ。
分布推定も試す。
def model(num_samples, x=None, min_x=None, max_x=None):
scale = numpyro.sample("scale", dist.LogNormal())
alpha = numpyro.sample("alpha", dist.LogNormal())
low = numpyro.sample("low", dist.TruncatedDistribution(dist.Cauchy(), low=0, high=min_x))
high = numpyro.sample("high", dist.TruncatedDistribution(dist.Cauchy(), low=max_x))
with numpyro.plate("observations", num_samples):
numpyro.sample("x", TruncatedPareto(scale, alpha, low, high), obs=x)
inputs = dict(num_samples=len(samples_scipy), x=samples_scipy, min_x=np.min(samples_scipy), max_x=np.max(samples_scipy))
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, num_chains=1)
mcmc.run(jax.random.PRNGKey(0), **inputs)
mcmc.print_summary()
sample: 100%|██████████| 1500/1500 [00:01<00:00, 772.30it/s, 7 steps of size 4.13e-01. acc. prob=0.93]
mean std median 5.0% 95.0% n_eff r_hat
alpha 2.12 0.03 2.12 2.06 2.17 556.78 1.00
high 100.77 1.23 100.44 99.57 102.37 847.73 1.00
low 15.00 0.00 15.00 15.00 15.00 715.23 1.00
scale 10.83 0.08 10.84 10.72 10.96 503.15 1.00
Number of divergences: 0
もとのパラメータに近い値が推定できている。
乱数が確率分布に従うことを言うのはけっこう大変な気がする。真面目に検証するならベイズファクターとか分布間距離を持ち出すのだろうか?
Discussion