【第一回】非線形微分方程式系のベイズ推定 with Stan: 1時系列データ
初めに
非線形な微分方程式で表される系は、分岐や多重安定性など、複雑な挙動を示します。今回は特に解析的に解くことができない数理モデルでベイズ推定する方法を紹介します。まず1時系列のデータのみを扱い、そこから複数時系列データ&階層モデルに拡張することを目指します。
実行環境
R ver. 4.4.2
cmdstanr ver. 0.8.1
cmdstan ver. 2.35.0
VSCode ver. 1.95.3
扱うモデル
アリの採餌動員についての数理モデル(Beekman et al. 2001)を使用します。
このモデルの詳細は省略しますが、動員されるアリの個体数
この微分方程式はパラメータ
今回は
データの生成
観察するアリの動員個体数
library(deSolve)
library(ggplot2)
library(dplyr)
library(tidyr)
library(extraDistr)
# 微分方程式の定義
Ppun_aggr <- function(t, x, params) {
alpha <- params["alpha"]
beta <- params["beta"]
s <- params["s"]
N <- params["N"]
dxdt <- (alpha + beta * x[1]) * (N - x[1]) - (s * x[1]) / (s + x[1])
list(c(dxdt))
}
# パラメータの設定
alpha <- 0.0045
beta <- 0.00015
s <- 10
N <- 600
times <- seq(0, 100, by = 1) # 時間範囲
# 初期値をランダムにして10個のデータを生成
set.seed(100)
initial_values <- runif(10, min = 0, max = N)
results <- list()
for (i in 1:10) {
x0 <- initial_values[i]
params <- c(alpha = alpha, beta = beta, s = s, N = N)
output <- ode(y = c(x = x0), times = times, func = Ppun_aggr, parms = params)
df <- as.data.frame(output)
df$initial_value <- x0
# ベータ二項分布に従ってばらつかせる
alpha_beta <- 100 # ベータ分布の形状パラメータ1
beta_beta <- 100 # ベータ分布の形状パラメータ2
prob <- df$x / N
df$x_obs <- rbbinom(n = nrow(df), size = N, alpha = alpha_beta * prob, beta = beta_beta * (1 - prob))
df$n_total <- N
results[[i]] <- df
}
# データを結合
all_data <- bind_rows(results, .id = "series")
# ggplotで散布図を作成
plt <- ggplot(all_data, aes(x = time, y = x_obs, color = as.factor(series))) +
geom_point(size=0.8) +
geom_line(aes(y = x), linetype = "solid") +
labs(x = "時間", y = "観測個体数", color = "系列") +
theme_classic() # +theme(legend.position = "none")
ggsave("plot.png", plot = plt, width = 1000, height = 800, units = "px")
write.csv(all_data[all_data$time>0,], "all_data.csv", row.names = FALSE)
Stanによる実装
stanには内部にODEソルバーが実装(ode_rk45)されているので、これを利用します。
また、微分方程式を解くには初期値
ある程度パラメータの大きさが分かっているものとして、事前分布として、
を設定します。
functions { // 微分方程式の宣言
vector beekman(real t, vector x, array[] real par) {
vector[1] dxdt;
real a = par[1]/10000;
real b = par[2]/100000;
real s = par[3];
real n_total = par[4];
dxdt[1] = (a + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
return dxdt;
}
}
data { // 入力データ形式の宣言
int<lower=0> N; // 1時系列のデータ数
array[N] real<lower=0> ts; // 時間
array[N] real<lower=0> x; // ts=tにおける個体数
int<lower=0> n_total; // 総個体数
}
parameters { // 推定するパラメータの宣言
real<lower=0> a;
real<lower=0> b;
real<lower=0> s;
real<lower=0> sigma;
vector<lower=0>[1] x0; // ode_rk45に与える初期値(vector形式で与える必要がある)
}
transformed parameters { // 変形されたパラメータの宣言
array[4] real par;
par[1] = a;
par[2] = b;
par[3] = s;
par[4] = n_total;
}
model { // モデル構造の宣言
// priors
a ~ normal(0, 100);
b ~ normal(0, 100);
s ~ normal(0, 100);
sigma ~ cauchy(0, 2.5);
x0[1] ~ uniform(0, n_total);
array[N] vector[1] mu = ode_rk45(beekman, x0, 0, ts, par); // 微分方程式の数値解を得る
for (i in 1:N) {
x[i] ~ normal(mu[i], sigma); // 各データx_iは平均mu[i], 標準偏差sigmaの正規分布に従う
}
}
generated quantities {
array[N] vector[1] mu_pred = ode_rk45(beekman, x0, 0, ts, par);
}
実行
R&cmdstanによってMCMCを行います。インストールの方法はほかの記事を参照してください。今回はサンプルデータのうち系列10のデータのみ使用します。
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)
raw_data <- read.csv("all_data.csv")
data <- raw_data %>% filter(series == 10) %>% select(time, x_obs, n_total)
input <- list(
N = nrow(data),
ts = data$time,
x = data$x_obs,
n_total = data$n_total[1]
)
stan <- cmdstan_model('model1.stan')
# warmup中のプロットも取得したい場合はsave_warmup = TRUEが必要
fit <- stan$sample(data = input, iter_warmup = 2000, iter_sampling = 2000, parallel_chains = 4, chains = 4, save_warmup = TRUE)
結果の確認
基本的なプロットを出力します。結果の可視化はbayesplotパッケージがおすすめです。
color_scheme_set("brewer-RdYlBu")
plt_dens <- mcmc_dens_overlay(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
color_scheme_set("blue")
plt_trace <- mcmc_trace(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = T), n_warmup = 2000) + theme_classic()
plt_pairs <- mcmc_pairs(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))
ggsave("trace_plot.png", plot = plt_trace, width = 1000, height = 800, units = "px", dpi=180)
ggsave("dens_plot.png", plot = plt_dens, width = 1000, height = 800, units = "px", dpi=180)
ggsave("pairs_plot.png", plot = plt_pairs, width = 1000, height = 800, units = "px", dpi=180)
推定結果
実行結果は
fit$summary()
で確認できます。cmdstanでは、デフォルトで出力される確信区間が90%であることに注意しましょう。
> fit$summary()
# A tibble: 110 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -370. -370. 1.93 1.71 -374. -368. 1.00 1386. 1583.
2 a 20.6 16.9 15.9 15.5 1.51 51.1 1.00 1612. 2367.
3 b 12.0 11.9 2.44 2.43 8.38 16.3 1.00 1186. 1435.
4 s 6.66 6.41 2.44 2.44 3.13 11.1 1.00 1138. 1567.
5 sigma 27.0 26.9 1.96 1.94 24.0 30.3 1.00 3183. 3610.
6 x0[1] 102. 102. 6.12 6.25 91.3 111. 1.00 2107. 2255.
今回、真の値は
事後平均値(EAP: Expected A Posteriori), 事後最頻値(MAP: Maximum A Posterior Estimator), 事後中央値(MED: posterior MEDian)といった点推定値や、等裾事後確信区間(ETI: Equal-Tailed Interval)や最大事後密度確信区間(HDI: Highest posterior Density Interval)といった区間推定値を計算するにはtidybayesパッケージを使って次のようにします。
fit$draws(format = "df") %>%
spread_draws(a,b,s) %>%
mean_qi() # EAPと95%ETI
fit$draws(format = "df") %>%
spread_draws(a,b,s) %>%
mode_hdi() # MAPと95%HDI
fit$draws(format = "df") %>%
spread_draws(a,b,s) %>%
median_hdi() # MEDと95%HDI
a a.lower a.upper b b.lower b.upper s s.lower s.upper .width .point .interval
1 20.60315 0.7498388 58.71253 12.0445 7.710016 17.0512 6.655556 2.601 12.03596 0.95 mean qi
a a.lower a.upper b b.lower b.upper s s.lower s.upper .width .point .interval
1 4.889889 0.007764758 51.1794 11.428 7.6966 17.0375 5.533436 2.15004 11.4708 0.95 mode hdi
a a.lower a.upper b b.lower b.upper s s.lower s.upper .width .point .interval
1 16.9045 0.007764758 51.1794 11.8855 7.6966 17.0375 6.40527 2.15004 11.4708 0.95 median hdi
トレースプロット
事後分布
フィッティング結果の描画
事後最頻値(MAP)とその95%CI(HDI)を計算してグラフにします。
df_xpred <- fit$draws(format = "df") %>%
spread_draws(mu_pred[time]) %>%
mode_hdi(.width = 0.95) # 予測値と95%CIの計算
plt_fitting <- ggplot(df_xpred, aes(x = time)) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), fill = '#D53E4F', alpha = 0.4) +
geom_line(aes(y = mu_pred), linewidth = 1, col = "#D53E4F") +
geom_line(data = data.frame(input), aes(x = ts, y = x), col = "#D53E4F") +
geom_point(data = data.frame(input), aes(x = ts, y = x), col = "#D53E4F") +
ylim(c(0, 600)) +
theme_classic() +
labs(x = "時間", y = "観測個体数")
ggsave("fitting_plot.png", plot = plt_fitting, width = 1000, height = 800, units = "px", dpi=180)
違うデータではどうなる?
データによっては、うまく推定できない場合があります。系列7のデータを使って同じように推定を行ってみます。一見、うまくデータにフィットしているように見えますが、系列10の結果と比べると、推定値の分散が非常に大きいことが分かります。特に
> fit_alt$summary()
# A tibble: 110 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -357. -357. 1.61 1.41 -361. -356. 1.00 2390. 3067.
2 a 81.7 68.2 62.1 60.0 6.44 200. 1.00 2419. 1878.
3 b 118. 117. 44.0 45.2 51.0 195. 1.00 1583. 1579.
4 s 89.8 85.6 39.1 38.2 34.8 161. 1.00 1583. 1524.
5 sigma 25.4 25.3 1.87 1.83 22.5 28.6 1.00 3416. 3464.
6 x0[1] 529. 528. 33.0 35.9 477. 585. 1.00 1861. 2033.
まとめ
今回は、微分方程式で表される数理モデルのパラメータをStanとRを使ってベイズ推定する手法と結果の図示の仕方について簡単に紹介しました。単純な曲線でフィッティングするよりも、より生物学的な意味を持たせた解釈ができるのではないでしょうか。次回は推定時に発生した問題について考えていきます。【第二回に続く】
参考文献
松浦健太郎 StanとRでベイズ統計モデリング
分寺杏介 統計的方法論特殊研究講義資料 その他参考にしたもの
Discussion