🗂

2023/04/28に公開

# はじめに

ベイジアンABテストはいろいろな企業で導入されており以下のように紹介記事も多いので、今回はNumPyroでの実装を簡単にだけ紹介して終わりにします。

# ライブラリのインポート

``````import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
``````

# ベイジアンABテスト

## デモデータの準備

``````A = dist.Normal(50, 5).sample(random.PRNGKey(2), (10,))
B = dist.Normal(55, 5).sample(random.PRNGKey(1), (10,))

sns.distplot(A)
sns.distplot(B)
``````

## モデルの定義

``````def model_normal(x):
mu = numpyro.sample("mu", dist.Normal(52, 5))
scale = numpyro.sample("scale", dist.LogNormal(1))
with numpyro.plate("N", len(x)):
numpyro.sample("obs", dist.Normal(mu, scale), obs=x)
``````

## MCMC

うまく推論できていそうです。

``````# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcA = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcA.run(
rng_key=rng_key,
x=A
)
samplesA = mcmcA.get_samples()
mcmcA.print_summary()
``````
``````                mean       std    median      5.0%     95.0%     n_eff     r_hat
mu     48.94      1.47     48.91     46.50     51.27   4122.14      1.00
scale      4.77      1.18      4.56      3.02      6.46   3948.31      1.00

Number of divergences: 0
``````
``````# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcB = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcB.run(
rng_key=rng_key,
x=B
)
samplesB = mcmcB.get_samples()
mcmcB.print_summary()
``````
``````                mean       std    median      5.0%     95.0%     n_eff     r_hat
mu     55.20      1.35     55.21     52.91     57.31   4204.37      1.00
scale      4.35      1.07      4.17      2.72      5.87   3833.45      1.00

Number of divergences: 0
``````

## 分布間の違いを定量化する

``````def model_diff():
mu_A = numpyro.sample("mu_A", dist.Normal(52, 5))
scale_A = numpyro.sample("scale_A", dist.LogNormal(1))
mu_B = numpyro.sample("mu_B", dist.Normal(52, 5))
scale_B = numpyro.sample("scale_B", dist.LogNormal(1))

normal_A = numpyro.sample("normal_A", dist.Normal(mu_A, scale_A))
normal_B = numpyro.sample("normal_B", dist.Normal(mu_B, scale_B))

diff = numpyro.deterministic("diff", (normal_B - normal_A)/normal_B)
``````

``````samples = {}
samples["mu_A"] = samplesA["mu"]
samples["scale_A"] = samplesA["scale"]
samples["mu_B"] = samplesB["mu"]
samples["scale_B"] = samplesB["scale"]
``````

``````rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model_diff, samples)
predictions = predictive(rng_key_)
predictions.keys()
``````
``````dict_keys(['diff', 'normal_A', 'normal_B'])
``````

AとBの事後分布を可視化してみます。

``````sns.distplot(predictions["normal_A"], label="A")
sns.distplot(predictions["normal_B"], label="B")
plt.legend()
plt.show()
``````

AとBの差の分布を可視化します。83%の確率でBがAよりよいことがわかりました。

``````diff = predictions["diff"]

prob_BgtA = len(diff[diff>0])/len(diff)

g = sns.displot(diff[diff>0], color="red", label=f"B>A: prob={prob_BgtA:.2f}")
g.map(sns.histplot, data=diff[diff<=0], color="blue", label=f"B<A: prob={1-prob_BgtA:.2f}")
plt.vlines(0, 0, 500, colors="gray")
plt.ylim([0, 500])
plt.legend()
``````