Cox比例ハザードモデルをlifelinesとscikit-survivalで比較してみた

2024/11/15に公開

はじめに

Cox比例ハザードモデルは生存時間解析におけるハザード関数を共変量(説明変数)も込めて回帰するための基本モデルである.ハザード関数が求まると共変量\bm{x},時刻tにおける生存率を表す生存関数S(\bm{x}; t)が求められる.

生存時間解析のPython実装についてはlifelinesとscikit-survivalというライブラリが代表的だと思うが,公式ドキュメントを見ると両者における生存関数の仕様(数式)が異なっていることがわかった[1][2]

この記事は実際に違いがあるのかをひとまず一つのデータセットで確認してみた記録である.

(結論としては,生存関数に関して言えば両者の値にほとんど差が無かったので,どちらのライブラリを使用してもよさそうである.当たり前か...)

生存時間解析の用語についてざっくり復習

共変量を考慮しない場合

"生存時間"を表す確率変数Tに興味がある.Tが確率密度関数f(t),累積確率密度関数F(t)で表される分布に従うとする.生存関数はS(t) = P(t \leq T) = - F(t)で定義される.ハザード関数の定義は

h(t) := \lim_{\Delta t \to 0} \frac{P(t \leq T < t + \Delta t \, | \, T \geq t)}{\Delta t}

である(時刻tにおける"瞬間死亡率").その上で

\begin{align*} h(t) &= \lim_{\Delta t \to 0} \frac{P(T < t + \Delta t) - P(T < t)}{P(T \geq t)} \frac{1}{\Delta t} \\ &= \lim_{\Delta t \to 0} \frac{F(t + \Delta t) - F(t)}{\Delta t} \frac{1}{S(t)} \\ &= \frac{F'(t)}{S(t)} = - \frac{S'(t)}{S(t)} = - \frac{d}{dt} \log{S(t)} \end{align*}

と変形できる.このことから,累積ハザード関数

H(t) = \int_{0}^{t} h(u) du

を用いて.生存関数を

S(t) = e^{- H(t)}

と表せる.

Cox比例ハザードモデル

Cox比例ハザードモデルとは,ハザード関数(共変量\bm{x}で表される個体の時刻tにおける"瞬間死亡率")が

h(t; \bm{x}) = h_0(t) e^{\bm{x}^{\top} \bm{\beta}}

という形であることを仮定したモデルである.時刻tの影響を表すh_0(t)はベースラインハザードと呼ばれ,共変量\bm{x}による影響と分離されていることが見てとれる.

このとき,累積ハザード関数は

H(t; \bm{x}) = \int_0^t h_0(u) e^{\bm{x}^{\top} \bm{\beta}} du = \int_0^t h_0(u) du \times e^{\bm{x}^{\top} \bm{\beta}} = H_0(t) e^{\bm{x}^{\top} \bm{\beta}}

となる.H_0(t)はベースライン累積ハザードと呼ばれる.これによって生存関数は

S(t; \bm{x}) = e^{- H(t; \bm{x})} = e^{- H_0(t) e^{\bm{x}^{\top} \bm{\beta}}} = \left( e^{- H_0(t)} \right)^{e^{\bm{x}^{\top} \bm{\beta}}} = S_0(t)^{e^{\bm{x}^{\top} \bm{\beta}}}

と表現できることになる.S_0(t) = e^{-H_0(t)}はベースライン生存関数と呼ばれる.

lifelinesとscikit-survivalによるモデル作成

lifelinesのrossiデータセットを使う.

# 各種ライブラリのインポート
from lifelines.datasets import load_rossi
from lifelines import CoxPHFitter
from sksurv.linear_model import CoxPHSurvivalAnalysis
import numpy as np
rossi = load_rossi()
print(rossi)
     week  arrest  fin  age  race  wexp  mar  paro  prio
0      20       1    0   27     1     0    0     1     3
1      17       1    0   18     1     0    0     1     8
2      25       1    0   19     0     1    0     1    13
3      52       0    1   23     1     1    1     1     1
4      52       0    0   19     0     1    0     1     3
..    ...     ...  ...  ...   ...   ...  ...   ...   ...
427    52       0    1   31     0     1    0     1     3
428    52       0    0   20     1     0    0     1     1
429    52       0    1   20     1     1    1     1     1
430    52       0    0   29     1     1    0     1     3
431    52       0    1   24     1     1    0     1     1

[432 rows x 9 columns]
各カラムの説明

このデータセットは再犯のリスク要因を分析するために収集されたようである.

  1. week: 追跡期間の週数.観測された期間の長さを表す.

  2. arrest: 再犯の有無.1は再犯があったこと,0は再犯が無かったことを示す.

  3. fin: 個人が追跡期間中に刑務所から出所したかどうか.1は出所したこと,0は出所していないことを示す.

  4. age: 対象者の年齢.

  5. race: 人種を示すカテゴリ変数.

  6. wexp: 雇用経験の有無.1は過去に雇用経験があること,0は雇用経験がないことを示す.

  7. mar: 結婚の有無.1は結婚していること,0は結婚していないことを示す.

  8. paro: 生活保護の有無.1は生活保護を受けていること,0は受けていないことを示す.

  9. prior: 過去の犯罪歴の数を示す変数.過去に逮捕された回数などを示す.

モデル作成にあたり,生存時間の打ち切り有無情報に関して加工が必要である.(lifelinesだと0/1の整数値でOKだが,scikit-survivalではTrue/Falseのブール値でなければいけない.)

def rossi_transform(rossi):
    # lifelines向けには加工無しでOK

    # scikit-survival向け
    X_ss = rossi[['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']]
    y_ss = rossi[["arrest", "week"]]
    y_ss["arrest"] = y_ss["arrest"].astype(bool)
    y_ss = np.array(list(zip(y_ss['arrest'], y_ss['week'])), dtype=[('arrest', '?'), ('week', '<f8')]) # イベントと時間を構造化配列に変換

    # 共変量
    X_cov = rossi[['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']]

    return X_ss, y_ss, X_cov

X_ss, y_ss, X_cov = rossi_transform(rossi)

# lifelines
model_ll = CoxPHFitter()
model_ll.fit(rossi, duration_col='week', event_col='arrest')

# scikit-survival
model_ss = CoxPHSurvivalAnalysis()
model_ss.fit(X_ss, y_ss)

これでモデルが作成できたので,両者の生存関数が一致しているか確認していこう.

lifelinesで生存関数を復元する試行

S_0(t) = e^{-H_0(t)}であることの確認をしよう.

ベースライン生存率が一致することの確認

# モデルのメソッドから直接呼び出し
base_surv_ll = model_ll.baseline_survival_.values

# ベースライン累積ハザードから計算
base_surv_calc_ll = np.exp( - model_ll.baseline_cumulative_hazard_).values

# 一致の判定
base_surv_ll.tolist() == base_surv_calc_ll.tolist()
True

よって一致している.

共変量の影響の確認

lifelinesにおける生存関数の定義は

S(t; \bm{x}) = \left( e^{-H_0(t)} \right)^{e^{\bm{x}^{\top} \bm{\beta}}} = S_0(t)^{e^{\bm{x}^{\top} \bm{\beta}}}

ではなく

S(t; \bm{x}) = \left( e^{-H_0(t)} \right)^{e^{(\bm{x} - \bar{\bm{x}})^{\top} \bm{\beta}}} = S_0(t)^{e^{(\bm{x} - \bar{\bm{x}})^{\top} \bm{\beta}}}

であることに注意する.(なお,\bar{\bm{x}}は共変量の平均である.)

# 中心化
X_cov_centerd = X_cov - X_cov.mean()

# 線形予測子
linear_predictor_ll = np.dot(X_cov_centerd, model_ll.params_)
# モデルのメソッドから直接呼び出し
surv_ll = model_ll.predict_survival_function(X_cov)[0].values

# 線形予測子から計算
surv_calc_ll = (model_ll.baseline_survival_ ** np.exp(linear_predictor_ll[0])).values

# 一致の判定
(surv_ll - surv_calc_ll).T
array([[ 0.00000000e+00,  0.00000000e+00,  1.11022302e-16,
        -1.11022302e-16,  1.11022302e-16,  1.11022302e-16,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.11022302e-16,  1.11022302e-16, -1.11022302e-16,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.11022302e-16,  1.11022302e-16,
         0.00000000e+00,  0.00000000e+00,  1.11022302e-16,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.11022302e-16, -1.11022302e-16,  1.11022302e-16,
         1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  1.11022302e-16, -1.11022302e-16,
         0.00000000e+00, -1.11022302e-16,  1.11022302e-16,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.11022302e-16,
        -1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  1.11022302e-16,
         0.00000000e+00]])

となってほぼ一致している.

scikit-survivalで生存関数を復元する試行

同様に,scikit-survivalでも実施してみる.

ベースライン生存率が一致することの確認

# モデルのメソッドから直接呼び出し
base_surv_ss = model_ss.baseline_survival_.y

# ベースライン累積ハザードから計算
base_surv_calc_ss = np.exp(- model_ss.cum_baseline_hazard_.y)

# 一致の判定
base_surv_ss.tolist() == base_surv_calc_ss.tolist()
True

となって一致する.

共変量の影響の確認

# 線形予測子
linear_predictor_ss = np.dot(X_cov, model_ss.coef_)

# モデルのメソッドから直接呼び出し
surv_ss = model_ss.predict_survival_function(X_cov)[0].y

# 線形予測子から計算
surv_calc_ss = model_ss.baseline_survival_.y ** np.exp(linear_predictor_ss[0])

# 一致の判定
surv_ss == surv_calc_ss
array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True])

となって一致する.

lifelinesとscikit-survivalの比較

共変量の中心化

lifelinesとscikit-survivalでは,生存関数の定義式が違うので単純比較はできない(S_0(t)^{e^{(\bm{x} - \bar{\bm{x}})^{\top} \beta}}S_0(t)^{e^{\bm{x}^{\top} \beta}}).平均を引いた(中心化した)共変量でモデルをつくることで双方の比較ができるようになると考えた.

rossi_centerd = rossi.copy()

col_list = ['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']

for col in col_list:
    rossi_centerd[col] = rossi_centerd[col] - rossi_centerd[col].mean()

X_ss_c, y_ss_c, X_cov_c = rossi_transform(rossi_centerd)

model_ll_c = CoxPHFitter()
model_ll_c.fit(rossi_centerd, duration_col='week', event_col='arrest')

model_ss_c = CoxPHSurvivalAnalysis()
model_ss_c.fit(X_ss_c, y_ss_c)

それでは比較していこう.

共変量の中心化前後でlifelines同士を比較

生存関数

surv_ll_c = model_ll_c.predict_survival_function(X_cov_c)[0].values

# 一致の判定
(surv_ll_c - surv_ll).T
array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  1.11022302e-16,  1.11022302e-16,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00, -1.11022302e-16,
       -1.11022302e-16])

となってほぼ一致している.

係数

また,

model_ll.params_ - model_ll_c.params_
covariate
fin    -2.220446e-16
age     4.163336e-17
race    5.551115e-17
wexp   -2.498002e-16
mar     2.220446e-16
paro    2.775558e-17
prio   -1.387779e-17
Name: coef, dtype: float64

と,係数もほぼ一致している.

ベースライン生存率

また,

model_ll.baseline_survival_.values.T - model_ll_c.baseline_survival_.values.T
array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
        0.00000000e+00, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.11022302e-16, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
        0.00000000e+00, 1.11022302e-16, 0.00000000e+00, 1.11022302e-16,
        1.11022302e-16, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
        1.11022302e-16, 1.11022302e-16, 2.22044605e-16, 1.11022302e-16,
        1.11022302e-16, 1.11022302e-16, 1.11022302e-16, 2.22044605e-16,
        1.11022302e-16, 1.11022302e-16, 2.22044605e-16, 1.11022302e-16,
        2.22044605e-16, 2.22044605e-16, 2.22044605e-16, 2.22044605e-16,
        2.22044605e-16]])

と,ベースライン生存率もほぼ一致している.

共変量の中心化前後でscikit-survival同士を比較

生存関数

surv_ss_c = model_ss_c.predict_survival_function(X_cov_c)[0].y

# 一致の判定
surv_ss_c - surv_ss
array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
       0.00000000e+00, 2.22044605e-16, 2.22044605e-16, 2.22044605e-16,
       2.22044605e-16, 3.33066907e-16, 3.33066907e-16, 2.22044605e-16,
       4.44089210e-16, 3.33066907e-16, 3.33066907e-16, 3.33066907e-16,
       4.44089210e-16, 4.44089210e-16, 3.33066907e-16, 5.55111512e-16,
       4.44089210e-16, 5.55111512e-16, 5.55111512e-16, 4.44089210e-16,
       5.55111512e-16, 5.55111512e-16, 5.55111512e-16, 6.66133815e-16,
       5.55111512e-16, 6.66133815e-16, 6.66133815e-16, 6.66133815e-16,
       6.66133815e-16, 7.77156117e-16, 8.88178420e-16, 7.77156117e-16,
       7.77156117e-16, 7.77156117e-16, 7.77156117e-16, 9.99200722e-16,
       8.88178420e-16, 8.88178420e-16, 8.88178420e-16, 9.99200722e-16,
       1.11022302e-15])

と,ほぼ一致している.

係数

model_ss_c.coef_ - model_ss.coef_
array([-1.66533454e-16, -6.73072709e-16, -3.60822483e-15,  1.08246745e-15,
        3.88578059e-16, -2.09554596e-15, -3.19189120e-16])

となって係数はほぼ一致している.

ベースライン生存率

model_ss.baseline_survival_.y - model_ss_c.baseline_survival_.y
array([-0.00480573, -0.00958362, -0.01432301, -0.01903026, -0.02371066,
       -0.02836888, -0.03301095, -0.05563968, -0.06453064, -0.06892717,
       -0.07768108, -0.08643973, -0.09076616, -0.10354943, -0.11197271,
       -0.12027714, -0.13246588, -0.14449435, -0.152395  , -0.17156456,
       -0.17902885, -0.1827159 , -0.18638324, -0.20066976, -0.21105287,
       -0.22130345, -0.22797978, -0.23453482, -0.24099696, -0.24419089,
       -0.25049132, -0.25667972, -0.26279742, -0.27466517, -0.2832675 ,
       -0.29430942, -0.29701113, -0.3023222 , -0.31254303, -0.31751823,
       -0.32706002, -0.33168084, -0.33619055, -0.34487569, -0.34698076,
       -0.35110341, -0.36089667, -0.36643967, -0.37347089])

ベースライン生存関数は結構差があるように見える.

生存関数S_0(t)^{e^{\bm{x}^{\top} \beta}}のうち,共変量\bm{x}を含む部分e^{\bm{x}^{\top} \bm{\beta}}の(係数\bm{\beta}は変わらないが)\bm{x}の中心化による値の変化とベースライン生存関数S_0(t)のずれが合わさることで,結局,共変量を含めた生存関数自体はほぼ一致する.

共変量の中心化前後でlifelinesとscikit-survivalを比較

# そのままのデータで学習したモデルに対するllとssの差
dif = surv_ll - surv_ss
dif
array([3.96366097e-06, 7.93129322e-06, 1.18894659e-05, 1.58473963e-05,
       1.97949853e-05, 2.37423061e-05, 2.76800911e-05, 4.71303772e-05,
       5.48339525e-05, 5.86833668e-05, 6.62248655e-05, 7.33776931e-05,
       7.69421130e-05, 8.75713039e-05, 9.45252643e-05, 1.01457027e-04,
       1.11806223e-04, 1.21987668e-04, 1.28732353e-04, 1.45413135e-04,
       1.52074106e-04, 1.55407123e-04, 1.58706599e-04, 1.71751591e-04,
       1.81437826e-04, 1.90549378e-04, 1.96609923e-04, 2.02644165e-04,
       2.08611439e-04, 2.11590860e-04, 2.17466781e-04, 2.23322323e-04,
       2.29052569e-04, 2.40302553e-04, 2.48599246e-04, 2.59544098e-04,
       2.62223803e-04, 2.67558896e-04, 2.78043489e-04, 2.83202635e-04,
       2.93351117e-04, 2.98302705e-04, 3.03184143e-04, 3.12737681e-04,
       3.15117235e-04, 3.19847253e-04, 3.31299341e-04, 3.38050019e-04,
       3.46467560e-04])
# 中心化したデータで学習したモデルに対するllとssの差
dif_c = surv_ll_c - surv_ss_c

dif_c
array([3.96366097e-06, 7.93129322e-06, 1.18894659e-05, 1.58473963e-05,
       1.97949853e-05, 2.37423061e-05, 2.76800911e-05, 4.71303772e-05,
       5.48339525e-05, 5.86833668e-05, 6.62248655e-05, 7.33776931e-05,
       7.69421130e-05, 8.75713039e-05, 9.45252643e-05, 1.01457027e-04,
       1.11806223e-04, 1.21987668e-04, 1.28732353e-04, 1.45413135e-04,
       1.52074106e-04, 1.55407123e-04, 1.58706599e-04, 1.71751591e-04,
       1.81437826e-04, 1.90549378e-04, 1.96609923e-04, 2.02644165e-04,
       2.08611439e-04, 2.11590860e-04, 2.17466781e-04, 2.23322323e-04,
       2.29052569e-04, 2.40302553e-04, 2.48599246e-04, 2.59544098e-04,
       2.62223803e-04, 2.67558896e-04, 2.78043489e-04, 2.83202635e-04,
       2.93351117e-04, 2.98302705e-04, 3.03184143e-04, 3.12737681e-04,
       3.15117235e-04, 3.19847253e-04, 3.31299341e-04, 3.38050019e-04,
       3.46467560e-04])

となる.結局,共変量の中心化の有無によらず,lifelinesとscikit-survivalの生存関数はほぼ一致しているようである.

脚注
  1. https://lifelines.readthedocs.io/en/latest/Survival Regression.html#cox-s-proportional-hazard-model ↩︎

  2. https://scikit-survival.readthedocs.io/en/stable/user_guide/00-introduction.html#Multivariate-Survival-Models ↩︎

Discussion