🐉

【第三回】非線形微分方程式系のベイズ推定 with Stan: 複数時系列データへの拡張

2024/11/22に公開

アイデア

第二回で、パラメータを推定する際の問題点として、識別可能性と感度があるというお話をしました。識別可能性は根本的な数理モデルの問題なので対処のしようがありませんが、感度については「その測定条件(初期条件: ここではx_{0}N)では推定が難しい」という問題なので、他の時系列データがあればなんとかなるかもしれません。すごく単純に言うと「あるパラメータを効率的に推定したければ、そのパラメータを少し変化させただけで動態が大きく変化するようなデータを集めればよい」ということですので、「ある時系列データでaの推定が難しくても、aの感度が高い別の時系列データと組み合わせて推定すればいいんじゃね?」というのが今回のアイデアです。複数の時系列データからまとめてa,b,sを推定するような統計モデルをStanで実装してみましょう。

使用するデータ

なぜそうなるかは一旦おいておくとして、とりあえず出来レースでデータを作ります。今回はこのデータセットall_data_2.csvを使います。

library(deSolve)
library(ggplot2)
library(dplyr)
library(tidyr)
library(extraDistr)

# 微分方程式の定義
Ppun_aggr <- function(t, x, params) {
  alpha <- params["alpha"]
  beta <- params["beta"]
  s <- params["s"]
  N <- params["N"]
  
  dxdt <- (alpha + beta * x[1]) * (N - x[1]) - (s * x[1]) / (s + x[1])
  
  list(c(dxdt))
}

# パラメータの設定
alpha <- 0.0045
beta <- 0.00015
s <- 10
N <- 600
times <- seq(0, 100, by = 1)  # 時間範囲

# 3つの初期値からデータを生成する
set.seed(100) 
initial_values <- c(2.0, 100, 200)
results <- list()

for (i in 1:length(initial_values)) {
  x0 <- initial_values[i]
  params <- c(alpha = alpha, beta = beta, s = s, N = N)
  output <- ode(y = c(x = x0), times = times, func = Ppun_aggr, parms = params)
  df <- as.data.frame(output)
  df$initial_value <- x0

  # ベータ二項分布に従ってばらつかせる
  alpha_beta <- 100  # ベータ分布の形状パラメータ1
  beta_beta <- 100   # ベータ分布の形状パラメータ2
  prob <- df$x / N
  df$x_obs <- rbbinom(n = nrow(df), size = N, alpha = alpha_beta * prob, beta = beta_beta * (1 - prob))
  df$n_total <- N
  results[[i]] <- df
}

# データを結合
all_data <- bind_rows(results, .id = "series")

# ggplotで散布図を作成
plt <- ggplot(all_data, aes(x = time, y = x_obs, color = as.factor(series))) +
  geom_point(size=0.8) +
  geom_line(aes(y = x), linetype = "solid") +
  labs(x = "時間", y = "観測個体数", color = "系列") +
  theme_classic() # +theme(legend.position = "none") 

ggsave("plot2.png", plot = plt, width = 1000, height = 800, units = "px")
write.csv(all_data[all_data$time>0,], "all_data_2.csv", row.names = FALSE)

plot2

Stanによる実装

基本は第一回と同じモデルになります。どの時系列データでもパラメータa,b,sは共通とします。数理モデルからの誤差\sigmaと初期値x_{0}は、系統によって独立であると仮定します(=別々のパラメータとして推定する)。総個体数N_{i}は入力データとして与えます。数式を使って書くとこんな感じです。系列iの観測個体数をx_{obs, i}とします。

x_{obs,i} \sim Normal(x_{i}, \sigma_{i}) \\ \frac{dx_{i}}{dt}=(\alpha+\beta x_{i})(N_{i}-x_{i})-\frac{sx_{i}}{s+x_{i}}

事前分布はほぼ第一回と同様に
a \sim Normal^+(0, 100)\\ b \sim Normal^+(0, 100)\\ s \sim Normal^+(0, 100)\\ \sigma_{i} \sim Cauchy^+(0, 2.5) \\ x_{0,i} \sim Uniform(0, N_{i})

とします。

model2.stan
functions { // モデル式の宣言
  vector beekman(real t, vector x, array[] real par) {
    vector[1] dxdt;
    
    real a = par[1]/10000;
    real b = par[2]/100000;
    real s = par[3];
    real n_total = par[4];
    
    dxdt[1] = (a + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
    
    return dxdt;
  }
}

data {
  int<lower=0> Series; // 時系列のデータ数
  int<lower=0> N; // 1時系列のデータ数
  array[Series, N] real<lower=0> ts; // 時間
  array[Series, N] real<lower=0> x; // ts=tにおける個体数
  array[Series, N] real<lower=0> n_total; // 総個体数
}

parameters {
  real<lower=0> a;
  real<lower=0> b;
  real<lower=0> s;
  array[Series] real<lower=0> sigma;
  array[Series] vector<lower=0>[1] x0; 
}

transformed parameters {
  array[Series, 4] real par;
  for (i in 1:Series) {
    par[i, 1] = a;
    par[i, 2] = b;
    par[i, 3] = s;
    par[i, 4] = n_total[i, 1];
  }
}

model {

  // priors
  a ~ normal(0, 100);
  b ~ normal(0, 100);
  s ~ normal(0, 100);
  
  for (i in 1:Series) {
    sigma[i] ~ cauchy(0, 2.5);
    x0[i, 1] ~ uniform(0, n_total[i, 1]);
  }
  
  // 各時系列ごとにODEを解く
  for (i in 1:Series) {
    array[N] vector[1] mu = ode_rk45(beekman, x0[i,], 0, ts[i, 1:N], par[i,]); 
    for (j in 1:N) {
      x[i,j] ~ normal(mu[j], sigma[i]);
    }
  }
}

generated quantities {
  array[Series, N] vector[1] mu_pred;
  for (i in 1:Series){
    mu_pred[i,] = ode_rk45(beekman, x0[i,], 0, ts[1, 1:N], par[i,]);
  }
}

実行

こちらも入力データの編集部分以外は第一回とほとんど同じです。個人的な感想ですが、データ数と実行時間には非線形な関係があるように思います笑 お菓子でも食べてゆっくり待ちましょう。cmdstanrの仕様上、1つのchainはCPU1スレッドで実行されます。parallel_chainsは並列で実行されるchain数ですので、parallel_chains>chainsにしても意味がないです。処理を早くしたい場合はシングルコア性能が高いCPUを使ってください。chain内並列処理もStanの機能としてあるそうですが、自分は試したことがありません。

run_model2.R
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)

raw_data <- read.csv("all_data_2.csv")
data <- raw_data %>% select(series, time, x_obs, n_total)
series_ids <- unique(data$series)
x_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(x_obs))
ts_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(time))
n_list <- map(series_ids, ~ data %>% filter(series == .x) %>% pull(n_total))
ts_lengths <- map(ts_list, length)  # 各系列の長さ このモデルでは全て同じ長さにする必要があるがこれを使えば拡張はできるはず

input <- list(
  Series = length(series_ids),
  N = ts_lengths[1],
  ts = ts_list,
  x = x_list,
  n_total = n_list
)

stan <- cmdstan_model('model2.stan')
fit <- stan$sample(data = input, iter_warmup = 2000, iter_sampling = 2000, parallel_chains = 4, chains = 4, save_warmup = TRUE)

# 以下各種グラフを描画
color_scheme_set("brewer-RdYlBu")
plt_dens <- mcmc_dens_overlay(fit$draws(c("a", "b", "s"),inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
color_scheme_set("blue")
plt_trace <- mcmc_trace(fit$draws(c("a", "b", "s"),inc_warmup = T),n_warmup = 2000) + theme_classic()
plt_pairs <- mcmc_pairs(fit$draws(c("a", "b", "s"),inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))
ggsave("trace_plot2.png", plot = plt_trace, width = 1000, height = 400, units = "px", dpi=180)
ggsave("dens_plot2.png", plot = plt_dens, width = 1000, height = 400, units = "px", dpi=180)
ggsave("pairs_plot2.png", plot = plt_pairs, width = 1000, height = 800, units = "px", dpi=180)

df_xpred <- fit$draws(format = "df") %>%
  spread_draws(mu_pred[series,time,vector]) %>%
  mode_hdi(.width = 0.95)  # MAP推定値と95%CIの計算

df_list <- split(df_xpred, df_xpred$series)# series列で分割してリストに格納

list2env(setNames(df_list, paste0("df", 1:3)), envir = .GlobalEnv) # 各リスト要素を df1, df2, df3 に代入

plt_fitting <- ggplot(df1, aes(x = time)) +
  geom_ribbon(aes(ymin = .lower, ymax = .upper), fill = '#D53E4F', alpha = 0.4) +
  geom_line(aes(y = mu_pred), linewidth = 1, col = "#D53E4F") +
  geom_ribbon(data = df2, aes(ymin = .lower, ymax = .upper), fill = '#32CD32', alpha = 0.4) +
  geom_line(data = df2, aes(y = mu_pred), linewidth=1, col="#32CD32") +
  geom_ribbon(data = df3, aes(ymin = .lower, ymax = .upper), fill = '#4169E2', alpha = 0.4) +
  geom_line(data = df3, aes(y = mu_pred), linewidth=1, col="#4169E2") +
  geom_line(data=data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
  geom_point(data=data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
  theme_classic() +
  labs(x = "時間", y = "観測個体数")

ggsave("fitting_plot2.png", plot = plt_fitting, width = 1000, height = 800, units = "px", dpi=180)

結果

出力した結果を列挙します。きれいに収束してくれました。真の値はa=45, b=15, s=10ですので、1個の時系列データを使った推定よりも非常に良い推定結果が得られているのが分かるかと思います。ちなみにx0[1,1], x0[2,1], ...はそれぞれ系列1,2,...の推定された初期値です。

result.R
>fit$summary()
# A tibble: 322 × 10
   variable    mean  median    sd   mad       q5     q95  rhat ess_bulk ess_tail
   <chr>      <dbl>   <dbl> <dbl> <dbl>    <dbl>   <dbl> <dbl>    <dbl>    <dbl>
 1 lp__     -998.   -998.   2.24  2.10  -1.00e+3 -995.    1.00    2804.    3760.
 2 a          41.5    41.5  3.97  4.01   3.49e+1   47.9   1.00    4459.    4248.
 3 b          14.4    14.4  1.82  1.81   1.15e+1   17.5   1.00    2051.    2394.
 4 s           9.58    9.54 1.58  1.57   7.03e+0   12.3   1.00    2037.    2488.
 5 sigma[1]    5.25    5.23 0.368 0.355  4.69e+0    5.88  1.00    6671.    5464.
 6 sigma[2]   30.3    30.2  2.14  2.13   2.69e+1   34.0   1.00    6346.    5341.
 7 sigma[3]   35.4    35.2  2.56  2.54   3.15e+1   39.8   1.00    6292.    4509.
 8 x0[1,1]     4.17    3.71 2.95  3.11   3.47e-1    9.65  1.00    3246.    2042.
 9 x0[2,1]   103.    103.   4.98  4.93   9.47e+1  111.    1.00    2678.    3213.
10 x0[3,1]   198.    198.   6.09  6.00   1.88e+2  208.    1.00    4485.    4213.
# ℹ 312 more rows
# ℹ Use `print(n = ...)` to see more rows
>fit$draws(format = "df") %>% spread_draws(a,b,s) %>% mode_hdi() # MAPと95%HDI
         a a.lower a.upper       b b.lower b.upper        s s.lower s.upper  .width .point .interval
  41.48638 33.9068 49.3635 14.3135 10.9573 18.0643 9.653436 6.48942 12.7132   0.95   mode       hdi

densplot
pairsplot
traceplot
fittingplot

まとめ

1つの時系列データだけでは推定困難なパラメータも、条件を変えたデータを取り入れることで推定することができました。うまくいった理由は、機会があれば説明を追記するつもりですが、第二回の感度分析を今回増やしたデータセットそれぞれで試してみるとわかるかもしれません。どんなデータを使えば推定がうまくいくのか、ぜひご自身でデータをいじってみてください。次回は、\alpha, \beta, sの群間差(e.g. コロニー差)を考慮できるベイズ階層モデルに拡張します。【第四回に続く】

追記

これまで使ったコード類はすべて私のGitHubにアップロードしています。データのcsvファイルもありますので、自分で作るのが面倒でしたらダウンロードしてください。ところどころ数値が違うかもしれません💦未整理でごめんなさい。
https://github.com/putaro-cu/StanODE

Discussion