🐥

NumPyro:ODE

2023/04/27に公開

連載している記事の1つです。以前までの記事を読んでいる前提で書いているので、必要であればNumPyroの記事一覧から各記事を参考にしてください。

はじめに

今回は、MCMCで常微分方程式のパラメータを推論する方法を見ていきます。MCMCでパラメータ推論することで物理法則などの事前知識の導入や不確実性を考慮しながら、多変数のパラメータ最適化をバランスよく実施できることが期待できます。

ライブラリのインポート

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.enable_x64(True)
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

化学反応速度を例にして

私の専門が化学なので、化学反応式の化学反応速度を題材にして紹介します。今回は、こちらを参考にして2段階で進む反応を想定します。

A \rightarrow B \rightarrow C

ここで、各段階の反応の反応速度定数をk_1, k2とし、AとBの反応次数をam, bmとすると、反応速度式は以下のように表すことができます。[A], [B], [C]はそれぞれ各成分の濃度です。

\frac{d[A]}{dt} = -k_1 [A]^{am} \\ \frac{d[B]}{dt} = k_1 [A]^{am} - k_2 [B]^{bm} \\ \frac{d[C]}{dt} = k_2 [B]^{bm} \\

ルンゲ=クッタ法による時間発展

jaxにはjax.experimental.ode.odeintという内部でルンゲ=クッタ法が動く時間発展用の関数があり、odeintの引数には時間発展式dz_dt, z_initには各初期値, time_trueには時間発展用のfloat型の時間が格納された配列, その他パラメータを与えます。

これを上記の反応速度式に適用した結果を見ていきます。パラメータはこちらを参考にしてk_1=0.10 s^{−1}k_2=0.010 s^{−1}[A]_0=1.0 Mとしました。1段階目の反応が律速反応になっている様子が見えます。

def dz_dt(z, t, k1, k2, am, bm):
    A, B, C = z
    
    dA_dt = -k1 * jnp.power(A, am)
    dB_dt = k1 * jnp.power(A, am) - k2 * jnp.power(B, bm)
    dC_dt = k2 * jnp.power(B, bm)
    
    return jnp.stack([dA_dt, dB_dt, dC_dt])

k1_true = 0.1 # [s^-1]
k2_true = 0.01 # [s^-1]
am_true = 1.0
bm_true = 1.0
# A, B, Cの初期値
z_init = jnp.array([1.0, 0.0, 0.0])
# 時間発展用のfloat型の時間が格納された配列
time_true = jnp.arange(700).astype(float)

z_true = odeint(dz_dt, z_init, time_true, k1_true, k2_true, am_true, bm_true, rtol=1e-7, atol=1e-6, mxstep=10000)

plt.plot(z_true, label=['A', 'B', 'C'])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()

デモデータの作成

先ほど作成した真の値に適当な正規分布に従うノイズを追加したものを観測値として作成しました。ここで、観測している時間間隔は1分刻みとしました。

sigma_true = 0.02

time_obs = np.arange(0, 701, 60)
z_obs = z_true[time_obs] 
z_obs += np.random.normal(np.zeros(z_obs.shape), sigma_true)

plt.plot(z_true, label=['A', 'B', 'C'])
plt.scatter(time_obs, z_obs[:, 0])
plt.scatter(time_obs, z_obs[:, 1])
plt.scatter(time_obs, z_obs[:, 2])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()

モデルの定義

ODEの箇所以外は今までの線形回帰などと同じです。各種パラメータを何かしらの分布からサンプリングし、その値を用いてodeintにより時間発展させます。

def model(time_ode, time_obs, z_init, z_obs):
    # param
    k1 = numpyro.sample("k1", dist.HalfNormal(1))
    k2 = numpyro.sample("k2", dist.HalfNormal(1))
    am = numpyro.sample("am", dist.Normal(1, 1))
    bm = numpyro.sample("bm", dist.Normal(1, 1))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    
    # ode
    z = odeint(dz_dt, z_init, time_ode, k1, k2, am, bm, rtol=1e-7, atol=1e-6, mxstep=10000)
    # 観測された時間だけ
    numpyro.sample('y', dist.Normal(z[time_obs], sigma), obs=z_obs)

MCMCと結果の確認

amk1の値が真の値から少しずれてますね。今回の推定値を用いた結果と比較すると、観測値が得られていない範囲(t=0~50 sあたり)がうまくフィットされていないことが分かります。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    time_ode=jnp.arange(700).astype(float),
    time_obs=time_obs,
    z_init=np.array([1.0, 0.0, 0.0]),
    z_obs=z_obs
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        am      1.66      0.26      1.62      1.24      2.08     92.32      1.00
        bm      0.95      0.06      0.94      0.88      1.05     74.84      1.00
        k1      0.87      0.58      0.73      0.11      1.71    101.22      1.00
        k2      0.01      0.00      0.01      0.01      0.01     42.43      1.02
     sigma      0.02      0.00      0.02      0.02      0.03    260.58      1.00

Number of divergences: 1579
samples = mcmc.get_samples()
k1_mcmc = samples["k1"].mean(axis=0)
k2_mcmc = samples["k2"].mean(axis=0)
am_mcmc = samples["am"].mean(axis=0)
bm_mcmc = samples["bm"].mean(axis=0)
z_init = jnp.array([1.0, 0.0, 0.0])
time_true = jnp.arange(700).astype(float)

z_mcmc = odeint(dz_dt, z_init, time_true, k1_mcmc, k2_mcmc, am_mcmc, bm_mcmc, rtol=1e-7, atol=1e-6, mxstep=10000)

plt.plot(z_true, label=['A', 'B', 'C'], color="b")
plt.plot(z_mcmc, label=['A_mcmc', 'B_mcmc', 'C_mcmc'], color="r")
plt.scatter(time_obs, z_obs[:, 0], label="A_obs")
plt.scatter(time_obs, z_obs[:, 1], label="B_obs")
plt.scatter(time_obs, z_obs[:, 2], label="C_obs")
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()

デモデータの作成2

先ほどの結果を受けて、最初の段階をより細かく測定したようなデータをデモデータとして作成しました。

sigma_true = 0.02

time_obs = np.hstack([np.arange(0, 100, 10), np.arange(100, 701, 60)])
z_obs = z_true[time_obs] 
z_obs += np.random.normal(np.zeros(z_obs.shape), sigma_true)

plt.plot(z_true, label=['A', 'B', 'C'])
plt.scatter(time_obs, z_obs[:, 0])
plt.scatter(time_obs, z_obs[:, 1])
plt.scatter(time_obs, z_obs[:, 2])
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()

MCMCと結果の確認

先ほどの結果と比較すると、amk1の値が真の値に近づきました。図で見てもかなりよくフィッティングするパラメータを見つけることができています。

# 乱数の固定に必要
rng_key= random.PRNGKey(0)

# NUTSでMCMCを実行する
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    time_ode=jnp.arange(700).astype(float),
    time_obs=time_obs,
    z_init=np.array([1.0, 0.0, 0.0]),
    z_obs=z_obs
)
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        am      1.19      0.03      1.18      1.16      1.23     48.26      1.02
        bm      0.93      0.02      0.93      0.91      0.97     52.57      1.00
        k1      0.12      0.00      0.12      0.11      0.12     57.10      1.01
        k2      0.01      0.00      0.01      0.01      0.01     38.45      1.00
     sigma      0.02      0.00      0.02      0.02      0.02    139.08      1.00

Number of divergences: 1858
samples = mcmc.get_samples()
k1_mcmc = samples["k1"].mean(axis=0)
k2_mcmc = samples["k2"].mean(axis=0)
am_mcmc = samples["am"].mean(axis=0)
bm_mcmc = samples["bm"].mean(axis=0)
z_init = jnp.array([1.0, 0.0, 0.0])
time_true = jnp.arange(700).astype(float)

z_mcmc = odeint(dz_dt, z_init, time_true, k1_mcmc, k2_mcmc, am_mcmc, bm_mcmc, rtol=1e-7, atol=1e-6, mxstep=10000)

plt.plot(z_true, label=['A', 'B', 'C'], color="b")
plt.plot(z_mcmc, label=['A_mcmc', 'B_mcmc', 'C_mcmc'], color="r")
plt.scatter(time_obs, z_obs[:, 0], label="A_obs")
plt.scatter(time_obs, z_obs[:, 1], label="B_obs")
plt.scatter(time_obs, z_obs[:, 2], label="C_obs")
plt.xlabel('time[s]')
plt.ylabel('c[mol]')
plt.legend()
plt.show()

最後に

以上で「ODE」は終わりです。モデルの定式化ができれば、簡単に試すことができそうです。また、触媒の有無や温度違いで測定したデータがあれば、物理的な知見として反応速度定数や次数がこれくらい変わる等という事前知識を入れてモデリングするなど拡張性もある点も良いところでしょうか。次回は「次元圧縮(PCA)、行列分解」です。

Discussion