「Pythonでスラスラわかる ベイズ推論「超」入門」をNumPyroで書き直す
はじめに
Pythonでスラスラわかる ベイズ推論「超」入門(赤石 雅典 (著), 須山 敦志 (監修))はとても分かりやすい良書です。
この記事は当該書籍のベイズ推論のライブラリをPyMCからNumPyroに書き換えたコード(Zenn記事)へのリンク集です。
異なるライブラリを使って同じ結果を得る過程は良い学びとなります。
それぞれのライブラリの共通項を理解することでその裏にある原理への理解も深まります。
これらの記事が皆さんの参考になれば幸いです。
この記事のセールスポイント
この記事の売りは次の2点です。
- 高速ライブラリNumPyroを使ったベイズ推論のコードを数多く紹介している。
- 確率モデルを作成する過程を丁寧に書いている。
1点目は言わずもがな、高速なベイズ推論ライブラリであるNumPyroに惹かれてこの記事を開いて頂いたはずです。
参考書籍のコードの大部分をNumPyroに書き換えました。
それぞれのコードをコピペするだけですぐに高速ライブラリの世界に入門することができます。
2点目は学習中にその重要性に気づいたので書籍の内容を掘り下げました。
ベイズ推論を学び始めたときに「ベイズの定理
定理とモデルをつなげる鍵はずばり数式です。
数式といっても行列式のような計算はほとんどありません。
「考えている現象をどの確率分布で近似するか」を繰り返すだけです。
確率モデルのプログラミングはこの仮定をコードに書き直すだけです。
ライブラリと同様に異なる手段で同じ結果を得ることで、確率モデルへの理解が深まります。
実用上は条件付確率
本記事でも 比例式
PyMCとNumPyroの比較
それぞれのチュートリアルを一通り触ってみての所感です。
最初は癖の少ないPyMCで学ぶのが良いと思います。
速度が必要になってからNumPyroに移るのがよいでしょう。
| 項目 | PyMC | NumPyro |
|---|---|---|
| 特徴量と目的変数のデータ型 |
numpy.ndarray, pandas.DataFrameを使用できる。慣れている人が多い。 |
jax.numpy.arrayを使う。 |
| 確率モデルの作り方 | モデルのインスタンスを生成した後にwith構文内でモデルの詳細を実装する。 |
モデル名の関数の中でモデルの詳細を実装する。こちらの方が好み。 |
| 確率モデルのプロット | 説明変数も可視化できる。わかりやすい。 | プロットがすっきりする。 |
| 確率モデルのプロットのコードの書きやすさ | 簡単。書籍の通りでOK。 | 面倒くさい。ユーティリティ関数を作るべき。(本記事にコードを掲載しました) |
| サンプリングの実行速度(パフォーマンス) | 遅い。他のバックエンドで少し改善する。 | 速い。もう戻れない。 |
| サンプリングのコードの書きやすさ(ユーザーフレンドリー) | 簡単。書籍の通りでOK。 | 面倒くさい。ユーティリティ関数を作るべき。(本記事にコードを掲載しました) |
| サンプリング後の検証 | ArviZのコードを使用する。 | PyMCと同じ |
ライブラリのバージョン
活発に更新されているライブラリでは破壊的変更が行われることがあります。
読者が同じコードを実行できるように、本記事とリンク先で使用するライブラリのバージョンを列挙します。
First releasedから一年たつのでそろそろPython 3.13に引き上げても良いだろうと判断しました。
# Pythonのバージョン
import sys
print("Miniforge", sys.version)
# DataFrame, Numerical computation
import pandas as pd
print("Pandas", pd.__version__)
import numpy as np
print("NumPy",np.__version__)
import jax
print("Jax", jax.__version__)
# ベイズ推定
import numpyro
print("NumPyro", numpyro.__version__)
# plot
import matplotlib
print("Matplotlib", matplotlib.__version__)
%matplotlib inline
import arviz as az
print("ArviZ", az.__version__)
# plotで日本語フォントを使用する。
import japanize_matplotlib
#print(japanize_matplotlib.__version__) # エラーになる

フォルダ構造とユーティリティ関数
PyMCとNumPyroの比較で書いた通り、NumPyroで確率モデルのプロットと事後分布のサンプリングのコードを毎回書くのは面倒くさいです。
繰り返しを減らすためにユーティリティ関数を用意しました。
本記事とリンク先では下記のディレクトリ構造で作成しました。
ディレクトリ構成図はTreeというサイトで作成しました。
.
├── mod/
│ └── numpyro_utility.py
└── notebooks/
├── this_article.ipynb
└── other_article.ipynb
"""
NumPyroを使ったベイズ統計ユーティリティ
=====================================
本モジュールは NumPyro で定義した確率モデルの可視化と、
NUTS によるMCMCサンプリングを簡易に実行するための関数を提供します。
- try_render_model: 確率モデルのグラフ可視化 (Graphviz / SVG)
- run_mcmc: NUTS + MCMC を走らせて ArviZ InferenceData を返す
注記
----
- ここでいう「model」は NumPyro のモデル関数 (sample / plate などを内部で呼ぶ関数) を指します。
- 乱数シードは jax.random.PRNGKey(seed) のみを固定します。ハードウェアや並列実行によって厳密再現性が揺らぐことがあります。
"""
from __future__ import annotations
# ---- import ----
from typing import Any, Callable, Optional, Literal
import jax
import numpyro
import arviz as az
def try_render_model(
model: Callable[..., None],
render_name: str,
**model_args: Any,
) -> Optional[str]:
"""
ベイズ統計モデルを可視化してSVGファイルに保存する。
Parameters
----------
model : Callable[..., None]
NumPyro のモデル関数。
render_name : str
出力するファイル名(拡張子なし)。`{render_name}.svg` が保存されます。
**model_args : Any
モデルに渡すキーワード引数。データやハイパーパラメータなど。
Returns
-------
Optional[str]
正常終了時は出力SVGのパス、エラー時は ``None``。
Notes
-----
- Graphviz が環境にインストールされていない場合はレンダリングに失敗します。
- Jupyter 上では生成したSVGをその場で表示します。スクリプト実行時はファイル保存のみです。
Examples
--------
>>> def model(y):
... import numpyro.distributions as dist
... theta = numpyro.sample("theta", dist.Beta(1, 1))
... numpyro.sample("obs", dist.Bernoulli(theta), obs=y)
>>> try_render_model(model, "coin_model", y=[0, 1, 1, 0])
'coin_model.svg'
"""
try:
# 確率モデルを作成する (パラメータ名や分布も描画)
g = numpyro.render_model(
model=model,
model_args=(),
model_kwargs=model_args,
render_distributions=True,
render_params=True,
)
# SVGで保存
outpath = f"{render_name}.svg"
g.render(render_name, format="svg", cleanup=True)
print(f"Model graph saved to: {outpath}")
# Jupyter 環境ならプレビュー表示も行う
try:
from IPython.display import display, SVG # type: ignore
display(SVG(filename=outpath))
except Exception:
# 表示側での失敗は無視 (ファイルは保存済み)
print(f"Preview skipped; file saved: {outpath}")
return outpath
except Exception as e:
print(f"(Skip model rendering for {render_name}: {e})")
return None
def run_mcmc(
model: Callable[..., None],
num_chains: int = 4,
num_warmup: int = 1000,
num_samples: int = 1000,
thinning: int = 1,
seed: int = 42,
target_accept_prob: float = 0.8,
log_likelihood: bool = False,
**model_args: Any,
) -> az.InferenceData:
"""
NumPyro のベイズ統計モデルで NUTS によるMCMCサンプリングを行い、
ArviZ の ``InferenceData`` を返す。
Parameters
----------
model : Callable[..., None]
NumPyro のモデル関数。
num_chains : int, default 4
同時に走らせるMCMCチェーンの本数。
num_warmup : int, default 1000
ウォームアップ(バーンイン)の反復回数。
num_samples : int, default 1000
保存する事後サンプル数 (各チェーンあたり)。
thinning : int, default 1
サンプルの間引き間隔。``1`` なら間引きなし。
seed : int, default 42
乱数シード (``jax.random.PRNGKey(seed)`` に渡されます)。
target_accept_prob : float, default 0.8
NUTS のステップサイズ調整で目標とする受理率。
log_likelihood : bool, default False
``az.from_numpyro`` で対数尤度を同梱するかどうか。
**model_args : Any
モデルに渡すキーワード引数 (データなど)。
Returns
-------
az.InferenceData
事後分布サンプル等を含む ``InferenceData``。
Notes
-----
- 利用可能な JAX デバイス数に応じて、チェーンを ``parallel`` または ``sequential`` に自動切替します。
- 進捗バーは対話環境で有効です。非対話環境では無効化される場合があります。
Examples
--------
>>> import numpy as np
>>> import numpyro.distributions as dist
>>> def model(x, y=None):
... beta0 = numpyro.sample("beta0", dist.Normal(0, 10))
... beta1 = numpyro.sample("beta1", dist.Normal(0, 10))
... sigma = numpyro.sample("sigma", dist.HalfNormal(1))
... mu = beta0 + beta1 * x
... numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
>>> x = np.linspace(0, 1, 50)
>>> y = 1 + 2 * x + np.random.normal(0, 0.1, size=x.size)
>>> idata = run_mcmc(model, num_warmup=500, num_samples=1000, x=x, y=y)
>>> az.summary(idata)
"""
# NUTSサンプラーの構築
sampler = numpyro.infer.NUTS(model, target_accept_prob=target_accept_prob)
# 並列実行可能かを判定
num_devices = jax.local_device_count()
chain_method: Literal["parallel", "sequential"] = (
"parallel" if num_devices >= num_chains else "sequential"
)
# MCMCオブジェクトの作成
mcmc = numpyro.infer.MCMC(
sampler=sampler,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
thinning=thinning,
chain_method=chain_method,
progress_bar=True,
)
# 乱数キーを初期化して実行
mcmc.run(jax.random.PRNGKey(seed), **model_args)
# ArviZ の InferenceData に変換
idata = az.from_numpyro(mcmc, log_likelihood=log_likelihood)
return idata
関連Notebookの共通import
関連記事のコードは全て下記のライブラリをimportして使用します。
# Module
import sys
sys.path.append("../")
from mod.numpyro_utility import *
# DataFrame, Numerical computation
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
# ベイズ推定
import numpyro
import numpyro.distributions as dist # 確率分布
# plot
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import arviz as az
# プロットに日本語を表示する。
import japanize_matplotlib
第1章から3章にはPyMCのコードが無いので省略します。
仕方ないね。
第4章 はじめてのベイズ推論実習
くじ引きを例にベルヌーイ分布と二項分布のベイズ推論を行います。
この例で数式から書き始めるのは冗長ですが、後の例の練習と信じてお付き合いください。
リンク
5.1 データ分布のベイズ推論 - 第5章ベイズ推論プログラミング
アイリス・データセットを使ってSetosaのがく片の長さsepal_lengthの分布を例に正規分布のベイズ推論を行います。
ここから条件付確率
リンク
5.2 線形回帰のベイズ推論 - 第5章ベイズ推論プログラミング
アイリス・データセットのversicolorのがく片の長さsepal_lengthとがく片の幅sepal_widthの1次関数近似のベイズ推論を行います。
単純な線形回帰であっても一目ではベイズの定理と確率モデルとの関係性が分かりにくいので、数式で考えることの恩恵が大きくなります。
リンク
5.3 階層ベイズモデル - 第5章ベイズ推論プログラミング
アイリス・データセットの3種類の花のデータを3個ずつ抽出しました。
合計9個のデータで3種類の花のがく片の長さsepal_lengthとがく片の幅sepal_widthの1次関数近似のベイズ推論を行います。
関連するデータが少しずつあるという業務でありがちな状況でベイズ推論が輝きます。
リンク
5.4 潜在変数モデル - 第5章ベイズ推論プログラミング
アイリス・データセットのversicolorとvirginicaをがく片の幅sepal_widthだけでクラスタリングします。
それぞれのがく片の幅が正規分布に従うというシンプルな仮定から見事にクラスタリングされる様子を味わうことができます。
リンク
6.1 ABテストの効果検証 - 第6章ベイズ推論の業務活用事例
業務で使用することが多いABテストをベイズ推論します。
シンプルながら奥が深い分析です。
リンク
終わりに
Pythonでスラスラわかる ベイズ推論「超」入門(赤石 雅典 (著), 須山 敦志 (監修))のPyMCで書かれたコードをNumPyroで書き直したZenn記事へのリンクを紹介しました。
記事に起こす過程は自分の理解度を振り返ることにもなり、一段と理解が深まった実感があります。
これらの記事が読者の方への参考になれば幸いです。
Discussion