🐍

機械学習で因果推論~Double Machine Learning~

2022/12/31に公開

はじめに

Double Machine Learning(以下、DML)について、Pythonによる実装を交えてまとめました。内容について誤り等ございましたら、コメントにてご指摘いただけますと幸いです。

DMLの概要

DMLとは、機械学習手法を用いつつ2段階に分けて処置効果を推定する手法です。1段階目で2つの予測タスクを行い、2段階目で処置効果を推定するモデルを作成します。

DMLの利点

DMLを用いる利点はたくさんありますが、次の4つを紹介させていただきます。

  1. 処置効果の異質性(HTE: Heterogeneous Treatment Effect)を考慮した推定が可能
  2. 処置変数が連続・離散問わず適応可能
  3. あまりに高次元すぎて古典的な統計学的アプローチでは適応できない・パラメトリックな関数で十分にモデル化できない場合にも利用可能
  4. 漸近正規性や信頼区間の構築など、望ましい統計的性質を多く維持したままモデルを作成することが可能

DMLのアルゴリズム

まず、記号を整理します。

  • Y: アウトカムを表す変数
  • T: 処置を表す変数
  • X: 効果修飾因子を表す変数(の集合)で、Y単体あるいはYTの両方に影響を与えます(共変量の一部)
  • W: X以外の共変量を表す変数(集合)で、YT(あるいは両方)に影響を与えます

ここで推定したいのは、Xで条件づけた処置効果(CATE)です。そのため、XWを区別しています。CATEを\theta(X)と置きます。

1段階目

YTに関する2つの予測タスクを行い、部分線形モデルを作成します。

Y = \theta(X) \cdot T + g(X, W) + \epsilon , \quad T = f(X, W) + \eta
ただし、
E[\epsilon|X, W]=0, \ E[\eta|X,W]=0, \ E[\eta \cdot \epsilon|X,W]=0
ここで、gfは構造的な仮定を置いておらず、任意のノンパラメトリックな機械学習手法を用いてYTを推定しています。

2段階目

YTの残差同士の回帰を行うことで、CATE(\theta(X))を推定するモデルを作成します。

Y - E[Y|X, W] = \theta(X) \cdot (T - E[T|X, W]) + \epsilon \tag{R}
ここで、
q(X, W) = E[Y|X,W], \ f(X,W) = E[T|X,W]
と置くと、YTの残差\tilde{Y}\tilde{T}は次のように表されます。
\tilde{Y} = Y - q(X,W), \ \tilde{T} = T - f(X, W) = \eta
よって(R)式は次のように表すことができ、
\tilde{Y} = \theta(X) \cdot \tilde{T} + \epsilon \tag{R'}
この(R')式の二乗誤差を最小とするような\theta(X)が、求めたい推定量\hat{\theta}になります。
\hat{\theta} = \underset{\theta}{\argmin} \ E[(\tilde{Y} - \theta(X) \cdot \tilde{T})^2]

DMLのアルゴリズムに関する補足

DMLのアルゴリズム内で用いられている背景知識や手法についてざっくり紹介します。

部分線形モデルとRobinsonの手法

DMLの理論には、部分線形モデルを推定するRobinsonの手法が背景にあります。

部分線形モデル

部分線形モデルとは、共変量の集合X, Wについてノンパラメトリックな関数gでモデリングし、そのモデルを用いて、処置Tg(X, W)でアウトカムYについて線形モデリングしたものです。

一般的な線形モデルが

Y = \theta T + b^T X + c^T W + \epsilon
と表されるのに対し、部分線形モデルは
Y = \theta T + g(X, W) + \epsilon
と表されます。

Robinsonの手法

Robinsonの手法とは、部分線形モデルの推定法の1つです。セミパラメトリックモデルの効率性を持つことが知られています。

2段階に分けて推定する方法で、大まかな流れはDMLと同じです。1段階目で共変量の集合X, WからアウトカムYと処置Tを予測するモデルを作り、2段階目でYTの残差同士の回帰を行うことでパラメータ\thetaの推定量を求めます。

1段階目
Y = \theta \cdot T + g(X, W) + \epsilon, \quad T = f(X, W) + \eta
2段階目
Y - E[Y|X, W] = \theta (T - E[T|X, W]) + \epsilon

Robinsonの手法における\thetaがDMLでは\theta(X)になっています。

2つのバイアスへの対応

機械学習を用いた、例えば、

Y = \hat{\theta} T + \hat{g}(X, W) + \hat{\epsilon}
から\thetaの推定値を求めるようなアプローチ(ナイーブなアプローチ)では

  • Regularization Bias: 正則化バイアス
  • Overfitting Bias: 過学習バイアス

の2つのバイアスが生じることが知られています。DMLでは

  • 直交化(ネイマン直交条件)
  • cross-fitting

を行うことによって、これらのバイアスに対応しています。

Pythonによる実装

Pythonでデータを生成し、DMLを用いて処置効果を推定してみます。

データの準備

下記の設定に従うデータを生成します。

T \backsim Bernoulli(f(W))
Y = T \cdot \theta(X) + \braket{W, \gamma} + \epsilon
W \backsim Normal(0, I_{nw})
X \backsim Uniform(0, 1)^{n_x}
ただし、
f(W)=\sigma(\braket{W, \beta} + \eta), \ \eta \backsim Uniform(-1, 1), \ \epsilon \backsim Uniform(-1, 1)
今回、推定したい処置効果は
\theta(x) = exp(2 \cdot x_0)
とします。

# 必要なライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
import econml
from econml.dml import DML, LinearDML, CausalForestDML
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
import shap

# Jupyter上にグラフを描画
%matplotlib inline

# 処置効果を返す関数
def exp_te(x):
  return np.exp(2*x[0])

# データの設定
np.random.seed(123)
n = 1000
n_w = 30
support_size = 5
n_x = 4

# アウトカムデータを生成するために利用するオブジェクトを作成
support_Y = np.random.choice(range(n_w), size=support_size, replace=False)
coefs_Y = np.random.uniform(0, 1, size=support_size)
epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n)

# 処置データを生成するために利用するオブジェクトを作成
support_T = support_Y
coefs_T = np.random.uniform(0, 1, size=support_size)
eta_sample = lambda n: np.random.uniform(-1, 1, size=n)

# 共変量データの生成
W = np.random.normal(0, 1, size=(n, n_w))
X = np.random.uniform(0, 1, size=(n, n_x))

# 処置効果データの生成
TE = np.array([exp_te(x_i) for x_i in X])

# 処置データの生成
log_odds = np.dot(W[:, support_T], coefs_T) + eta_sample(n)
T_sigmoid = 1 / (1 + np.exp(-log_odds))
T = np.array([np.random.binomial(1, p) for p in T_sigmoid])

# アウトカムデータの生成
Y = TE * T + np.dot(W[:, support_Y], coefs_Y) + epsilon_sample(n)

# テストデータを生成
X_test = np.random.uniform(0, 1, size=(n, n_x))
X_test[:, 0] = np.linspace(0, 1, n)

DMLの実装

今回はLinearDMLとCausalForestDMLという2つの手法で処置効果を推定します。

まずはLinearDMLで推定します。

# LinearDML
est = LinearDML(model_y=RandomForestRegressor(), 
                model_t=RandomForestClassifier(min_samples_leaf=10), 
                discrete_treatment=True, 
                linear_first_stages=False, 
                cv=6)
est.fit(Y, T, X=X, W=W)
te_pred = est.effect(X_test)
lb, ub = est.effect_interval(X_test, alpha=0.01)

次にCausalForestDMLで推定します。

# CausalForestDML
est2 = CausalForestDML(model_y=RandomForestRegressor(), 
                       model_t=RandomForestClassifier(min_samples_leaf=10), 
                       discrete_treatment=True, 
                       n_estimators=1000, 
                       min_impurity_decrease=0.001, 
                       verbose=0, 
                       cv=6)
est2.tune(Y, T, X=X, W=W)
est2.fit(Y, T, X=X, W=W)
te_pred2 = est2.effect(X_test)
lb2, ub2 = est2.effect_interval(X_test, alpha=0.01)

推定結果を描画します。

# 推定結果の描画
expected_te = np.array([exp_te(x_i) for x_i in X_test])
plt.figure(figsize=(12, 4))

# LinearDML
plt.subplot(1, 2, 1)
plt.plot(X_test[:, 0], te_pred, label='LinearDML', alpha=0.6)
plt.fill_between(X_test[:, 0], lb, ub, alpha=0.4)
plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')
plt.ylabel('Treatment Effect')
plt.xlabel('x')
plt.legend()

# CausalForestDML
plt.subplot(1, 2, 2)
plt.plot(X_test[:, 0], te_pred2, label='CausalForestDML', alpha=0.6)
plt.fill_between(X_test[:, 0], lb2, ub2, alpha=0.4)
plt.plot(X_test[:, 0], expected_te, 'b--', label='True effect')
plt.ylabel('Treatment Effect')
plt.xlabel('x')
plt.legend()

plt.show()

破線部が推定したい効果で、青塗り部分が推定結果の信頼区間です。

SHAP値も算出することができます。

# SHAP値を可視化
shap_values = est.shap_values(X)
shap.plots.beeswarm(shap_values['Y0']['T0_1'])

shap_values = est2.shap_values(X)
shap.plots.beeswarm(shap_values['Y0']['T0_1'])

データの生成過程どおり、X_0が処置効果の推定結果に強く影響を与えていることが分かります。

参考文献

参照日はすべて2022/12/31です。

おわりに

最後まで読んでいただきありがとうございました。他にも「Python×データ分析」をメインテーマに記事を執筆しているので、参考にしていただけたら幸いです。内容の誤り等がございましたら、コメントにてご指摘くださいませ。

他にも下記のような記事を書いています。ご一読いただけますと幸いです。

また、過去にLTや勉強会で発表した資料は下記リンクにまとめてあります。ぜひ、ご一読くださいませ。
https://speakerdeck.com/s1ok69oo

Discussion