🦩

【第四回】非線形微分方程式系のベイズ推定 with Stan: 階層ベイズモデル

2024/11/28に公開

はじめに

第三回で複数時系列データを扱えるように拡張しました。今回は、階層ベイズモデルに拡張することで、群ごとのパラメータのばらつきを扱えるようにしたいと思います。まず今まで使ってきたアリの採餌動員モデルを再掲しておきます。

\frac{dx}{dt}=(\alpha+\beta x)(N-x)-\frac{sx}{s+x}

\alpha, \beta, sは生物学的な意味合いに照らせば、それぞれ餌の発見確率、動員確率、フェロモンの効果の強さを表しています。これらの値は、常に一定とするよりも、実験間や由来コロニー間にばらつき(異質性)があると考えるほうが自然でしょう。

階層モデルについて

階層モデルでは、各グループのパラメータを「全体平均」と「グループ差」に分解して表現します。この時「グループ差」は正規分布に従うと仮定します。今回は、1つの時系列データを「グループ」として考えます。もしコロニー差みたいなものを考えたい場合は、由来コロニーが同じ時系列を1つの「グループ」と扱えばよいです。
アリの採餌動員モデルにおけるパラメータ\alphaを代表にして、具体的に解説していきます。時系列iのパラメータ\alpha_{i}を「全体平均」\mu_{\alpha}と「グループ差」\delta_{i}に分解すると数式を使って

\alpha_{i}=\mu_{\alpha}+\delta_{i}

と記述できます。さらに、「グループ差」\delta_{i}は平均0、標準偏差\sigma_{\alpha}の正規分布に従うと仮定するので、以下のようになります。
\delta_{i} \sim Normal(0,\sigma_{\alpha})

これらを簡潔に一つの式でまとめると
\alpha_{i} \sim Normal(\mu_{\alpha},\sigma_{\alpha})

となります。このパイパーパラメータ\mu_{\alpha}\sigma_{\alpha}を推定します。

Stanによる実装

\alpha, \beta, sすべてで系列間の異質性があるとして階層モデル化します。また、簡単のためa=10000\alpha, b=100000\betaと変換します。モデル式を書くと以下のようになります。

x_{obs,i} \sim Normal(x_{i}, \sigma_{i}) \\ \frac{dx_{i}}{dt}=(\alpha_{i}+\beta_{i} x_{i})(N_{i}-x_{i})-\frac{s_{i}x_{i}}{s_{i}+x_{i}} \\ a_{i} \sim Normal(\mu_{a},\sigma_{a}) \\ b_{i} \sim Normal(\mu_{b},\sigma_{b}) \\ s_{i} \sim Normal(\mu_{s},\sigma_{s})

事前分布は以下のものを与えます。
\mu_{a} \sim Normal^+(0, 100)\\ \mu_{b} \sim Normal^+(0, 100)\\ \mu_{s} \sim Normal^+(0, 100) \\ \sigma_{a} \sim Cauchy^+(0, 2.5) \\ \sigma_{b} \sim Cauchy^+(0, 2.5) \\ \sigma_{s} \sim Cauchy^+(0, 2.5) \\ \sigma_{i} \sim Cauchy^+(0, 2.5) \\ x_{0,i} \sim Uniform(0, N_{i})

model3.stan
functions { // モデル式の宣言
  vector beekman(real t, vector x, array[] real par) {
    vector[1] dxdt;
    
    real a = par[1]/10000;
    real b = par[2]/100000;
    real s = par[3];
    real n_total = par[4];
    
    dxdt[1] = (a + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
    
    return dxdt;
  }
}

data {
  int<lower=0> Series; // 時系列のデータ数
  int<lower=0> N; // 1時系列のデータ数
  array[Series, N] real<lower=0> ts; // 時間
  array[Series, N] real<lower=0> x; // ts=tにおける個体数
  array[Series, N] real<lower=0> n_total; // 総個体数
}

parameters {
  array[Series] real<lower=0> a;
  array[Series] real<lower=0> b;
  array[Series] real<lower=0> s;
  real<lower=0> mu_a;
  real<lower=0> mu_b;
  real<lower=0> mu_s;
  real<lower=0> sigma_a;
  real<lower=0> sigma_b;
  real<lower=0> sigma_s;
  array[Series] real<lower=0> sigma;
  array[Series] vector<lower=0>[1] x0; 
}

transformed parameters {
  array[Series, 4] real par;
  for (i in 1:Series) {
    par[i, 1] = a[i];
    par[i, 2] = b[i];
    par[i, 3] = s[i];
    par[i, 4] = n_total[i, 1];
  }
}

model {

  // priors
  mu_a ~ normal(0, 100);
  mu_b ~ normal(0, 100);
  mu_s ~ normal(0, 100);
  sigma_a ~ cauchy(0, 2.5);
  sigma_b ~ cauchy(0, 2.5);
  sigma_s ~ cauchy(0, 2.5);
  
  for (i in 1:Series) {
    a[i] ~ normal(mu_a, sigma_a);
    b[i] ~ normal(mu_b, sigma_b);
    s[i] ~ normal(mu_s, sigma_s);
    sigma[i] ~ cauchy(0, 2.5);
    x0[i, 1] ~ uniform(0, n_total[i, 1]);
  }
  
  // 各時系列ごとにODEを解く
  for (i in 1:Series) {
    array[N] vector[1] mu = ode_rk45(beekman, x0[i,], 0, ts[i, 1:N], par[i,]); 
    for (j in 1:N) {
      x[i,j] ~ normal(mu[j], sigma[i]);
    }
  }
}

generated quantities {
  array[Series, N] vector[1] mu_pred;
  for (i in 1:Series){
    mu_pred[i,] = ode_rk45(beekman, x0[i,], 0, ts[1, 1:N], par[i,]);
  }
}

実行

以下の6個の時系列データを使用します。各パラメータの全体平均と標準偏差を推定するので前回より多いデータ数にしています。生成に使ったコードは省略しますが、基本的には以前のコードに少し改変を加えただけです。私のGitHubから見ることができます。真の値は\mu_{a}=45, \mu_{b}=15, \mu_{s}=10, \sigma_{a}=5, \sigma_{b}=5, \sigma_{s}=0.1です。階層モデルにすると実行時間がかなり長くなります。私の環境では約100分ほどかかりました。
plot3

run_model3.R
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)

raw_data <- read.csv("all_data_3.csv")
data <- raw_data %>% select(series, time, x_obs, n_total)
series_ids <- unique(data$series)
x_list <- map(series_ids, ~ data %>%
    filter(series == .x) %>%
    pull(x_obs))
ts_list <- map(series_ids, ~ data %>%
    filter(series == .x) %>%
    pull(time))
n_list <- map(series_ids, ~ data %>%
    filter(series == .x) %>%
    pull(n_total))
ts_lengths <- map(ts_list, length)

input <- list(
    Series = length(series_ids),
    N = ts_lengths[1],
    ts = ts_list,
    x = x_list,
    n_total = n_list
)

stan <- cmdstan_model("model3.stan")
fit <- stan$sample(data = input, iter_warmup = 2000, iter_sampling = 2000, parallel_chains = 4, chains = 4, save_warmup = TRUE)


color_scheme_set("brewer-RdYlBu")
plt_dens <- mcmc_dens_overlay(fit$draws(c("mu_a", "mu_b", "mu_s"), inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
plt_dens2 <- mcmc_dens_overlay(fit$draws(c("sigma_a", "sigma_b", "sigma_s"), inc_warmup = F)) + geom_density(linewidth = 1) + theme_classic()
color_scheme_set("blue")
plt_trace <- mcmc_trace(fit$draws(c("mu_a", "mu_b", "mu_s", "sigma_a", "sigma_b", "sigma_s"), inc_warmup = T), n_warmup = 2000) + theme_classic()
plt_pairs <- mcmc_pairs(fit$draws(c("mu_a", "mu_b", "mu_s"), inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))
plt_pairs2 <- mcmc_pairs(fit$draws(c("sigma_a", "sigma_b", "sigma_s"), inc_warmup = F), off_diag_args = list(size = 0.5, alpha = 0.5))

ggsave("trace_plot3.png", plot = plt_trace, width = 1000, height = 800, units = "px", dpi = 180)
ggsave("dens_plot3-1.png", plot = plt_dens, width = 1000, height = 400, units = "px", dpi = 180)
ggsave("dens_plot3-2.png", plot = plt_dens2, width = 1000, height = 400, units = "px", dpi = 180)
ggsave("pairs_plot3-1.png", plot = plt_pairs, width = 1000, height = 800, units = "px", dpi = 180)
ggsave("pairs_plot3-2.png", plot = plt_pairs2, width = 1000, height = 800, units = "px", dpi = 180)

df_xpred <- fit$draws(format = "df") %>%
    spread_draws(mu_pred[series, time, vector]) %>%
    mode_hdi(.width = 0.95) # MAP推定値と95%CIの計算

df_list <- split(df_xpred, df_xpred$series) # series列で分割してリストに格納

list2env(setNames(df_list, paste0("df", 1:6)), envir = .GlobalEnv) # 各リスト要素を df1, df2, df3 に代入

plt_fitting <- ggplot(df1, aes(x = time)) +
    geom_ribbon(aes(ymin = .lower, ymax = .upper), fill = "#D53E4F", alpha = 0.4) +
    geom_line(aes(y = mu_pred), linewidth = 1, col = "#D53E4F") +
    geom_ribbon(data = df2, aes(ymin = .lower, ymax = .upper), fill = "#cd9c32", alpha = 0.4) +
    geom_line(data = df2, aes(y = mu_pred), linewidth = 1, col = "#cd9c32") +
    geom_ribbon(data = df3, aes(ymin = .lower, ymax = .upper), fill = "#32CD32", alpha = 0.4) +
    geom_line(data = df3, aes(y = mu_pred), linewidth = 1, col = "#32CD32") +
    geom_ribbon(data = df4, aes(ymin = .lower, ymax = .upper), fill = "#41b4e2", alpha = 0.4) +
    geom_line(data = df4, aes(y = mu_pred), linewidth = 1, col = "#41b4e2") +
    geom_ribbon(data = df5, aes(ymin = .lower, ymax = .upper), fill = "#4169E2", alpha = 0.4) +
    geom_line(data = df5, aes(y = mu_pred), linewidth = 1, col = "#4169E2") +
    geom_ribbon(data = df6, aes(ymin = .lower, ymax = .upper), fill = "#a70de4", alpha = 0.4) +
    geom_line(data = df6, aes(y = mu_pred), linewidth = 1, col = "#a70de4") +
    geom_line(data = data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
    geom_point(data = data, aes(x = time, y = x_obs, group = as.factor(series), col = as.factor(series))) +
    labs(x = "時間", y = "観測個体数", color = "系列") +
    theme_classic() +
    labs(x = "時間", y = "観測個体数")

ggsave("fitting_plot3.png", plot = plt_fitting, width = 1000, height = 800, units = "px", dpi = 180)

結果

出力した結果が以下のようになります。推定結果には各系列のa_{i}なども含まれています。真の値は\mu_{a}=45, \mu_{b}=15, \mu_{s}=10, \sigma_{a}=5, \sigma_{b}=5, \sigma_{s}=0.1です。おおよそ正確に推定されています。

result.R
> print(fit$summary(),n=40)
# A tibble: 661 × 10
   variable     mean   median     sd    mad        q5      q95  rhat ess_bulk
   <chr>       <dbl>    <dbl>  <dbl>  <dbl>     <dbl>    <dbl> <dbl>    <dbl>
 1 lp__     -1983.   -1983.    9.33   9.31  -1999.    -1968.    1.00     649.
 2 mu_a        43.0     42.8   7.18   6.30     32.1      54.6   1.00    1905.
 3 mu_b        15.4     15.2   3.40   3.01     10.4      21.3   1.00    3364.
 4 mu_s         9.38     9.32  1.36   1.20      7.32     11.6   1.00    2545.
 5 sigma_a      3.36     2.02  4.33   1.97      0.256    11.0   1.01     516.
 6 sigma_b      6.18     5.70  3.02   2.42      2.25     11.6   1.00    1949.
 7 sigma_s      1.51     1.18  1.34   1.07      0.105     3.99  1.01     718.
 8 a[1]        43.3     43.0   5.67   5.44     34.5      53.0   1.00    2434.
 9 a[2]        42.3     42.3   8.05   6.78     29.1      54.6   1.00    2091.
10 a[3]        43.1     42.8   8.04   6.77     31.0      56.0   1.00    2195.
11 a[4]        43.2     42.9   8.96   6.73     30.6      56.5   1.00    2385.
12 a[5]        43.1     42.8   8.82   6.88     30.5      56.2   1.00    2235.
13 a[6]        43.3     42.8   9.06   6.80     30.5      56.7   1.00    2328.
14 b[1]        15.4     15.0   6.56   5.83      5.22     27.1   1.00    3846.
15 b[2]        14.7     14.6   1.53   1.47     12.4      17.4   1.00    2408.
16 b[3]         7.76     7.47  2.05   1.78      4.95     11.5   1.00    1752.
17 b[4]        15.4     15.3   1.66   1.59     12.8      18.2   1.00    3422.
18 b[5]        16.1     16.2   2.61   2.56     11.7      20.2   1.00    2103.
19 b[6]        23.2     23.7   5.18   5.06     14.1      31.0   1.00    1630.
20 s[1]         9.76     9.48  2.49   1.67      6.61     13.8   1.00    2768.
21 s[2]         9.90     9.79  1.42   1.35      7.78     12.5   1.00    2029.
22 s[3]        10.1      9.79  1.67   1.40      7.82     13.2   1.00    1642.
23 s[4]         9.44     9.40  1.16   1.10      7.61     11.4   1.00    3354.
24 s[5]         8.64     8.69  1.47   1.45      6.11     11.0   1.00    2100.
25 s[6]         8.36     8.52  1.84   1.80      5.17     11.1   1.00    1646.
26 sigma[1]     5.16     5.14  0.371  0.358     4.60      5.81  1.00   10479.
27 sigma[2]    28.3     28.2   2.08   2.04     25.1      32.0   1.00    8242.
28 sigma[3]    24.7     24.6   1.81   1.74     21.9      27.9   1.00   10993.
29 sigma[4]    24.0     23.9   1.73   1.71     21.3      27.0   1.00   10574.
30 sigma[5]    25.3     25.2   1.81   1.77     22.5      28.4   1.00    9584.
31 sigma[6]    16.6     16.6   1.21   1.18     14.8      18.7   1.00    9303.
32 x0[1,1]      4.81     4.39  3.18   3.38      0.481    10.6   1.00    6861.
33 x0[2,1]     99.2     99.1   5.22   5.17     90.7     108.    1.00    3208.
34 x0[3,1]    207.     207.    7.37   7.16    194.      219.    1.00    3668.
35 x0[4,1]    298.     298.    8.37   8.19    284.      312.    1.00    5060.
36 x0[5,1]    407.     407.   10.4   10.1     389.      424.    1.00    4851.
37 x0[6,1]    516.     516.    9.00   8.72    500.      530.    1.00    5927.
38 par[1,1]    43.3     43.0   5.67   5.44     34.5      53.0   1.00    2434.
39 par[2,1]    42.3     42.3   8.05   6.78     29.1      54.6   1.00    2091.
40 par[3,1]    43.1     42.8   8.04   6.77     31.0      56.0   1.00    2195.
# ℹ 621 more rows
# ℹ 1 more variable: ess_tail <dbl>
# ℹ Use `print(n = ...)` to see more rows

densplot3-1
densplot3-2
pairsplot3-1
pairsplot3-2
traceplot3
fittingplot3

階層化しなかった場合との比較

第三回のコードでデータだけ入れ替えて推定を行ってみましょう。このモデルでは、すべての時系列でa,b,sの値が等しいと仮定していました。そのせいで系列3のようなずれた挙動をするデータが混ざっている場合にうまく推定ができません。コードは省略して結果だけお見せします。

results.R
> fit$summary()
# A tibble: 640 × 10
   variable     mean   median     sd    mad        q5      q95  rhat ess_bulk
   <chr>       <dbl>    <dbl>  <dbl>  <dbl>     <dbl>    <dbl> <dbl>    <dbl>
 1 lp__     -2198.   -2198.    2.93   2.86  -2203.    -2194.    1.00    2564.
 2 a           18.5     21.9   7.81   4.28      1.90     26.2   1.01     644.
 3 b            7.34     7.11  0.948  0.856     6.13      9.17  1.01     724.
 4 s            3.17     3.09  0.474  0.443     2.54      4.07  1.01     867.
.......().......

pairsplot2
fittingplot2

まとめ

今回は、それぞれのパラメータの異質性を考慮できる階層ベイズモデルへの拡張についてご紹介しました。階層化によって全体の傾向\muとばらつき\sigmaとしてそれぞれのパラメータを解釈することができます。自分の理解の範囲を超えますが、さらなる発展として時系列データを考慮した状態空間モデルへの拡張も可能です。このアプローチについて詳しく知りたい方は、CIUKF-MCMC法: Linden et al. 2022をぜひ参照してみてください。

Discussion