🗂
NumPyro:ベイジアンABテスト
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
今回はベイジアンABテストを扱います。各ケースにおいて適切に事前分布等を設計するなど手間はかかりますが、最終的にBがAより優れている/効果があった確率は◯%というような結果が出てくるので、次のアクションを決める意思決定に向いています。
ベイジアンABテストはいろいろな企業で導入されており以下のように紹介記事も多いので、今回はNumPyroでの実装を簡単にだけ紹介して終わりにします。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
ベイジアンABテスト
デモデータの準備
今回は適当に正規分布からサンプリングされた2群のデータをA、Bとして扱います。AとBのサンプルはお互いに重なっており、見た目だけではどれだけの違いがあるのかは一見わかりづらいデータになっています。
A = dist.Normal(50, 5).sample(random.PRNGKey(2), (10,))
B = dist.Normal(55, 5).sample(random.PRNGKey(1), (10,))
sns.distplot(A)
sns.distplot(B)
モデルの定義
def model_normal(x):
mu = numpyro.sample("mu", dist.Normal(52, 5))
scale = numpyro.sample("scale", dist.LogNormal(1))
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, scale), obs=x)
MCMC
うまく推論できていそうです。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcA = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcA.run(
rng_key=rng_key,
x=A
)
samplesA = mcmcA.get_samples()
mcmcA.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu 48.94 1.47 48.91 46.50 51.27 4122.14 1.00
scale 4.77 1.18 4.56 3.02 6.46 3948.31 1.00
Number of divergences: 0
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcB = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcB.run(
rng_key=rng_key,
x=B
)
samplesB = mcmcB.get_samples()
mcmcB.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu 55.20 1.35 55.21 52.91 57.31 4204.37 1.00
scale 4.35 1.07 4.17 2.72 5.87 3833.45 1.00
Number of divergences: 0
分布間の違いを定量化する
既に各正規分布の平均と標準偏差を推定できたので、この分布からmodel
を新しく定義せずに計算していけば良いのですが、ここではNumPyroのPredictive
の変わった使い方の紹介として、わざわざPredictive
を使用してAとBの差分をサンプリングしていきたいと思います。
def model_diff():
mu_A = numpyro.sample("mu_A", dist.Normal(52, 5))
scale_A = numpyro.sample("scale_A", dist.LogNormal(1))
mu_B = numpyro.sample("mu_B", dist.Normal(52, 5))
scale_B = numpyro.sample("scale_B", dist.LogNormal(1))
normal_A = numpyro.sample("normal_A", dist.Normal(mu_A, scale_A))
normal_B = numpyro.sample("normal_B", dist.Normal(mu_B, scale_B))
diff = numpyro.deterministic("diff", (normal_B - normal_A)/normal_B)
上記のモデルのサイトネームに合うように辞書を作成します。
samples = {}
samples["mu_A"] = samplesA["mu"]
samples["scale_A"] = samplesA["scale"]
samples["mu_B"] = samplesB["mu"]
samples["scale_B"] = samplesB["scale"]
事後分布のサンプルが与えられなかったサイトのサンプリング結果が得られます。
rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model_diff, samples)
predictions = predictive(rng_key_)
predictions.keys()
dict_keys(['diff', 'normal_A', 'normal_B'])
AとBの事後分布を可視化してみます。
sns.distplot(predictions["normal_A"], label="A")
sns.distplot(predictions["normal_B"], label="B")
plt.legend()
plt.show()
AとBの差の分布を可視化します。83%の確率でBがAよりよいことがわかりました。
diff = predictions["diff"]
prob_BgtA = len(diff[diff>0])/len(diff)
g = sns.displot(diff[diff>0], color="red", label=f"B>A: prob={prob_BgtA:.2f}")
g.map(sns.histplot, data=diff[diff<=0], color="blue", label=f"B<A: prob={1-prob_BgtA:.2f}")
plt.vlines(0, 0, 500, colors="gray")
plt.ylim([0, 500])
plt.legend()
最後に
以上で、「ベイジアンABテスト」は終わりです。頻度主義のABテストとは異なり、解釈もしやすい方法ですし、確率で出力が得られるのは納得感もあるのではないでしょうか。次は「事前分布や各分布に関してまとめ」です。
Discussion