GBDT系モデルで生存予測分析
こんにちは。LLMが流行っていますが、たまにはGBDT系のモデルも思い出してあげて下さい。
先日kaggle上で、生存予測分析がお題の「CIBMTR - Equity in post-HCT Survival Predictions」というコンペが開催されました(通称CIBMTR)。個人的に、生存予測分析をお馴染みのGBDT系モデルに落とし込む方法がとても勉強になったので、公開ノートブックや上位ソリューションにおける手法を備忘録がてらここにまとめておこうかと思います。なお、生存予測分析におけるそれぞれの手法や理論の深掘りは行いませんので、ご了承ください。
導入
生存予測分析の簡単な紹介
生存予測分析は、ある対象(患者や機械など)が特定のイベント(死亡、機械の故障など)に至るまでの期間を予測する分析手法です。古典的な手法では、ハザード関数と呼ばれる「ある時点
例えば、生存予測を行う有名なモデルとして、回帰によってハザード関数を推定するCox比例ハザードモデル(Cox回帰)などがあります。Cox比例ハザードモデルにおけるハザード関数は以下の通りです。
-
:h(t | X) におけるハザード率(イベント発生のリスク)t -
:ベースラインハザード(基準となるハザード率。相対的なリスクスコアを出すだけであればこれは無視できる。)h_0(t) -
:説明変数(年齢、性別、治療法など)X_1, X_2, ..., X_p -
:説明変数に対応する係数\beta_1, \beta_2, ..., \beta_p
データの特徴としては、観測打ち切りデータが含まれる点が挙げられます。これを表現するための重要なカラムとして、「イベントが起きたかどうか」と「観測期間」が存在します。観測期間とは、対象を観察し、イベントが起きるかどうかを確認するために設けられた期間のことを指します。したがって、もしイベントが発生した場合、発生するまでの期間が格納されます。一方でイベントが発生しなかった場合は、なんらかの理由で打ち切られるまでの期間が格納されます。
コンペの紹介
コンペのデータを例に説明しますので、CIBMTRコンペがどのようなタスクだったのかを簡単に紹介します。
CIBMTRでは、造血細胞移植という治療を受けた患者の生存予測がお題でした。予測としては、イベントが起きる(=死亡)確率を出力する必要がありました。データは、survival GANというモデルを用いてシミュレートされたテーブルデータです。造血細胞移植後に死亡が観測されたかどうかが「efs」列に、観測期間が「efs_time」列に入っています。
評価指標はc-indexというものを少しカスタマイズしたものが用いられています。c-indexはモデルの予測が実際の生存時間の順序をどれだけ正しく捉えているかを測る指標です。データをペアにしてモデルの予測リスクが小さい方が実際に長生きしていれば「正解」とみなし、この正解数を元にスコアが計算されます。つまり、絶対的な値よりも相対的な順序が大切であり、言ってしまえば生存予測版のAUCというイメージです。
ベースラインの手法
CIBMTRコンペ期間中に公開されていた手法として、以下の2つを紹介します。
ターゲットを生成し回帰タスクに落とし込む
イベント列と観測期間列から、ルールに基づく計算方法でターゲットを生成し、お馴染みの回帰タスクに落とし込む方法です。例えば、KaplanMeierFitterを用いた変換が挙げられます。これはデータを観測期間順にソートし、ある時点
-
: イベント(死亡など)が発生する時間t_i -
: 時点d_i でイベントが発生した人数t_i -
: 時点n_i 直前まで生存していた人数t_i
これはlifelinesというパッケージを用いて以下のようにサクッと求めることができます。
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
kmf.fit(durations=train['efs_time'], event_observed=train['efs'])
train['target'] = kmf.survival_function_at_times(train['efs_time']).values
また、lifelinesのドキュメントを確認すると、たくさんの種類のフィッターが用意されているのがわかります。CIBMTRにおいても、様々な種類のフィッター使ってアンサンブルするノートブックが公開されていました。
少し毛色の違う方法として、フィッターを使わずに、より直感的な変換を行っているノートブックもありました。例えば以下の手順です。
- イベントが発生したデータにおいて、観測期間を変換しリスクスコアとする(rank変換や逆数のQuantile標準化など)
- イベントが発生していないデータは、1.の最小値から定数値を引き算した値に置換する
このような単純なルールに則って計算されたtargetによるモデルは、KaplanMeierと同等かそれ以上のスコアを達成していました。
ライブラリに用意されている生存予測用目的関数を用いる
XGBoost、CatBoostには生存予測を解くための目的関数が何種類か実装されています。ちなみにLightGBMには残念ながら存在していませんでした。
例えば、前述したCox比例ハザードモデルは、XGBoostの目的関数(objective)として設定することができます。以下に簡単なコードの例を示します。ポイントはobjectibeパラメータの指定と、ターゲットの生成方法です。
params = {
"objective": "survival:cox", # Cox比例ハザードモデル用の目的関数
"eval_metric": "cox-nloglik",
}
# Coxloss用のターゲット変換。生存データの観測期間に-1を掛ける。
train['target'] = train['efs_time'].copy()
train.loc[train['efs'] == 0, 'target'] *= -1
model = xgb.XGBRegressor(**params)
# 以下、回帰モデルと同様に学習/予測
同じようなCox比例ハザードモデル実装は、CatBoostにも存在しています。
もう一つ、Accelerated Failure Timeモデルというモデルも、XGBoostとCatBoost両方に実装されています。AFTモデルに関してはこちらのdiscussionがわかりやすいでしょう(丸投げですみません)。上記に紐づくノートブックを確認すると、若干ターゲットの生成に癖がありますが、このモデルも比較的容易に使用することができます。ただし、AFTモデルにおけるXGBoostのsklearn APIが現時点で存在しないようなのでその点は注意です。
上位入賞者の手法
次のセクションとして、CIBMTRの上位陣が用いていた解法の一つを紹介します。一言で言ってしまうと、クラス分類と回帰にタスクを分解するという手法です。
クラス分類の方は、愚直にイベントが起きたか起きていないかの二値分類を行います。回帰の方は、学習時はイベントが起きたデータのみに対して、イベントが起きるまでの期間を回帰として予測します。予測はすべてのデータにイベントが起きたと仮定して、予測値を出します。これらを掛け合わせることで、最終的な予測を生成します。
ここから各ソリューションで工夫が見られ、例えばefs_timeをlogやrankに変換したり、efs==0側も同様の式で計算したりなどなど行なっていました。
余談ですが、「予測はすべてのデータにイベントが起きたと仮定」しても回帰予測が機能するというのが、今回のデータがシミュレートデータである故の結果なのか、汎用的に効果がある手法なのかは個人的に少し疑問です。(学習時に含まれてないデータも予測することになるので、外挿になりそうな気もする)
まとめ
kaggleで開催されたコンペを例に、生存予測分析をGBDT系モデルで解く方法を紹介しました。特に回帰に落とし込む方法に関しては、GBDT系以外のモデルにも汎用的に使える手法だと思います。上位陣の詳しい解法が気になった方、また今回紹介しなかった深層学習を用いた解法を知りたい場合はぜひkaggleのdiscussionを覗いてみてください。
参考
- CIBMTRのディスカッションとかノートブックたくさん
- https://qiita.com/nakey_tdse/items/b40238599395653a7965
Discussion