【第三回】非線形微分方程式系のベイズ推定 with Stan: 複数時系列データへの拡張
アイデア
第二回で、パラメータを推定する際の問題点として、識別可能性と感度があるというお話をしました。識別可能性は根本的な数理モデルの問題なので対処のしようがありませんが、感度については「その測定条件(初期条件: ここでは
使用するデータ
なぜそうなるかは一旦おいておくとして、とりあえず出来レースでデータを作ります。今回はこのデータセットall_data_2.csvを使います。
library(deSolve)
library(ggplot2)
library(dplyr)
library(tidyr)
library(extraDistr)
# 微分方程式の定義
Ppun_aggr <- function(t, x, params) {
alpha <- params["alpha"]
beta <- params["beta"]
s <- params["s"]
N <- params["N"]
dxdt <- (alpha + beta * x[1]) * (N - x[1]) - (s * x[1]) / (s + x[1])
list(c(dxdt))
}
# パラメータの設定
alpha <- 0.0045
beta <- 0.00015
s <- 10
N <- 600
times <- seq(0, 100, by = 1) # 時間範囲
# 3つの初期値からデータを生成する
set.seed(100)
initial_values <- c(2.0, 100, 200)
results <- list()
for (i in 1:length(initial_values)) {
x0 <- initial_values[i]
params <- c(alpha = alpha, beta = beta, s = s, N = N)
output <- ode(y = c(x = x0), times = times, func = Ppun_aggr, parms = params)
df <- as.data.frame(output)
df$initial_value <- x0
# ベータ二項分布に従ってばらつかせる
alpha_beta <- 100 # ベータ分布の形状パラメータ1
beta_beta <- 100 # ベータ分布の形状パラメータ2
prob <- df$x / N
df$x_obs <- rbbinom(n = nrow(df), size = N, alpha = alpha_beta * prob, beta = beta_beta * (1 - prob))
df$n_total <- N
results[[i]] <- df
}
# データを結合
all_data <- bind_rows(results, .id = "series")
# ggplotで散布図を作成
plt <- ggplot(all_data, aes(x = time, y = x_obs, color = as.factor(series))) +
geom_point(size=0.8) +
geom_line(aes(y = x), linetype = "solid") +
labs(x = "時間", y = "観測個体数", color = "系列") +
theme_classic() # +theme(legend.position = "none")
ggsave("plot2.png", plot = plt, width = 1000, height = 800, units = "px")
write.csv(all_data[all_data$time>0,], "all_data_2.csv", row.names = FALSE)
Stanによる実装
基本は第一回と同じモデルになります。どの時系列データでもパラメータ
事前分布はほぼ第一回と同様に
とします。
functions { // モデル式の宣言
vector beekman(real t, vector x, array[] real par) {
vector[1] dxdt;
real a = par[1]/10000;
real b = par[2]/100000;
real s = par[3];
real n_total = par[4];
dxdt[1] = (a + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
return dxdt;
}
}
data {
int<lower=0> Series; // 時系列のデータ数
int<lower=0> N; // 1時系列のデータ数
array[Series, N] real<lower=0> ts; // 時間
array[Series, N] real<lower=0> x; // ts=tにおける個体数
array[Series, N] real<lower=0> n_total; // 総個体数
}
parameters {
real<lower=0> a;
real<lower=0> b;
real<lower=0> s;
array[Series] real<lower=0> sigma;
array[Series] vector<lower=0>[1] x0;
}
transformed parameters {
array[Series, 4] real par;
for (i in 1:Series) {
par[i, 1] = a;
par[i, 2] = b;
par[i, 3] = s;
par[i, 4] = n_total[i, 1];
}
}
model {
// priors
a ~ normal(0, 100);
b ~ normal(0, 100);
s ~ normal(0, 100);
for (i in 1:Series) {
sigma[i] ~ cauchy(0, 2.5);
x0[i, 1] ~ uniform(0, n_total[i, 1]);
}
// 各時系列ごとにODEを解く
for (i in 1:Series) {
array[N] vector[1] mu = ode_rk45(beekman, x0[i,], 0, ts[i, 1:N], par[i,]);
for (j in 1:N) {
x[i,j] ~ normal(mu[j], sigma[i]);
}
}
}
generated quantities {
array[Series, N] vector[1] mu_pred;
for (i in 1:Series){
mu_pred[i,] = ode_rk45(beekman, x0[i,], 0, ts[1, 1:N], par[i,]);
}
}
実行
こちらも入力データの編集部分以外は第一回とほとんど同じです。個人的な感想ですが、データ数と実行時間には非線形な関係があるように思います笑 お菓子でも食べてゆっくり待ちましょう。cmdstanrの仕様上、1つのchainはCPU1スレッドで実行されます。parallel_chainsは並列で実行されるchain数ですので、parallel_chains>chainsにしても意味がないです。処理を早くしたい場合はシングルコア性能が高いCPUを使ってください。chain内並列処理もStanの機能としてあるそうですが、自分は試したことがありません。
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)
raw_data <- read.csv("all_data_2.csv")
data <- raw_data %>% select(series, time, x_obs, n_total)
series_ids <- unique(data$series)
x_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(x_obs))
ts_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(time))
n_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(n_total))
ts_lengths <- map(ts_list, length) # 各系列の長さ このモデルでは全て同じ長さにする必要があるがこれを使えば拡張はできるはず
input <- list(
Series = length(series_ids),
N = ts_lengths[1],
ts = ts_list,
x = x_list,
n_total = n_list
)
stan <- cmdstan_model('model2.stan')
fit <- stan$sample(data = input, iter_warmup = 2000, iter_sampling = 2000, parallel_chains = 4, chains = 4, save_warmup = TRUE)
# 以下各種グラフを描画
color_scheme_set("brewer-RdYlBu")
plt_dens <- mcmc_dens_overlay(fit$draws(c("a", "b", "s"),inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
color_scheme_set("blue")
plt_trace <- mcmc_trace(fit$draws(c("a", "b", "s"),inc_warmup = T),n_warmup = 2000) + theme_classic()
plt_pairs <- mcmc_pairs(fit$draws(c("a", "b", "s"),inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))
ggsave("trace_plot2.png", plot = plt_trace, width = 1000, height = 400, units = "px", dpi=180)
ggsave("dens_plot2.png", plot = plt_dens, width = 1000, height = 400, units = "px", dpi=180)
ggsave("pairs_plot2.png", plot = plt_pairs, width = 1000, height = 800, units = "px", dpi=180)
df_xpred <- fit$draws(format = "df") %>%
spread_draws(mu_pred[series,time,vector]) %>%
mode_hdi(.width = 0.95) # MAP推定値と95%CIの計算
df_list <- split(df_xpred, df_xpred$series)# series列で分割してリストに格納
list2env(setNames(df_list, paste0("df", 1:3)), envir = .GlobalEnv) # 各リスト要素を df1, df2, df3 に代入
plt_fitting <- ggplot(df1, aes(x = time)) +
geom_ribbon(aes(ymin = .lower, ymax = .upper), fill = '#D53E4F', alpha = 0.4) +
geom_line(aes(y = mu_pred), linewidth = 1, col = "#D53E4F") +
geom_ribbon(data = df2, aes(ymin = .lower, ymax = .upper), fill = '#32CD32', alpha = 0.4) +
geom_line(data = df2, aes(y = mu_pred), linewidth=1, col="#32CD32") +
geom_ribbon(data = df3, aes(ymin = .lower, ymax = .upper), fill = '#4169E2', alpha = 0.4) +
geom_line(data = df3, aes(y = mu_pred), linewidth=1, col="#4169E2") +
geom_line(data=data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
geom_point(data=data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
theme_classic() +
labs(x = "時間", y = "観測個体数")
ggsave("fitting_plot2.png", plot = plt_fitting, width = 1000, height = 800, units = "px", dpi=180)
結果
出力した結果を列挙します。きれいに収束してくれました。真の値はx0[1,1], x0[2,1], ...
はそれぞれ系列1,2,...の推定された初期値です。
>fit$summary()
# A tibble: 322 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -998. -998. 2.24 2.10 -1.00e+3 -995. 1.00 2804. 3760.
2 a 41.5 41.5 3.97 4.01 3.49e+1 47.9 1.00 4459. 4248.
3 b 14.4 14.4 1.82 1.81 1.15e+1 17.5 1.00 2051. 2394.
4 s 9.58 9.54 1.58 1.57 7.03e+0 12.3 1.00 2037. 2488.
5 sigma[1] 5.25 5.23 0.368 0.355 4.69e+0 5.88 1.00 6671. 5464.
6 sigma[2] 30.3 30.2 2.14 2.13 2.69e+1 34.0 1.00 6346. 5341.
7 sigma[3] 35.4 35.2 2.56 2.54 3.15e+1 39.8 1.00 6292. 4509.
8 x0[1,1] 4.17 3.71 2.95 3.11 3.47e-1 9.65 1.00 3246. 2042.
9 x0[2,1] 103. 103. 4.98 4.93 9.47e+1 111. 1.00 2678. 3213.
10 x0[3,1] 198. 198. 6.09 6.00 1.88e+2 208. 1.00 4485. 4213.
# ℹ 312 more rows
# ℹ Use `print(n = ...)` to see more rows
>fit$draws(format = "df") %>% spread_draws(a,b,s) %>% mode_hdi() # MAPと95%HDI
a a.lower a.upper b b.lower b.upper s s.lower s.upper .width .point .interval
41.48638 33.9068 49.3635 14.3135 10.9573 18.0643 9.653436 6.48942 12.7132 0.95 mode hdi
まとめ
1つの時系列データだけでは推定困難なパラメータも、条件を変えたデータを取り入れることで推定することができました。うまくいった理由は、機会があれば説明を追記するつもりですが、第二回の感度分析を今回増やしたデータセットそれぞれで試してみるとわかるかもしれません。どんなデータを使えば推定がうまくいくのか、ぜひご自身でデータをいじってみてください。次回は、
追記
これまで使ったコード類はすべて私のGitHubにアップロードしています。データのcsvファイルもありますので、自分で作るのが面倒でしたらダウンロードしてください。ところどころ数値が違うかもしれません💦未整理でごめんなさい。
Discussion