NumPyro:ODE
連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。
はじめに
今回は、MCMCで常微分方程式のパラメータを推論する方法を見ていきます。MCMCでパラメータ推論することで物理法則などの事前知識の導入や不確実性を考慮しながら、多変数のパラメータ最適化をバランスよく実施できることが期待できます。
ライブラリのインポート
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.enable_x64(True)
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
化学反応速度を例にして
私の専門が化学なので、化学反応式の化学反応速度を題材にして紹介します。今回は、こちらを参考にして2段階で進む反応を想定します。
ここで、各段階の反応の反応速度定数をk_1
, k2
とし、AとBの反応次数をam
, bm
とすると、反応速度式は以下のように表すことができます。[A], [B], [C]はそれぞれ各成分の濃度です。
ルンゲ=クッタ法による時間発展
jaxにはjax.experimental.ode.odeint
という内部でルンゲ=クッタ法が動く時間発展用の関数があり、odeint
の引数には時間発展式dz_dt
, z_initには各初期値, time_trueには時間発展用のfloat型の時間が格納された配列, その他パラメータを与えます。
これを上記の反応速度式に適用した結果を見ていきます。パラメータはこちらを参考にして
def dz_dt(z, t, k1, k2, am, bm):
A, B, C = z
dA_dt = -k1 * jnp.power(A, am)
dB_dt = k1 * jnp.power(A, am) - k2 * jnp.power(B, bm)
dC_dt = k2 * jnp.power(B, bm)
return jnp.stack([dA_dt, dB_dt, dC_dt])
k1_true = 0.1 # [s^-1]
k2_true = 0.01 # [s^-1]
am_true = 1.0
bm_true = 1.0
# A, B, Cの初期値
z_init = jnp.array([1.0, 0.0, 0.0])
# 時間発展用のfloat型の時間が格納された配列
time_true = jnp.arange(700).astype(float)
z_true = odeint(dz_dt, z_init, time_true, k1_true, k2_true, am_true, bm_true, rtol=1e-7, atol=1e-6, mxstep=10000)
plt.plot(z_true, label=['A', 'B', 'C'])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()
デモデータの作成
先ほど作成した真の値に適当な正規分布に従うノイズを追加したものを観測値として作成しました。ここで、観測している時間間隔は1分刻みとしました。
sigma_true = 0.02
time_obs = np.arange(0, 701, 60)
z_obs = z_true[time_obs]
z_obs += np.random.normal(np.zeros(z_obs.shape), sigma_true)
plt.plot(z_true, label=['A', 'B', 'C'])
plt.scatter(time_obs, z_obs[:, 0])
plt.scatter(time_obs, z_obs[:, 1])
plt.scatter(time_obs, z_obs[:, 2])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()
モデルの定義
ODEの箇所以外は今までの線形回帰などと同じです。各種パラメータを何かしらの分布からサンプリングし、その値を用いてodeint
により時間発展させます。
def model(time_ode, time_obs, z_init, z_obs):
# param
k1 = numpyro.sample("k1", dist.HalfNormal(1))
k2 = numpyro.sample("k2", dist.HalfNormal(1))
am = numpyro.sample("am", dist.Normal(1, 1))
bm = numpyro.sample("bm", dist.Normal(1, 1))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
# ode
z = odeint(dz_dt, z_init, time_ode, k1, k2, am, bm, rtol=1e-7, atol=1e-6, mxstep=10000)
# 観測された時間だけ
numpyro.sample('y', dist.Normal(z[time_obs], sigma), obs=z_obs)
MCMCと結果の確認
am
やk1
の値が真の値から少しずれてますね。今回の推定値を用いた結果と比較すると、観測値が得られていない範囲(t=0~50 sあたり)がうまくフィットされていないことが分かります。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
time_ode=jnp.arange(700).astype(float),
time_obs=time_obs,
z_init=np.array([1.0, 0.0, 0.0]),
z_obs=z_obs
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
am 1.66 0.26 1.62 1.24 2.08 92.32 1.00
bm 0.95 0.06 0.94 0.88 1.05 74.84 1.00
k1 0.87 0.58 0.73 0.11 1.71 101.22 1.00
k2 0.01 0.00 0.01 0.01 0.01 42.43 1.02
sigma 0.02 0.00 0.02 0.02 0.03 260.58 1.00
Number of divergences: 1579
samples = mcmc.get_samples()
k1_mcmc = samples["k1"].mean(axis=0)
k2_mcmc = samples["k2"].mean(axis=0)
am_mcmc = samples["am"].mean(axis=0)
bm_mcmc = samples["bm"].mean(axis=0)
z_init = jnp.array([1.0, 0.0, 0.0])
time_true = jnp.arange(700).astype(float)
z_mcmc = odeint(dz_dt, z_init, time_true, k1_mcmc, k2_mcmc, am_mcmc, bm_mcmc, rtol=1e-7, atol=1e-6, mxstep=10000)
plt.plot(z_true, label=['A', 'B', 'C'], color="b")
plt.plot(z_mcmc, label=['A_mcmc', 'B_mcmc', 'C_mcmc'], color="r")
plt.scatter(time_obs, z_obs[:, 0], label="A_obs")
plt.scatter(time_obs, z_obs[:, 1], label="B_obs")
plt.scatter(time_obs, z_obs[:, 2], label="C_obs")
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()
デモデータの作成2
先ほどの結果を受けて、最初の段階をより細かく測定したようなデータをデモデータとして作成しました。
sigma_true = 0.02
time_obs = np.hstack([np.arange(0, 100, 10), np.arange(100, 701, 60)])
z_obs = z_true[time_obs]
z_obs += np.random.normal(np.zeros(z_obs.shape), sigma_true)
plt.plot(z_true, label=['A', 'B', 'C'])
plt.scatter(time_obs, z_obs[:, 0])
plt.scatter(time_obs, z_obs[:, 1])
plt.scatter(time_obs, z_obs[:, 2])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()
MCMCと結果の確認
先ほどの結果と比較すると、am
やk1
の値が真の値に近づきました。図で見てもかなりよくフィッティングするパラメータを見つけることができています。
# 乱数の固定に必要
rng_key= random.PRNGKey(0)
# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
time_ode=jnp.arange(700).astype(float),
time_obs=time_obs,
z_init=np.array([1.0, 0.0, 0.0]),
z_obs=z_obs
)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
am 1.19 0.03 1.18 1.16 1.23 48.26 1.02
bm 0.93 0.02 0.93 0.91 0.97 52.57 1.00
k1 0.12 0.00 0.12 0.11 0.12 57.10 1.01
k2 0.01 0.00 0.01 0.01 0.01 38.45 1.00
sigma 0.02 0.00 0.02 0.02 0.02 139.08 1.00
Number of divergences: 1858
samples = mcmc.get_samples()
k1_mcmc = samples["k1"].mean(axis=0)
k2_mcmc = samples["k2"].mean(axis=0)
am_mcmc = samples["am"].mean(axis=0)
bm_mcmc = samples["bm"].mean(axis=0)
z_init = jnp.array([1.0, 0.0, 0.0])
time_true = jnp.arange(700).astype(float)
z_mcmc = odeint(dz_dt, z_init, time_true, k1_mcmc, k2_mcmc, am_mcmc, bm_mcmc, rtol=1e-7, atol=1e-6, mxstep=10000)
plt.plot(z_true, label=['A', 'B', 'C'], color="b")
plt.plot(z_mcmc, label=['A_mcmc', 'B_mcmc', 'C_mcmc'], color="r")
plt.scatter(time_obs, z_obs[:, 0], label="A_obs")
plt.scatter(time_obs, z_obs[:, 1], label="B_obs")
plt.scatter(time_obs, z_obs[:, 2], label="C_obs")
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()
最後に
以上で「ODE」は終わりです。モデルの定式化ができれば、簡単に試すことができそうです。また、触媒の有無や温度違いで測定したデータがあれば、物理的な知見として反応速度定数や次数がこれくらい変わる等という事前知識を入れてモデリングするなど拡張性もある点も良いところでしょうか。次回は「次元圧縮(PCA)、行列分解」です。
Discussion