🗂

NumPyro:ベイジアンABテスト

2023/04/28に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回はベイジアンABテストを扱います。各ケースにおいて適切に事前分布等を設計するなど手間はかかりますが、最終的にBがAより優れている/効果があった確率は◯%というような結果が出てくるので、次のアクションを決める意思決定に向いています。

ベイジアンABテストはいろいろな企業で導入されており以下のように紹介記事も多いので、今回はNumPyroでの実装を簡単にだけ紹介して終わりにします。

https://engineering.mercari.com/blog/entry/20221110-bayesian-testing-for-souzoh/
https://saltcooky.hatenablog.com/entry/2020/07/30/012109
https://inside.dmm.com/articles/bayesian-ab-testing-arpu/
https://hack.nikkei.com/blog/advent20221216/

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

ベイジアンABテスト

デモデータの準備

今回は適当に正規分布からサンプリングされた2群のデータをA、Bとして扱います。AとBのサンプルはお互いに重なっており、見た目だけではどれだけの違いがあるのかは一見わかりづらいデータになっています。

A = dist.Normal(50, 5).sample(random.PRNGKey(2), (10,))
B = dist.Normal(55, 5).sample(random.PRNGKey(1), (10,))

sns.distplot(A)
sns.distplot(B)

モデルの定義

def model_normal(x):
    mu = numpyro.sample("mu", dist.Normal(52, 5))
    scale = numpyro.sample("scale", dist.LogNormal(1))
    with numpyro.plate("N", len(x)):
        numpyro.sample("obs", dist.Normal(mu, scale), obs=x)

MCMC

うまく推論できていそうです。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcA = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcA.run(
    rng_key=rng_key,
    x=A
)
samplesA = mcmcA.get_samples()
mcmcA.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu     48.94      1.47     48.91     46.50     51.27   4122.14      1.00
     scale      4.77      1.18      4.56      3.02      6.46   3948.31      1.00

Number of divergences: 0
# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model_normal)
mcmcB = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4, thinning=1)
mcmcB.run(
    rng_key=rng_key,
    x=B
)
samplesB = mcmcB.get_samples()
mcmcB.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu     55.20      1.35     55.21     52.91     57.31   4204.37      1.00
     scale      4.35      1.07      4.17      2.72      5.87   3833.45      1.00

Number of divergences: 0

分布間の違いを定量化する

既に各正規分布の平均と標準偏差を推定できたので、この分布からmodelを新しく定義せずに計算していけば良いのですが、ここではNumPyroのPredictiveの変わった使い方の紹介として、わざわざPredictiveを使用してAとBの差分をサンプリングしていきたいと思います。

def model_diff():
    mu_A = numpyro.sample("mu_A", dist.Normal(52, 5))
    scale_A = numpyro.sample("scale_A", dist.LogNormal(1))
    mu_B = numpyro.sample("mu_B", dist.Normal(52, 5))
    scale_B = numpyro.sample("scale_B", dist.LogNormal(1))

    normal_A = numpyro.sample("normal_A", dist.Normal(mu_A, scale_A))
    normal_B = numpyro.sample("normal_B", dist.Normal(mu_B, scale_B))
    
    diff = numpyro.deterministic("diff", (normal_B - normal_A)/normal_B)

上記のモデルのサイトネームに合うように辞書を作成します。

samples = {}
samples["mu_A"] = samplesA["mu"]
samples["scale_A"] = samplesA["scale"]
samples["mu_B"] = samplesB["mu"]
samples["scale_B"] = samplesB["scale"]

事後分布のサンプルが与えられなかったサイトのサンプリング結果が得られます。

rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(model_diff, samples)
predictions = predictive(rng_key_)
predictions.keys()
dict_keys(['diff', 'normal_A', 'normal_B'])

AとBの事後分布を可視化してみます。

sns.distplot(predictions["normal_A"], label="A")
sns.distplot(predictions["normal_B"], label="B")
plt.legend()
plt.show()

AとBの差の分布を可視化します。83%の確率でBがAよりよいことがわかりました。

diff = predictions["diff"]

prob_BgtA = len(diff[diff>0])/len(diff)

g = sns.displot(diff[diff>0], color="red", label=f"B>A: prob={prob_BgtA:.2f}")
g.map(sns.histplot, data=diff[diff<=0], color="blue", label=f"B<A: prob={1-prob_BgtA:.2f}")
plt.vlines(0, 0, 500, colors="gray")
plt.ylim([0, 500])
plt.legend()

最後に

以上で、「ベイジアンABテスト」は終わりです。頻度主義のABテストとは異なり、解釈もしやすい方法ですし、確率で出力が得られるのは納得感もあるのではないでしょうか。次は「事前分布や各分布に関してまとめ」です。

Discussion