Cox比例ハザードモデルをlifelinesとscikit-survivalで比較してみた
はじめに
Cox比例ハザードモデルは生存時間解析におけるハザード関数を共変量(説明変数)も込めて回帰するための基本モデルである.ハザード関数が求まると共変量
生存時間解析のPython実装についてはlifelinesとscikit-survivalというライブラリが代表的だと思うが,公式ドキュメントを見ると両者における生存関数の仕様(数式)が異なっていることがわかった[1][2].
この記事は実際に違いがあるのかをひとまず一つのデータセットで確認してみた記録である.
(結論としては,生存関数に関して言えば両者の値にほとんど差が無かったので,どちらのライブラリを使用してもよさそうである.当たり前か...)
生存時間解析の用語についてざっくり復習
共変量を考慮しない場合
"生存時間"を表す確率変数
である(時刻
と変形できる.このことから,累積ハザード関数
を用いて.生存関数を
と表せる.
Cox比例ハザードモデル
Cox比例ハザードモデルとは,ハザード関数(共変量
という形であることを仮定したモデルである.時刻
このとき,累積ハザード関数は
となる.
と表現できることになる.
lifelinesとscikit-survivalによるモデル作成
lifelinesのrossi
データセットを使う.
# 各種ライブラリのインポート
from lifelines.datasets import load_rossi
from lifelines import CoxPHFitter
from sksurv.linear_model import CoxPHSurvivalAnalysis
import numpy as np
rossi = load_rossi()
print(rossi)
week arrest fin age race wexp mar paro prio
0 20 1 0 27 1 0 0 1 3
1 17 1 0 18 1 0 0 1 8
2 25 1 0 19 0 1 0 1 13
3 52 0 1 23 1 1 1 1 1
4 52 0 0 19 0 1 0 1 3
.. ... ... ... ... ... ... ... ... ...
427 52 0 1 31 0 1 0 1 3
428 52 0 0 20 1 0 0 1 1
429 52 0 1 20 1 1 1 1 1
430 52 0 0 29 1 1 0 1 3
431 52 0 1 24 1 1 0 1 1
[432 rows x 9 columns]
各カラムの説明
このデータセットは再犯のリスク要因を分析するために収集されたようである.
-
week
: 追跡期間の週数.観測された期間の長さを表す. -
arrest
: 再犯の有無.1は再犯があったこと,0は再犯が無かったことを示す. -
fin
: 個人が追跡期間中に刑務所から出所したかどうか.1は出所したこと,0は出所していないことを示す. -
age
: 対象者の年齢. -
race
: 人種を示すカテゴリ変数. -
wexp
: 雇用経験の有無.1は過去に雇用経験があること,0は雇用経験がないことを示す. -
mar
: 結婚の有無.1は結婚していること,0は結婚していないことを示す. -
paro
: 生活保護の有無.1は生活保護を受けていること,0は受けていないことを示す. -
prior
: 過去の犯罪歴の数を示す変数.過去に逮捕された回数などを示す.
モデル作成にあたり,生存時間の打ち切り有無情報に関して加工が必要である.(lifelinesだと0/1の整数値でOKだが,scikit-survivalではTrue/Falseのブール値でなければいけない.)
def rossi_transform(rossi):
# lifelines向けには加工無しでOK
# scikit-survival向け
X_ss = rossi[['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']]
y_ss = rossi[["arrest", "week"]]
y_ss["arrest"] = y_ss["arrest"].astype(bool)
y_ss = np.array(list(zip(y_ss['arrest'], y_ss['week'])), dtype=[('arrest', '?'), ('week', '<f8')]) # イベントと時間を構造化配列に変換
# 共変量
X_cov = rossi[['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']]
return X_ss, y_ss, X_cov
X_ss, y_ss, X_cov = rossi_transform(rossi)
# lifelines
model_ll = CoxPHFitter()
model_ll.fit(rossi, duration_col='week', event_col='arrest')
# scikit-survival
model_ss = CoxPHSurvivalAnalysis()
model_ss.fit(X_ss, y_ss)
これでモデルが作成できたので,両者の生存関数が一致しているか確認していこう.
lifelinesで生存関数を復元する試行
ベースライン生存率が一致することの確認
# モデルのメソッドから直接呼び出し
base_surv_ll = model_ll.baseline_survival_.values
# ベースライン累積ハザードから計算
base_surv_calc_ll = np.exp( - model_ll.baseline_cumulative_hazard_).values
# 一致の判定
base_surv_ll.tolist() == base_surv_calc_ll.tolist()
True
よって一致している.
共変量の影響の確認
lifelinesにおける生存関数の定義は
ではなく
であることに注意する.(なお,
# 中心化
X_cov_centerd = X_cov - X_cov.mean()
# 線形予測子
linear_predictor_ll = np.dot(X_cov_centerd, model_ll.params_)
# モデルのメソッドから直接呼び出し
surv_ll = model_ll.predict_survival_function(X_cov)[0].values
# 線形予測子から計算
surv_calc_ll = (model_ll.baseline_survival_ ** np.exp(linear_predictor_ll[0])).values
# 一致の判定
(surv_ll - surv_calc_ll).T
array([[ 0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
-1.11022302e-16, 1.11022302e-16, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.11022302e-16, 1.11022302e-16, -1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.11022302e-16, -1.11022302e-16, 1.11022302e-16,
1.11022302e-16, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.11022302e-16, -1.11022302e-16,
0.00000000e+00, -1.11022302e-16, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
-1.11022302e-16, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
0.00000000e+00]])
となってほぼ一致している.
scikit-survivalで生存関数を復元する試行
同様に,scikit-survivalでも実施してみる.
ベースライン生存率が一致することの確認
# モデルのメソッドから直接呼び出し
base_surv_ss = model_ss.baseline_survival_.y
# ベースライン累積ハザードから計算
base_surv_calc_ss = np.exp(- model_ss.cum_baseline_hazard_.y)
# 一致の判定
base_surv_ss.tolist() == base_surv_calc_ss.tolist()
True
となって一致する.
共変量の影響の確認
# 線形予測子
linear_predictor_ss = np.dot(X_cov, model_ss.coef_)
# モデルのメソッドから直接呼び出し
surv_ss = model_ss.predict_survival_function(X_cov)[0].y
# 線形予測子から計算
surv_calc_ss = model_ss.baseline_survival_.y ** np.exp(linear_predictor_ss[0])
# 一致の判定
surv_ss == surv_calc_ss
array([ True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True])
となって一致する.
lifelinesとscikit-survivalの比較
共変量の中心化
lifelinesとscikit-survivalでは,生存関数の定義式が違うので単純比較はできない(
rossi_centerd = rossi.copy()
col_list = ['fin', 'age', 'race', 'wexp', 'mar', 'paro', 'prio']
for col in col_list:
rossi_centerd[col] = rossi_centerd[col] - rossi_centerd[col].mean()
X_ss_c, y_ss_c, X_cov_c = rossi_transform(rossi_centerd)
model_ll_c = CoxPHFitter()
model_ll_c.fit(rossi_centerd, duration_col='week', event_col='arrest')
model_ss_c = CoxPHSurvivalAnalysis()
model_ss_c.fit(X_ss_c, y_ss_c)
それでは比較していこう.
共変量の中心化前後でlifelines同士を比較
生存関数
surv_ll_c = model_ll_c.predict_survival_function(X_cov_c)[0].values
# 一致の判定
(surv_ll_c - surv_ll).T
array([ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.11022302e-16, 1.11022302e-16, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 1.11022302e-16, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, -1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, -1.11022302e-16,
-1.11022302e-16])
となってほぼ一致している.
係数
また,
model_ll.params_ - model_ll_c.params_
covariate
fin -2.220446e-16
age 4.163336e-17
race 5.551115e-17
wexp -2.498002e-16
mar 2.220446e-16
paro 2.775558e-17
prio -1.387779e-17
Name: coef, dtype: float64
と,係数もほぼ一致している.
ベースライン生存率
また,
model_ll.baseline_survival_.values.T - model_ll_c.baseline_survival_.values.T
array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.11022302e-16, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.11022302e-16,
0.00000000e+00, 1.11022302e-16, 0.00000000e+00, 1.11022302e-16,
1.11022302e-16, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
1.11022302e-16, 1.11022302e-16, 2.22044605e-16, 1.11022302e-16,
1.11022302e-16, 1.11022302e-16, 1.11022302e-16, 2.22044605e-16,
1.11022302e-16, 1.11022302e-16, 2.22044605e-16, 1.11022302e-16,
2.22044605e-16, 2.22044605e-16, 2.22044605e-16, 2.22044605e-16,
2.22044605e-16]])
と,ベースライン生存率もほぼ一致している.
共変量の中心化前後でscikit-survival同士を比較
生存関数
surv_ss_c = model_ss_c.predict_survival_function(X_cov_c)[0].y
# 一致の判定
surv_ss_c - surv_ss
array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 1.11022302e-16, 1.11022302e-16,
0.00000000e+00, 2.22044605e-16, 2.22044605e-16, 2.22044605e-16,
2.22044605e-16, 3.33066907e-16, 3.33066907e-16, 2.22044605e-16,
4.44089210e-16, 3.33066907e-16, 3.33066907e-16, 3.33066907e-16,
4.44089210e-16, 4.44089210e-16, 3.33066907e-16, 5.55111512e-16,
4.44089210e-16, 5.55111512e-16, 5.55111512e-16, 4.44089210e-16,
5.55111512e-16, 5.55111512e-16, 5.55111512e-16, 6.66133815e-16,
5.55111512e-16, 6.66133815e-16, 6.66133815e-16, 6.66133815e-16,
6.66133815e-16, 7.77156117e-16, 8.88178420e-16, 7.77156117e-16,
7.77156117e-16, 7.77156117e-16, 7.77156117e-16, 9.99200722e-16,
8.88178420e-16, 8.88178420e-16, 8.88178420e-16, 9.99200722e-16,
1.11022302e-15])
と,ほぼ一致している.
係数
model_ss_c.coef_ - model_ss.coef_
array([-1.66533454e-16, -6.73072709e-16, -3.60822483e-15, 1.08246745e-15,
3.88578059e-16, -2.09554596e-15, -3.19189120e-16])
となって係数はほぼ一致している.
ベースライン生存率
model_ss.baseline_survival_.y - model_ss_c.baseline_survival_.y
array([-0.00480573, -0.00958362, -0.01432301, -0.01903026, -0.02371066,
-0.02836888, -0.03301095, -0.05563968, -0.06453064, -0.06892717,
-0.07768108, -0.08643973, -0.09076616, -0.10354943, -0.11197271,
-0.12027714, -0.13246588, -0.14449435, -0.152395 , -0.17156456,
-0.17902885, -0.1827159 , -0.18638324, -0.20066976, -0.21105287,
-0.22130345, -0.22797978, -0.23453482, -0.24099696, -0.24419089,
-0.25049132, -0.25667972, -0.26279742, -0.27466517, -0.2832675 ,
-0.29430942, -0.29701113, -0.3023222 , -0.31254303, -0.31751823,
-0.32706002, -0.33168084, -0.33619055, -0.34487569, -0.34698076,
-0.35110341, -0.36089667, -0.36643967, -0.37347089])
ベースライン生存関数は結構差があるように見える.
生存関数
共変量の中心化前後でlifelinesとscikit-survivalを比較
# そのままのデータで学習したモデルに対するllとssの差
dif = surv_ll - surv_ss
dif
array([3.96366097e-06, 7.93129322e-06, 1.18894659e-05, 1.58473963e-05,
1.97949853e-05, 2.37423061e-05, 2.76800911e-05, 4.71303772e-05,
5.48339525e-05, 5.86833668e-05, 6.62248655e-05, 7.33776931e-05,
7.69421130e-05, 8.75713039e-05, 9.45252643e-05, 1.01457027e-04,
1.11806223e-04, 1.21987668e-04, 1.28732353e-04, 1.45413135e-04,
1.52074106e-04, 1.55407123e-04, 1.58706599e-04, 1.71751591e-04,
1.81437826e-04, 1.90549378e-04, 1.96609923e-04, 2.02644165e-04,
2.08611439e-04, 2.11590860e-04, 2.17466781e-04, 2.23322323e-04,
2.29052569e-04, 2.40302553e-04, 2.48599246e-04, 2.59544098e-04,
2.62223803e-04, 2.67558896e-04, 2.78043489e-04, 2.83202635e-04,
2.93351117e-04, 2.98302705e-04, 3.03184143e-04, 3.12737681e-04,
3.15117235e-04, 3.19847253e-04, 3.31299341e-04, 3.38050019e-04,
3.46467560e-04])
# 中心化したデータで学習したモデルに対するllとssの差
dif_c = surv_ll_c - surv_ss_c
dif_c
array([3.96366097e-06, 7.93129322e-06, 1.18894659e-05, 1.58473963e-05,
1.97949853e-05, 2.37423061e-05, 2.76800911e-05, 4.71303772e-05,
5.48339525e-05, 5.86833668e-05, 6.62248655e-05, 7.33776931e-05,
7.69421130e-05, 8.75713039e-05, 9.45252643e-05, 1.01457027e-04,
1.11806223e-04, 1.21987668e-04, 1.28732353e-04, 1.45413135e-04,
1.52074106e-04, 1.55407123e-04, 1.58706599e-04, 1.71751591e-04,
1.81437826e-04, 1.90549378e-04, 1.96609923e-04, 2.02644165e-04,
2.08611439e-04, 2.11590860e-04, 2.17466781e-04, 2.23322323e-04,
2.29052569e-04, 2.40302553e-04, 2.48599246e-04, 2.59544098e-04,
2.62223803e-04, 2.67558896e-04, 2.78043489e-04, 2.83202635e-04,
2.93351117e-04, 2.98302705e-04, 3.03184143e-04, 3.12737681e-04,
3.15117235e-04, 3.19847253e-04, 3.31299341e-04, 3.38050019e-04,
3.46467560e-04])
となる.結局,共変量の中心化の有無によらず,lifelinesとscikit-survivalの生存関数はほぼ一致しているようである.
Discussion