🐜

【第一回】非線形微分方程式系のベイズ推定 with Stan: 1時系列データ

2024/11/19に公開

初めに

非線形な微分方程式で表される系は、分岐や多重安定性など、複雑な挙動を示します。今回は特に解析的に解くことができない数理モデルでベイズ推定する方法を紹介します。まず1時系列のデータのみを扱い、そこから複数時系列データ&階層モデルに拡張することを目指します。

実行環境

R ver. 4.4.2
cmdstanr ver. 0.8.1
cmdstan ver. 2.35.0
VSCode ver. 1.95.3

扱うモデル

アリの採餌動員についての数理モデル(Beekman et al. 2001)を使用します。

\frac{dx}{dt}=(\alpha+\beta x)(N-x)-\frac{sx}{s+x}

このモデルの詳細は省略しますが、動員されるアリの個体数xの時間変化\frac{dx}{dt}は、既に採餌している個体による動員の効果(第一項)と脱落する効果(第二項)によって決まるという数理モデルです。
この微分方程式はパラメータ\alpha, \beta , N, sによって平衡点の数と安定性が変化します。パラメータのうち\alpha, \beta , sはシステム固有のものであると考え、\alpha=0.0045, \beta=0.00015 , s=10を真の値とします。
今回はN=600で生成した時系列データから、\alpha, \beta , sをベイズ推定します。現実の実験ではコロニーの総個体数Nを人為的に操作してシステムの挙動の変化を観察することになりますので、最終的にはN=300N=1200などのパターンにも当てはめていきます。

データの生成

観察するアリの動員個体数x_{obs}

x_{obs} \sim BetaBin(N, \frac{100x}{N}, 100(1-\frac{x}{N})) \\ \frac{dx}{dt}=(\alpha+\beta x)(N-x)-\frac{sx}{s+x}
に従うとします。これによって生成したデータの例を以下に示します。\alpha=0.0045, \beta=0.00015 , s=10, N=600の場合、この系には2つの安定平衡点と1つの不安定平衡点が存在します。
Fig1

library(deSolve)
library(ggplot2)
library(dplyr)
library(tidyr)
library(extraDistr)

# 微分方程式の定義
Ppun_aggr <- function(t, x, params) {
  alpha <- params["alpha"]
  beta <- params["beta"]
  s <- params["s"]
  N <- params["N"]
  
  dxdt <- (alpha + beta * x[1]) * (N - x[1]) - (s * x[1]) / (s + x[1])
  
  list(c(dxdt))
}

# パラメータの設定
alpha <- 0.0045
beta <- 0.00015
s <- 10
N <- 600
times <- seq(0, 100, by = 1)  # 時間範囲

# 初期値をランダムにして10個のデータを生成
set.seed(100) 
initial_values <- runif(10, min = 0, max = N)
results <- list()

for (i in 1:10) {
  x0 <- initial_values[i]
  params <- c(alpha = alpha, beta = beta, s = s, N = N)
  output <- ode(y = c(x = x0), times = times, func = Ppun_aggr, parms = params)
  df <- as.data.frame(output)
  df$initial_value <- x0

  # ベータ二項分布に従ってばらつかせる
  alpha_beta <- 100  # ベータ分布の形状パラメータ1
  beta_beta <- 100   # ベータ分布の形状パラメータ2
  prob <- df$x / N
  df$x_obs <- rbbinom(n = nrow(df), size = N, alpha = alpha_beta * prob, beta = beta_beta * (1 - prob))
  df$n_total <- N
  results[[i]] <- df
}

# データを結合
all_data <- bind_rows(results, .id = "series")

# ggplotで散布図を作成
plt <- ggplot(all_data, aes(x = time, y = x_obs, color = as.factor(series))) +
  geom_point(size=0.8) +
  geom_line(aes(y = x), linetype = "solid") +
  labs(x = "時間", y = "観測個体数", color = "系列") +
  theme_classic() # +theme(legend.position = "none") 

ggsave("plot.png", plot = plt, width = 1000, height = 800, units = "px")
write.csv(all_data[all_data$time>0,], "all_data.csv", row.names = FALSE)

Stanによる実装

stanには内部にODEソルバーが実装(ode_rk45)されているので、これを利用します。\alpha, \betaは、簡単のためa=10000\alpha, b=100000\betaと変換しa, bを推定します。生成時はx_{obs}がベータ二項分布に従うとしましたが、現実には分布は未知なのが普通ですので、ここでは正規分布に従うとします。
また、微分方程式を解くには初期値x_{0}を与える必要がありますが、定数として与えることも、事前分布を与えて推定することもできます。今回は初期値も推定します。

x_{obs} \sim Normal(x, \sigma) \\ \frac{dx}{dt}=(\alpha+\beta x)(N-x)-\frac{sx}{s+x}

ある程度パラメータの大きさが分かっているものとして、事前分布として、
a \sim Normal^+(0, 100)\\ b \sim Normal^+(0, 100)\\ s \sim Normal^+(0, 100)\\ \sigma \sim Cauchy^+(0, 2.5) \\ x_{0} \sim Uniform(0, 600)

を設定します。

model1.stan
functions { // 微分方程式の宣言
  vector beekman(real t, vector x, array[] real par) {
    vector[1] dxdt;
    
    real a = par[1]/10000;
    real b = par[2]/100000;
    real s = par[3];
    real n_total = par[4];
    
    dxdt[1] = (a + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
    
    return dxdt;
  }
}

data { // 入力データ形式の宣言
  int<lower=0> N; // 1時系列のデータ数
  array[N] real<lower=0> ts; // 時間
  array[N] real<lower=0> x; // ts=tにおける個体数
  int<lower=0> n_total; // 総個体数
}

parameters { // 推定するパラメータの宣言
  real<lower=0> a;
  real<lower=0> b;
  real<lower=0> s;
  real<lower=0> sigma;
  vector<lower=0>[1] x0; // ode_rk45に与える初期値(vector形式で与える必要がある)
}

transformed parameters { // 変形されたパラメータの宣言
  array[4] real par;
  par[1] = a;
  par[2] = b;
  par[3] = s;
  par[4] = n_total;
}

model {  // モデル構造の宣言

  // priors
  a ~ normal(0, 100);
  b ~ normal(0, 100);
  s ~ normal(0, 100);
  sigma ~ cauchy(0, 2.5);
  x0[1] ~ uniform(0, n_total);
  
  array[N] vector[1] mu = ode_rk45(beekman, x0, 0, ts, par); // 微分方程式の数値解を得る
  for (i in 1:N) {
    x[i] ~ normal(mu[i], sigma); // 各データx_iは平均mu[i], 標準偏差sigmaの正規分布に従う
  }
}

generated quantities {
  array[N] vector[1] mu_pred = ode_rk45(beekman, x0, 0, ts, par);
}

実行

R&cmdstanによってMCMCを行います。インストールの方法はほかの記事を参照してください。今回はサンプルデータのうち系列10のデータのみ使用します。

run_model1.R
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)

raw_data <- read.csv("all_data.csv")
data <- raw_data %>% filter(series == 10) %>% select(time, x_obs, n_total)

input <- list(
  N = nrow(data),
  ts = data$time,
  x = data$x_obs,
  n_total = data$n_total[1]
)

stan <- cmdstan_model('model1.stan')
# warmup中のプロットも取得したい場合はsave_warmup = TRUEが必要
fit <- stan$sample(data = input, iter_warmup = 2000, iter_sampling = 2000, parallel_chains = 4, chains = 4, save_warmup = TRUE)

結果の確認

基本的なプロットを出力します。結果の可視化はbayesplotパッケージがおすすめです。

run_model1.R
color_scheme_set("brewer-RdYlBu")
plt_dens <- mcmc_dens_overlay(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
color_scheme_set("blue")
plt_trace <- mcmc_trace(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = T), n_warmup = 2000) + theme_classic()
plt_pairs <- mcmc_pairs(fit$draws(c("a", "b", "s", "sigma"),inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))
ggsave("trace_plot.png", plot = plt_trace, width = 1000, height = 800, units = "px", dpi=180)
ggsave("dens_plot.png", plot = plt_dens, width = 1000, height = 800, units = "px", dpi=180)
ggsave("pairs_plot.png", plot = plt_pairs, width = 1000, height = 800, units = "px", dpi=180)

推定結果

実行結果は

fit$summary()

で確認できます。cmdstanでは、デフォルトで出力される確信区間が90%であることに注意しましょう。\hat{R}>1.1であれば、収束していないと判断されることが多いです。トレースプロットや事後分布の形がチェーンによって違ったりする場合もうまく推定できていません。

> fit$summary()
# A tibble: 110 × 10
   variable    mean  median    sd   mad      q5    q95  rhat ess_bulk ess_tail
   <chr>      <dbl>   <dbl> <dbl> <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__     -370.   -370.    1.93  1.71 -374.   -368.   1.00    1386.    1583.
 2 a          20.6    16.9  15.9  15.5     1.51   51.1  1.00    1612.    2367.
 3 b          12.0    11.9   2.44  2.43    8.38   16.3  1.00    1186.    1435.
 4 s           6.66    6.41  2.44  2.44    3.13   11.1  1.00    1138.    1567.
 5 sigma      27.0    26.9   1.96  1.94   24.0    30.3  1.00    3183.    3610.
 6 x0[1]     102.    102.    6.12  6.25   91.3   111.   1.00    2107.    2255.

今回、真の値はa=45, b=15, s=10であるので、b, sはおおむね正しい値が推定されていますが、aは分散が大きくうまく推定できていません。
事後平均値(EAP: Expected A Posteriori), 事後最頻値(MAP: Maximum A Posterior Estimator), 事後中央値(MED: posterior MEDian)といった点推定値や、等裾事後確信区間(ETI: Equal-Tailed Interval)や最大事後密度確信区間(HDI: Highest posterior Density Interval)といった区間推定値を計算するにはtidybayesパッケージを使って次のようにします。

fit$draws(format = "df") %>%
    spread_draws(a,b,s) %>%
    mean_qi() # EAPと95%ETI
fit$draws(format = "df") %>%
    spread_draws(a,b,s) %>%
    mode_hdi() # MAPと95%HDI
fit$draws(format = "df") %>%
    spread_draws(a,b,s) %>%
    median_hdi() # MEDと95%HDI
         a   a.lower  a.upper       b  b.lower b.upper        s s.lower   s.upper .width .point .interval
1 20.60315 0.7498388 58.71253 12.0445 7.710016 17.0512 6.655556   2.601   12.03596   0.95   mean        qi

         a     a.lower a.upper      b b.lower b.upper        s s.lower s.upper  .width .point .interval
1 4.889889 0.007764758 51.1794 11.428  7.6966 17.0375 5.533436 2.15004 11.4708   0.95   mode       hdi

        a     a.lower a.upper       b b.lower b.upper       s s.lower s.upper  .width .point .interval
1 16.9045 0.007764758 51.1794 11.8855  7.6966 17.0375 6.40527 2.15004 11.4708   0.95 median       hdi

トレースプロット

traceplot

事後分布

densplot
pairssplot

フィッティング結果の描画

事後最頻値(MAP)とその95%CI(HDI)を計算してグラフにします。

run_model1.R
df_xpred <- fit$draws(format = "df") %>%
  spread_draws(mu_pred[time]) %>%
  mode_hdi(.width = 0.95)  # 予測値と95%CIの計算

plt_fitting <- ggplot(df_xpred, aes(x = time)) +
  geom_ribbon(aes(ymin = .lower, ymax = .upper), fill = '#D53E4F', alpha = 0.4) +
  geom_line(aes(y = mu_pred), linewidth = 1, col = "#D53E4F") +
  geom_line(data = data.frame(input), aes(x = ts, y = x), col = "#D53E4F") +
  geom_point(data = data.frame(input), aes(x = ts, y = x), col = "#D53E4F") +
  ylim(c(0, 600)) +
  theme_classic() +
  labs(x = "時間", y = "観測個体数")

ggsave("fitting_plot.png", plot = plt_fitting, width = 1000, height = 800, units = "px", dpi=180)

fittingplot

違うデータではどうなる?

データによっては、うまく推定できない場合があります。系列7のデータを使って同じように推定を行ってみます。一見、うまくデータにフィットしているように見えますが、系列10の結果と比べると、推定値の分散が非常に大きいことが分かります。特にaの事後分布は与えた事前分布からほとんど変化していません。

> fit_alt$summary()
# A tibble: 110 × 10
   variable   mean median    sd   mad      q5    q95  rhat ess_bulk ess_tail
   <chr>     <dbl>  <dbl> <dbl> <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__     -357.  -357.   1.61  1.41 -361.   -356.   1.00    2390.    3067.
 2 a          81.7   68.2 62.1  60.0     6.44  200.   1.00    2419.    1878.
 3 b         118.   117.  44.0  45.2    51.0   195.   1.00    1583.    1579.
 4 s          89.8   85.6 39.1  38.2    34.8   161.   1.00    1583.    1524.
 5 sigma      25.4   25.3  1.87  1.83   22.5    28.6  1.00    3416.    3464.
 6 x0[1]     529.   528.  33.0  35.9   477.    585.   1.00    1861.    2033.

traceplot2
densplot
pairsplot
fittingplot

まとめ

今回は、微分方程式で表される数理モデルのパラメータをStanとRを使ってベイズ推定する手法と結果の図示の仕方について簡単に紹介しました。単純な曲線でフィッティングするよりも、より生物学的な意味を持たせた解釈ができるのではないでしょうか。次回は推定時に発生した問題について考えていきます。【第二回に続く】

参考文献

松浦健太郎 StanとRでベイズ統計モデリング
https://www.kyoritsu-pub.co.jp/book/b10003786.html
分寺杏介 統計的方法論特殊研究講義資料
https://www2.kobe-u.ac.jp/~bunji/resource.html
その他参考にしたもの
https://www.martinmodrak.cz/2018/05/14/identifying-non-identifiability/
https://futaba-nt.com/archives/36
https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html
https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1010651

Discussion