🐬

【第二回】非線形微分方程式系のベイズ推定 with Stan: 識別可能性とSobol感度分析

2024/11/22に公開

発生した問題

前回、データからパラメータを推定する中で、事後分布が事前分布からほとんど変化しないものが出てきてしまいました。今回は、そのような推定困難なパラメータが生じる原因と対処法を紹介します。
統計モデルとサンプルデータ、系列10のデータ(茶色)からの推定結果を再掲しておきます(a=10000\alpha, b=100000\betaに変換しています)。真の値はa=45, b=15 , s=10です。N=600に固定して実験をしたという想定です。aの事後分布に注目しますと、分散が大きすぎて意味のある推定値が得られていないことが分かるかと思います。

x_{obs} \sim Normal(x, \sigma) \\ \frac{dx}{dt}=(\alpha+\beta x)(N-x)-\frac{sx}{s+x}

Fig1
pairsplot

識別可能性解析(Identifiability Analysis)

数理モデル自体の問題で、どんなにデータを集めてもパラメータが一意に定められない(識別不可能)場合があります。これを検証するには識別可能性解析が有用です。データから推定を行う前に識別不可能な数理モデルになっていないか、きちんと確認しておくのが良いでしょう。

構造的識別可能性(Structual Identifiability)とは

異なるパラメータセットから同じデータが生成できてしまうような数理モデルでは、たとえノイズが全くなくても観測データ\bm{y}(t)からパラメータを一意に定めることはできません。このことを構造的に識別不可能であるといいます(Wieland et al. 2021)。

定義: あるパラメータ\theta_{i}がグローバルに構造的識別可能であるとき、任意のパラメータベクトル\bm{\theta}について

\bm{y}(t| \bm{\theta'})=\bm{y}(t| \bm{\theta}) \Rightarrow \theta_{i}'=\theta_{i}

がすべてのt>0で成り立つ

特定のパラメータ範囲において一意に推定されるかどうかは局所識別可能性(local Identifiability)ということもあります。
識別可能性解析は、Juliaのパッケージで実装されています
https://github.com/SciML/StructuralIdentifiability.jl

解析例

今回用いた微分方程式に適用してみます。実験設定からN=600は既知であるとして代入しておきます。この数理モデルではすべてのパラメータがグローバルに識別可能です。

sia.jl
using ModelingToolkit
using DifferentialEquations
using StructuralIdentifiability
using Random
using GlobalSensitivity
using Plots

iv = @variables t
states = @variables x1(t)
@variables y(t)
ps = @parameters α=0.0045 β=0.00015 s=10
D = Differential(t)

eqs = [
    D(x1) ~ (α + β * x1) * (600 - x1) - s * x1 / (s + x1)
]

obs_eq = [y ~ x1]
measured_quantities = [y ~ x1]

@named model = ODESystem(eqs, t, states, ps; observed = obs_eq)

sia_result = assess_identifiability(model, measured_quantities = measured_quantities, p = 0.99) # 識別可能性解析の実行
println(sia_result)
実行結果
OrderedCollections.OrderedDict{SymbolicUtils.BasicSymbolic{Real}, Symbol} with 4 entries:
  x1(t) => :globally
  α     => :globally
  β     => :globally
  s     => :globally

:globally meaning that the parameter is globally identifiable
:locally meaning that the parameter is locally but not globally identifiable
:nonidentifiable meaning that the parameter is not identifiable even locally.

識別不可能なモデル例

明らかに変なモデルですが、次のような例を考えてみます。

\frac{dx}{dt}=(\alpha+\beta)x

初期値x_{0}とパラメータ\alpha, \betaを与えればx(t)が一意に定まりますが、x(t)からパラメータは一意に定まりません。具体的に言うと\alpha=1,\beta=2であれば、x(t)=3xt+x_{0}と求められますが、x(t)=3xt+x_{0}となる\alpha, \betaの組み合わせはこれ以外にも無数にあります(例: \alpha=2,\beta=1)。

sia.jl
iv = @variables t
states = @variables x1(t)
@variables y(t)
ps = @parameters α=1 β=1
D = Differential(t)

eqs = [
    D(x1) ~ (α + β) * x1
]

obs_eq = [y ~ x1]
measured_quantities = [y ~ x1]

@named model = ODESystem(eqs, t, states, ps; observed = obs_eq)

sia_result = assess_identifiability(model, measured_quantities = measured_quantities, p = 0.99) # 識別可能性解析の実行
実行結果
OrderedCollections.OrderedDict{SymbolicUtils.BasicSymbolic{Real}, Symbol} with 3 entries:
  x1(t) => :globally
  α     => :nonidentifiable
  β     => :nonidentifiable

対処法: 再パラメータ化

StructuralIdentifiability.jlには、識別可能になるように再パラメータ化してくれる非常に便利な関数find_identifiable_functions()が実装されています。もう少し複雑なモデルにして試してみましょう。

\frac{dx_{1}}{dt}=ax_{1}-bx_{1}x_{2} \\ \frac{dx_{2}}{dt}=-cx_{2}+dx_{1}x_{2}

x_{1}しか観測できない場合、識別可能性解析の結果は次のようになります。ちなみにx_{1}x_{2}が両方とも観測できる場合は、すべてのパラメータが識別可能です。

sia.jl
iv = @variables t
states = @variables x1(t) x2(t)
@variables y1(t) y2(t)
ps = @parameters a b c d
D = Differential(t)

eqs = [
    D(x1) ~ a * x1 - b * x1 * x2,
    D(x2) ~ -c * x2 + d * x1 * x2
]

obs_eq = [y1 ~ x1] # x1もx2も観測できる場合は[y1 ~ x1, y2~x2]に変更する
measured_quantities = [y1 ~ x1] # x1もx2も観測できる場合は[y1 ~ x1, y2~x2]に変更する

@named model = ODESystem(eqs, t, states, ps; observed = obs_eq)

sia_result = assess_identifiability(model, measured_quantities = measured_quantities, p = 0.99) # 識別可能性解析の実行
println(sia_result)

find_result = find_identifiable_functions(model, measured_quantities = measured_quantities, with_states = true) # 識別可能性な関数の探索を実行
println(find_result)
実行結果
>sia_result
OrderedCollections.OrderedDict{SymbolicUtils.BasicSymbolic{Real}, Symbol} with 6 entries:
  x1(t) => :globally
  x2(t) => :nonidentifiable
  a     => :globally
  b     => :nonidentifiable
  c     => :globally
  d     => :globally
>find_result
5-element Vector{Num}:
   x1(t)
       d
       c
       a
 b*x2(t)

つまり、x_{2}(t),bは識別不可能ですが、\hat{x}_{2}(t)=bx_{2}(t)と変換すれば、すべてのパラメータが識別可能になります。

Sobol感度分析

数理モデルの出力に対するパラメータの影響を評価する分析を感度分析といいます。Sobol感度分析はグローバルな感度解析の一種です(Sobol 2001)。構造的に識別可能であっても、モデルの出力への貢献が小さいパラメータは推定が難しくなります。こちらも、Juliaのパッケージで実装されています。
https://github.com/SciML/GlobalSensitivity.jl

解析例

ここでのモデルの出力はt=100の値とします。初期値は系列10に合わせてx_{0}=102.2と指定しています。グラフを見てみると、パラメータ\alphaのSensitivity Indexが他と比べて非常に小さいことが分かります。これは\alphaの値を変えてもモデルの動態がほとんど変化しないことを意味します。実際に\alphaを様々に変化させてグラフを描いてみると良いでしょう。
観測データから各パラメータを推定する際、このような感度が低いパラメータは少し変えても尤度がほとんど変化しませんので、意味のある推定結果になりません。

sensitivity.jl
function model_func(p)
    prob = ODEProblem(model, [102.2], (0.0, 100.0), p)
    sol = solve(prob, Tsit5())
    return sol[1, end]
end


param_ranges = [(0.0001, 0.01), (0.00001, 0.001), (1, 50.0)] # 動かすパラメータ範囲

sobol_result = gsa(model_func, Sobol(), param_ranges, samples=1000) # 感度分析の実行
println(sobol_result)

first_order = sobol_result.S1[:]
total_order = sobol_result.ST[:]

param_names = ["α", "β", "s"]
p1 = bar(param_names, first_order, legend=:none)
xlabel!(p1, "Parameters")
ylabel!(p1, "First-Order Sensitivity Index")
p2 = bar(param_names, total_order, legend=:none)
xlabel!(p2, "Parameters")
ylabel!(p2, "Total-Order Sensitivity Index")
p_combined = plot(p1, p2, layout=(1, 2), size=(800, 400), left_margin=5Plots.mm, right_margin=5Plots.mm, bottom_margin=5Plots.mm, top_margin=5Plots.mm)

savefig(p_combined, "sobol_sensitivity_analysis_combined.png")

Sobolplot
改めて推定結果をみてみると、確かに感度が低いaの推定がうまくいっていないことが分かります。
pairssplot

対処法: 感度が低いパラメータの固定

Sobol感度解析を行って、感度が非常に低いパラメータ(First-Order Sensitivity Index < 0.1)については、推定をせず妥当な値に固定するとよいという文献があります(Linden et al. 2023)。上の例で、a=50に固定して、ベイズ推定をやり直してみましょう。

run_model1_2.R
library(tidyverse)
library(rstan)
library(bayesplot)
library(ggplot2)
library(cmdstanr)
library(tidybayes)

raw_data <- read.csv("all_data.csv")
data <- raw_data %>% filter(series == 10) %>% select(time, x_obs, n_total)

input <- list(
  N = nrow(data),
  ts = data$time,
  x = data$x_obs,
  n_total = data$n_total[1]
)

stan <- cmdstan_model('model1_2.stan')
fit <- stan$sample(data = input, iter_warmup = 5000, iter_sampling = 5000, parallel_chains = 4, chains = 4, save_warmup = TRUE, adapt_delta = 0.9)

model1_2.stan
functions { // モデル式の宣言
  vector beekman(real t, vector x, array[] real par) {
    vector[1] dxdt;
    
    real b = par[1]/100000;
    real s = par[2];
    real n_total = par[3];
    
    dxdt[1] = (0.005 + b * x[1]) * (n_total - x[1]) - (s * x[1]) / (s + x[1]);
    
    return dxdt;
  }
}

data {
  int<lower=0> N; // 1時系列のデータ数
  array[N] real<lower=0> ts; // 時間
  array[N] real<lower=0> x; // ts=tにおける個体数
  int<lower=0> n_total; // 総個体数
}

parameters {
  real<lower=0> b;
  real<lower=0> s;
  real<lower=0> sigma;
  vector<lower=0>[1] x0; 
}

transformed parameters {
  array[3] real par;
  par[1] = b;
  par[2] = s;
  par[3] = n_total;
}

model {

  // priors
  a ~ normal(0, 100);
  b ~ normal(0, 100);
  s ~ normal(0, 100);
  sigma ~ cauchy(0, 2.5);
  x0[1] ~ uniform(0, n_total);
  
  array[N] vector[1] mu = ode_rk45(beekman, x0, 0, ts, par); 
  
  for (i in 1:N) {
    x[i] ~ normal(mu[i], sigma);
  }
}

generated quantities {
  array[N] vector[1] mu_pred = ode_rk45(beekman, x0, 0, ts, par);
}

推定結果
> fit$summary()                  
# A tibble: 108 × 10
   variable       mean median    sd   mad      q5    q95  rhat ess_bulk ess_tail
   <chr>         <dbl>  <dbl> <dbl> <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__         -373.  -373.   1.52  1.30 -376.   -371.   1.00    4667.    5440.
 2 b              15.2   15.1  2.63  2.68   11.0    19.7  1.00    2962.    2895.
 3 s              10.4   10.3  2.29  2.31    6.90   14.4  1.00    2929.    2880.
 4 sigma          27.4   27.3  1.93  1.95   24.3    30.6  1.00    6376.    7459.

pairplot1_2
真の値がb=15, s=10ですので、aを固定したほうがよい推定結果が得られました。やったー!と喜びたいところですがa=50と真の値(45)に近い値に固定したのである意味当たり前の結果です。本来の実験ではaについて皆目見当がつかない状態から始めるわけですから、何とかしてaの値もデータから推定したいところです。これには、Nの値を変えたりして複数の時系列データをまとめて推定するという方法が直感的によさそうに思われます。

[Tips] 実行中のエラーについて

実行中に次のようなエラーが出ることがあります

実行中のエラー
Chain 1 Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Chain 1 Exception: ode_rk45: ode parameters and data[1] is inf, but must be finite! (in 'C:/Users/XXXX/AppData/Local/Temp/RtmpCchH2Q/model-56cc1ee585.stan', line 44, column 2 to column 60)
Chain 1 If this warning occurs sporadically, such as for highly constrained variable types like covariance matrices, then the sampler is fine,      
Chain 1 but if this warning occurs often then your model may be either severely ill-conditioned or misspecified.

このエラーは、stanのODEソルバーが計算ができずにNaNを返した時に発生するようです。パラメータの探索中に極端な微分方程式が生成されてしまうことが原因です。エラーが出てもMCMC自体は進行します。warmup中に発生するのは問題ないと思われますが、sampling期間にも多発する場合は異常な結果につながることがあります。stanコード中のパラメータ範囲を見直してみると良いかもしれません。今回はwarmupとsampling期間を5000に延長し、adapt_deltaを0.9(デフォルト0.8)にして対処しました。

まとめ

今回は、想定している数理モデルからパラメータを推定できるのかどうかを事前に調べる手法を紹介しました。実験をする前に識別可能性解析&感度分析をして、実験計画を行うようにしましょう。頑張ってデータをとってもパラメータ推定に全く無意味だったらすごく悲しいことになってしまいます。もちろん、実際に分析する中で「MCMCが収束しない……」「事後分布が事前分布とほとんど変わらない……」などのトラブルがあった際にも役に立ちます。次回はこのモデルを複数時系列に拡張したいと思います。【第三回に続く】

Discussion