💭

機械学習による因果推論の感度分析入門 ~DMLを使った手法~

に公開

はじめに

因果推論という本で DML(Double Machine Learning) を使った感度分析が紹介されていました。
ですが、ほぼ名前を出して終わりのサラッとした紹介のみだったので、具体的に分析で使うにはどうすれば良いのか気になり調べてみました。

因果推論の感度分析

機械学習を用いない因果推論の感度分析の有名な手法としては、Imbens(2003)の手法E-Value を用いた方法があります。
一方、機械学習を用いた感度分析の手法としては、Austen PlotDML を用いた手法があります。

Austen Plot は 2020 年、DML を用いる方法は 2022 年に発表と、機械学習を用いた感度分析は因果推論の領域の中でも比較的新しい分野になっており、
そのためまだデファクトスタンダードとなる手法が固まっていないようです。

ただ、DML を用いた手法については以下のライブラリなどで既に実装がなされており、DoubleMLDoWhyといったライブラリで比較的手軽に試すことが出来るようになっています。

この記事では、 DoubleML ライブラリ を使って DML を使った感度分析のチュートリアルをやってみます。

DML を用いた感度分析やってみる

以降の内容はこちらの notebook を参考にしています。
https://docs.doubleml.org/stable/examples/py_double_ml_sensitivity.html

0. 準備

まずは、必要なライブラリをインストールします。

pip install doubleml nbformat numpy pandas scikit-learn

必要な import 文 は以下です。

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import doubleml as dml
from doubleml.datasets import make_confounded_plr_data, fetch_401K

1. サンプルデータの作成

DoubleML 組み込みの make_confounded_plr_data() を使うと、観測できない交絡因子の影響を受けたデータセットを簡単に作成できます。

データ作成にあたり設定するパラメータは以下です。

  • cf_y: 結果変数を予測するモデルで計算された残差の分散のうち、未観測の交絡因子で説明される割合
  • cf_d: 処置変数を予測するモデルで計算されたの残差の分散のうち、未観測の交絡因子で説明される割合
  • theta: 真の処置効果

cf_ycf_dの説明が難しいですね。
めちゃめちゃざっくりな説明をすると、cf_ycf_dはそれぞれ「観測できない交絡因子の強さ」を表します。

ではデータを作ります。

# 観測されない交絡因子の強さの設定値
cf_y = 0.1
cf_d = 0.1
# 真の処置効果の設定値
theta = 5.0

np.random.seed(42)
dpg_dict = make_confounded_plr_data(n_obs=1000, cf_y=cf_y, cf_d=cf_d, theta=theta)
x_cols = [f"X{i + 1}" for i in np.arange(dpg_dict["x"].shape[1])]
df = pd.DataFrame(
    np.column_stack((dpg_dict["x"], dpg_dict["y"], dpg_dict["d"])),
    columns=x_cols + ["y", "d"],
)

作成されたdfはこんな感じになっています。
X1X4 が共変量、y が結果変数、d が処置変数になっています。

上記で設定したとおり、観測できない交絡因子も dataframe には入っていませんが内在的に存在しており、その影響をydが受けています。

Index X1 X2 X3 X4 y d
38 -0.680025 0.232254 0.293072 -0.714351 197.567781 2.585162
8 -0.013497 -1.057711 0.822545 -1.220844 162.394830 -0.635069
327 1.085896 0.474698 -0.025027 0.817766 250.706827 -0.370320
430 0.513106 -0.259547 0.738810 0.615367 212.628500 -1.003343
75 -0.828995 -0.560181 0.747294 0.610370 171.618438 0.739475

2. DML による因果推論の実行

感度分析をするためにはまず普通に因果推論をする必要があるので、これは DML で行います。
より具体的には、DML の中でも最も基本的な手法である、部分線形回帰モデル(Partially Linear Regression: PLR)を使った手法を用います。

これは DoubleML ライブラリのDoubleMLPLRクラスを使うと一撃です。

DML についての説明は今回の記事では省略しますが、
DML では 2 つの機械学習モデルを作る必要があるため、ここでは Random Forest を使ってそれぞれのモデルを構築します。

np.random.seed(42)

dml_data = dml.DoubleMLData(df, "y", "d")
dml_obj = dml.DoubleMLPLR(
    dml_data,
    ml_l=RandomForestRegressor(),
    ml_m=RandomForestRegressor(),
    n_folds=5,
    score="partialling out",
)
dml_obj.fit()
print(dml_obj)

print される結果は以下のようになります。

================== DoubleMLPLR Object ==================

------------------ Data Summary      ------------------
Outcome variable: y
Treatment variable(s): ['d']
Covariates: ['X1', 'X2', 'X3', 'X4']
Instrument variable(s): None
No. Observations: 1000


------------------ Score & Algorithm ------------------
Score function: partialling out

------------------ Machine Learner   ------------------
Learner ml_l: RandomForestRegressor()
Learner ml_m: RandomForestRegressor()
Out-of-sample Performance:
Regression:
Learner ml_l RMSE: [[11.82684324]]
Learner ml_m RMSE: [[1.11552911]]

------------------ Resampling        ------------------
No. folds: 5
No. repeated sample splits: 1

------------------ Fit Summary       ------------------
       coef   std err        t         P>|t|     2.5 %    97.5 %
d  4.130122  0.455448  9.06827  1.209219e-19  3.237461  5.022783

結果の中で大事な点は「Fit Summary」の箇所です。
coefつまり推定された処置効果の値が約4.13となっており、上記でtheta=5.0と設定した真の値とはズレてしまっていることが分かります。

これは予想通りの結果で、make_confounded_plr_data()でデータを作る際にcf_ycf_dを設定して「観測できない交絡因子」の影響をデータに与えていることが起因しています。

3. DML を用いた感度分析

ここからようやく本題です!

DoubleML ライブラリで DML を用いた感度分析を実行するには、sensitivity_analysis()メソッドを使用すれば OK です。

sensitivity_analysis()の実行時に指定できるパラメータは以下です。

  • cf_y
    • make_confounded_plr_data()でデータを作成したときに設定したパラメータと意味は同じ
    • デフォルト値は0.03
  • cf_d
    • make_confounded_plr_data()でデータを作成したときに設定したパラメータと意味は同じ
    • デフォルト値は0.03
  • rho
    • 結果変数に対する長形式(full model)と短形式(reduced model)の出力の差と、 Riesz representer の間の相関を表します(ムズカシイ...)
    • 雑な説明をすると、観測できない交絡因子が処置変数と結果変数にどのくらいダイレクトに影響を与えるかを表す指標です。
    • デフォルト値は1.0となっており、これは観測できない交絡因子の影響が最大限発生するケースで、「こんな悲観的な条件で検証してロバストなら問題ないでしょ」と判断するための保守的な値になっています。
    • ちなみに、rho-1.0~1.0の値を取りますが、感度分析では絶対値のみ評価されるので符号は関係ないです。
  • null_hypothesis
    • ロバストネス値(後述)を計算する際に使用する、帰無仮説の処置効果の値を指定します。
    • デフォルト値は0.0、つまり処置効果ゼロを表します。

また、sensitivity_analysis()の内部では統計的なばらつきを考慮するために有意水準(デフォルト値は 0.95 )が考慮されて分析がなされます。

では実際に感度分析をやってみます。

dml_obj.sensitivity_analysis(cf_y=0.1, cf_d=0.1, rho=1.0, null_hypothesis=0.0)
print(dml_obj.sensitivity_summary)

print される結果は以下のようになります。

================== Sensitivity Analysis ==================

------------------ Scenario          ------------------
Significance Level: level=0.95
Sensitivity parameters: cf_y=0.1; cf_d=0.1, rho=1.0

------------------ Bounds with CI    ------------------
   CI lower  theta lower     theta  theta upper  CI upper
d  2.392091      3.09976  4.132084     5.164407  5.976682

------------------ Robustness Values ------------------
   H_0     RV (%)    RVa (%)
d  0.0  34.220072  29.692382

ここで注意点としては、
実際の感度分析ではcf_ycf_dの真の値(今回だと両者とも0.1)は当然分からない状態で分析をすることになります。

ここでは、あくまでチュートリアルなので、真の値をそのまま感度分析でセットしていると解釈してください。

感度分析の結果の見方について説明していきます。

Bounds with CI

  • theta
    • 観測できない交絡因子を考慮せずに因果推論した際のthetaの値を表します。
  • theta lower, theta upper
    • DML のモデルやcf_y, cf_d, rhoの設定に基づいて推定されたthetaの下限/上限を表します。
  • CI lower, CI upper
    • theta lower, theta upperをさらに統計的なばらつきを考慮した上でのthetaの下限/上限を表します。

今回の結果の解釈としては、
処置効果はcf_y=0.1, cf_d=0.1, rho=1.0の条件下では、観測できない交絡因子の影響で最小で約2.39、最大で約5.98まで歪められる、ということになります。

確かに、真の処置効果theta=5.0はこの範囲に含まれているので、この結果は妥当そうだと言えます。

ロバストネス値(RV)

RV とは、設定した帰無仮説が成立するには、観測できない交絡因子が結果変数と処置変数の両方の残差の分散をどのくらい説明する必要があるか、を表す指標です。

めちゃめちゃ雑に言うと、この値が大きいほどロバストです。(ほんとに雑な説明)

今回の結果の解釈としては、
cf_y=0.34, cf_d=0.34のときに帰無仮説が成立、つまり処置効果ゼロになる、ということになります。

0.34という値はそこそこ大きいので、今回の結果は比較的ロバストだと言えます。
実際、真の値はtheta=5.0で処置効果は全然ゼロではないので、「ロバストそう」という解釈は妥当です。

調整済みロバストネス値(RVa)

RV を統計的なばらつきを考慮して調整した指標が RVa です。
RV より厳しめな値が出ます。

cf_ycf_dを変動させた際のプロット

上記のように感度分析の結果をテキストで出すのも良いですが、
cf_ycf_dの値を固定している点と、やっぱり何かしらのグラフで見れると分かりやすい、という点はあるかと思います。

そんなときはsensitivity_plot()メソッドでグラフを作れます。

dml_obj.sensitivity_plot()

縦軸がcf_y、横軸がcf_dで、グラフの色が処置効果の値を表しています。

また、Unadjustedと書かれている箇所が観測できない交絡因子を一切考慮しない場合(つまり普通に因果推論した場合)、
Scenarioと書かれている箇所が上記でsensitivity_analysis()に指定したcf_ycf_dの値を使った場合の位置を表しています。

また、cf_ycd_dの値を大きくするとnull_hypothesisで指定した値に近づくようなプロットになります。

実際にnull_hypothesis=5.0として感度分析をやって、グラフをプロットしてみます。

dml_obj.sensitivity_analysis(cf_y=0.1, cf_d=0.1, rho=1.0, null_hypothesis=5.0)
dml_obj.sensitivity_plot()

null_hypothesis=5.0に向けて観測できない交絡因子が影響する仮定でプロットが作られるので、
1つ前のプロットと比べて、グラフの色が反転していることが分かります。

また、よく見ると、Scenarioの位置の値は5.16あたりでtheta upperが使われていることが分かります。

4. 実データを用いてやってみる

これまで使ってきた無機質な人工データだとふーんという感じであまり手触り感が無く終わってしまいそうなので、
401(k)という実データでも感度分析をしてみます。

このデータセットは複数の因果推論の論文で使用されているデータで、アメリカの 401(k)という確定拠出年金制度への参加資格が、その人の累積資産にどのような影響を与えるか分析できるものになっています。

以下、401(k)のデータで

  1. データ取得
  2. DML による因果推論
  3. DML を用いた感度分析

の順で実行していきます。

# --- 1. データ取得 ---
data = fetch_401K(return_type="DataFrame")
features_base = ["age", "inc", "educ", "fsize", "marr", "twoearn", "db", "pira", "hown"]
data_dml = dml.DoubleMLData(data, y_col="net_tfa", d_cols="e401", x_cols=features_base)

# --- 2. DML による因果推論 ---
learner_l = RandomForestRegressor(
    n_estimators=500, max_depth=10, max_features=5, min_samples_leaf=10
)
learner_m = RandomForestClassifier(
    n_estimators=500, max_depth=10, max_features=5, min_samples_leaf=10
)
np.random.seed(42)
dml_plr_obj = dml.DoubleMLPLR(
    data_dml, ml_l=learner_l, ml_m=learner_m, n_folds=5, n_rep=3
)
dml_plr_obj.fit()

# --- 3. DML を用いた感度分析 ---
dml_plr_obj.sensitivity_analysis(cf_y=0.03, cf_d=0.03, rho=1.0, null_hypothesis=0.0)

感度分析の出力は以下のようになります。

================== Sensitivity Analysis ==================

------------------ Scenario          ------------------
Significance Level: level=0.95
Sensitivity parameters: cf_y=0.03; cf_d=0.03, rho=1.0

------------------ Bounds with CI    ------------------
         CI lower  theta lower        theta   theta upper      CI upper
e401  2892.043073  5133.780059  8837.100328  12540.420598  14762.001569

------------------ Robustness Values ------------------
      H_0    RV (%)   RVa (%)
e401  0.0  7.009296  5.200793

どうやら、
シンプルに因果推論すると因果効果は約8837(ドル?)で、
観測できない交絡因子の影響を考慮すると最小で約2892、最大で約14762まで歪められる可能性がある、ということになります。
(ただし、cf_y=0.03, cf_d=0.03だと仮定した場合の結果)

また RV は約7.0なので、観測できない交絡因子が結果変数と処置変数の両方の残差の分散を約 7%説明する場合に帰無仮説が成立(つまり因果効果ゼロ)、ということになります。
7%くらいであれば発生しそうな気もするので、あまりロバストとは言えないのかもしれません。

次に、
cf_y=0.03, cf_d=0.03だと仮定せずに、cf_ycf_dを変動させた場合のグラフは以下になります。

dml_plr_obj.sensitivity_plot()

グラフの真ん中ぐらいで因果効果ゼロになっているのと、右上だと-9000とかまで行ってしまっているので、
あまりロバストと言えなさそうな雰囲気が出ています。(雑な説明)

Benchmarking Analysis

ここまでくると薄々既に感じている方もいるかもしれませんが、
「そもそも観測できない交絡因子の強さcf_y, cf_dがどれくらいになるかなんて分からなくない?」という問題があります。

この疑問は至極真っ当で、その対応として Benchmarking Analysis という手法があるので、それを最後に試してみます。

Benchmarking Analysis では、使用している共変量が仮に観測できない交絡因子であった場合に、グラフ上にプロットしてみる手法です。
これは sensitivity_benchmark()メソッドを使うと簡単に実施できます。

ベンチマークに使う既知の共変量は以下を使うことにします。

  • inc: 収入
  • pira: 個人退職年金口座(IRA)への参加の有無
  • twoearn: 共働きかどうか

それではやってみます。

# 既知の交絡因子をベンチマークとしてセットする
benchmark_inc = dml_plr_obj.sensitivity_benchmark(benchmarking_set=["inc"])
benchmark_pira = dml_plr_obj.sensitivity_benchmark(benchmarking_set=["pira"])
benchmark_twoearn = dml_plr_obj.sensitivity_benchmark(benchmarking_set=["twoearn"])

benchmark_dict = {
    "cf_y": [
        benchmark_inc.loc["e401", "cf_y"],
        benchmark_pira.loc["e401", "cf_y"],
        benchmark_twoearn.loc["e401", "cf_y"],
    ],
    "cf_d": [
        benchmark_inc.loc["e401", "cf_d"],
        benchmark_pira.loc["e401", "cf_d"],
        benchmark_twoearn.loc["e401", "cf_d"],
    ],
    "name": ["inc", "pira", "twoearn"],
}
# 感度分析 with ベンチマーク
dml_plr_obj.sensitivity_plot(benchmarks=benchmark_dict)

Benchmarking Analysis では、既存の共変量がプロット上でどこに位置するのかを可視化することが出来ます。

これを見ると、
piratwoearnレベルが観測されていない交絡因子で存在してもほぼ問題ないことが分かります。

逆に、incomeレベルがあるとちょっとやばいね、ということも分かります。
また、処置効果を-9000 とかまで引き下げるような交絡因子は現在見つかっている共変量から考えると、ほぼ存在しなさそうなこともグラフから推察できます。

ここまできてようやく感度分析って便利じゃん!と思えた方もいるのでは無いでしょうか?(僕はそうでした)

まとめ

この記事では、DML を使った感度分析の一連の流れを実装してみました。

感度分析自体は因果推論の中でそこまでメジャーではない印象ですが、やると得られる知見は多いので基本的にセットでやるべきだなと感じました。

といっても、今回は DML を使った感度分析の数式には一切触れず、また既存のライブラリを使って実装したので
「比較的新しめの手法で感度分析したらこのような結果になった」くらいの理解になってしまってはいるなと思いました。

誰かに自信を持って説明するには、やっぱりもう少し数式的な理解をしてから望みたい気持ちはあるので、いつかそんな記事も書けると良いなと思います。

Discussion