動的ハザードモデル(using Stan)
前置き
比例ハザードモデルの「比例」の部分を緩和する拡張はいくつか提案されている.(例えば Hemming, K., & Shaw, J. E. H. 2005. A class of parametric dynamic survival models. Lifetime data analysis, 11, 81-98.)が, あまり普及はしていない印象がある. その理由の一つは使いやすい実装が少ないせいかもしれない. Stanを使うとわりとかんたんに書ける.(といいつつも私は結構手こずったが…)
モデル
生存関数を
...
生存関数をハザードを使って書くと,
...
各時間区切りごとのイベント数を
と二項分布で書ける.
さらにハザード
ここでは
以上では記号をやや曖昧に書いてしまったが、次のStanのコードを見て納得してほしい.
data{
int<lower=1> N;
int<lower=1> M;
array[N] int<lower=0> Y;
array[N] int<lower=0> R;
array[N] int<lower=1,upper=M> time;
array[N] real x;
array[M] int<lower=1> delta;
}
parameters{
array[M] real alpha;
array[M] real beta;
real<lower=0> sigma_a;
real<lower=0> sigma_b;
}
model{
for(i in 1:N){
Y[i] ~ binomial_logit(R[i],alpha[time[i]]+beta[time[i]]*x[i]);
}
for(i in 2:M){
alpha[i] ~ normal(alpha[i-1], sqrt(delta[i])*sigma_a);
beta[i] ~ normal(beta[i-1], sqrt(delta[i])*sigma_b);
}
beta[1] ~ normal(0,5);
alpha[1] ~ normal(0,5);
sigma_b ~ student_t(3,0,1);
sigma_a ~ student_t(3,0,1);
}
generated quantities{
vector[M] S1;
vector[M] S2;
S1[1] = 1-inv_logit(alpha[1]);
S2[1] = 1-inv_logit(alpha[1]+beta[1]);
for(i in 2:M){
S1[i] = (1-inv_logit(alpha[i]))*S1[i-1];
S2[i] = (1-inv_logit(alpha[i]+beta[i]))*S2[i-1];
}
}
これをベースにして, 説明変数によって時間変化しない係数とする係数を使い分けるとか, 変化に制約をつけるとか(例えば単調であるとか), いろいろ考えられると思う(ちゃんと推定できるかはやってみないとわからない).
データ分析
R のパッケージ MASS に入っている Gehan データ(白血病についてのデータだったと思う)に対して, 生存関数を推定してみた.
生存関数をプロットしてみる. 帯は95%信用区間.
カプラン・マイヤー推定量(点線)と比較すると, このモデル(Stanで書いたほう)はあまり当てはまりが良くない気がする….
6-MP(薬)投与群と, 対照群では明確に差が見て取れる.
以下、Rのコード:
library(cmdstanr)
library(posterior)
library(bayesplot)
library(survival)
library(dplyr)
library(ggplot2)
library(broom)
library(tidyr)
library(pammtools)
data(gehan, package="MASS")
sfit <- survfit(Surv(time,cens)~treat, data=gehan)
#plot(sfit)
ut <- unique(sort(gehan$time))
gehan2 <-gehan %>%
group_by(time,treat) %>%
summarise(Y=sum(cens),n=n()) %>%
group_by(treat) %>%
mutate(R=rev(cumsum(rev(n)))) %>%
ungroup() %>%
mutate(time2 = as.integer(factor(time)))
head(gehan2)
delta <- diff(c(0,ut))
stan_dat <- list(N = nrow(gehan2),
Y = gehan2$Y,
R = gehan2$R,
M = n_distinct(gehan2$time2),
time = gehan2$time2,
delta = delta,
x=gehan2$treat=="6-MP")
stan_dat
mod <- cmdstan_model("./Documents/dynamichazardmodel2.stan")
#args(mod$sample)
fit_mcmc <- mod$sample(
data = stan_dat,
seed = 1234,
chains = 4,
parallel_chains = 4,
iter_warmup = 8000, iter_sampling = 2000)
mcmc_trace(fit_mcmc$draws(c("lp__","sigma_a","sigma_b","alpha[1]", "beta[1]")))+
theme_bw()
M <- stan_dat$M
S1 <- apply(fit_mcmc$draws("S1"),3,mean)
S2 <- apply(fit_mcmc$draws("S2"),3,mean)
S1_ci <- t(apply(fit_mcmc$draws("S1"),3,quantile, prob=c(0.025,0.975)))
S2_ci <- t(apply(fit_mcmc$draws("S2"),3,quantile, prob=c(0.025,0.975)))
colnames(S1_ci) <- c("lower","upper")
colnames(S2_ci) <- c("lower","upper")
dfs <- tidy(sfit)
dfS1 <- data.frame(ut,S1,S1_ci)
dfS2 <- data.frame(ut,S2,S2_ci)
ggplot(dfs)+
geom_step(aes(x=time,y=estimate,colour=strata), linetype=2)+
geom_step(data = dfS1, aes(x=ut,y=S1,colour="treat=control"))+
geom_stepribbon(data = dfS1, aes(x=ut,ymin=lower,ymax=upper,fill="treat=control"),alpha=0.1)+
geom_step(data = dfS2, aes(x=ut,y=S2,colour="treat=6-MP"))+
geom_stepribbon(data = dfS2, aes(x=ut,ymin=lower,ymax=upper,fill="treat=6-MP"),alpha=0.1)+
theme_bw(16)+labs(fill="strata")
ggsave("surv.png")
b_ci <- t(apply(fit_mcmc$draws("beta"),3,quantile, prob=c(0.025,0.975)))
b_hat <- apply(fit_mcmc$draws("beta"),3,mean)
colnames(b_ci) <- c("lower","upper")
df_b <- data.frame(mean=b_hat, b_ci,time=ut)
ggplot(df_b)+
geom_line(aes(x=time,y=mean), linetype=2)+
geom_ribbon(aes(x=time,ymin=lower,ymax=upper),alpha=0.1)+
theme_bw(16)
ggsave("beta.png")
Discussion