library(tidyverse)
library(sl3)
library(tmle3)
library(future)
<- partial(
kable ::kable,
knitrdigits = 3
)
plan(multisession, workers = 8)
set.seed(7)
TMLEの手順
参考ページ:https://www.khstats.com/blog/tmle/tutorial-pt2
下準備
ライブラリなどの読み込み
データ作成
元のサイトでは\(Y\)が2値変数になっていて、推定が真値と一致しているかどうかの評価が難しいので、ここでは\(Y\)を連続変数としている。
ATEの真値は0.3に設定
<- function(n){
generate_data <- rbinom(n, size=1, prob=0.2) # binary confounder
W1 <- rbinom(n, size=1, prob=0.5) # binary confounder
W2 <- round(runif(n, min=2, max=7)) # continuous confounder
W3 <- round(runif(n, min=0, max=4)) # continuous confounder
W4 <- rbinom(n, size=1, prob= plogis(-0.2 + 0.2*W2 + log(0.1*W3) + 0.3*W4 + 0.2*W1*W4)) # binary treatment depends on confounders
A <- -1 + 0.3*A - 0.1*W1 + 0.2*W2 + 0.3*W3 - 0.1*W4 + sin(0.1*W2*W4) # continuous outcome depends on confounders
Y return(tibble(Y, W1, W2, W3, W4, A))
}
<- 1000
n <- generate_data(n) # generate a data set with n observations
dat_obs
|>
dat_obs summarise(samplemean = mean(Y), .by = A)
# A tibble: 2 × 2
A samplemean
<int> <dbl>
1 1 0.686
2 0 0.237
機械学習ライブラリの設定
- glm、Lasso、Random forest、Multivariate adaptive regression splineをスタッキング
<-
sl_libs $new(
Lrnr_sllearners = Stack$new(
$new(),
Lrnr_glm$new(alpha = 1),
Lrnr_glmnet$new(num.trees = 2000, max.depth = 3),
Lrnr_ranger$new()
Lrnr_earth
) )
Step1:アウトカムの予測
アウトカムの条件付き期待値関数を推定
\[ {\mathrm E}[Y | A, W] \]
<-
task $new(
sl3_Task
dat_obs, covariates = select(dat_obs, !Y) |> names(),
outcome = 'Y',
outcome_type = 'continuous',
folds = 8L
)
<-
task_A1 $new(
sl3_Task|> mutate(A = 1),
dat_obs covariates = select(dat_obs, !Y) |> names(),
outcome = 'Y',
outcome_type = 'continuous',
folds = 8L
)
<-
task_A0 $new(
sl3_Task|> mutate(A = 0),
dat_obs covariates = select(dat_obs, !Y) |> names(),
outcome = 'Y',
outcome_type = 'continuous',
folds = 8L
)
# 全サンプルで学習
<- sl_libs$train(task) sl_fit
以下の3つの予測値を算出
\(A := 1\)は全サンプルでAを1にする(\(A = 1\)はサンプルのうちのAが1となる部分集団)
\[\begin{align} &{\mathrm E}[Y | A, W] \\ &{\mathrm E}[Y | A := 1, W] \\ &{\mathrm E}[Y | A := 0, W] \end{align}\]
<-
dat_tmle1 |>
dat_obs mutate(
# 観測サンプルについての予測値
Q_A = sl_fit$predict(task),
# 全てのサンプルでA = 1に固定したときの予測値
Q_A1 = sl_fit$predict(task_A1),
# 全てのサンプルでA = 0に固定した時の予測値
Q_A0 = sl_fit$predict(task_A0)
)
- standardization(g-computation)によるATE
\[ ATE_{g \mathrm{-}comp} = {\mathrm E}[ {\mathrm E}[Y | A := 1, W] - {\mathrm E}[Y | A := 0, W]] \]
|>
dat_tmle1 summarise(ATE = mean(Q_A1 - Q_A0))
# A tibble: 1 × 1
ATE
<dbl>
1 0.300
Step2:処置確率(傾向スコア)の予測
- 傾向スコアを機械学習モデルにより予測
\[ \mathrm{Pr}(A = 1 | W) \]
<-
task_g $new(
sl3_Taskdata = dat_obs,
covariates = select(dat_obs, !c(Y, A)) |> names(),
outcome = 'A',
outcome_type = 'binomial',
folds = 8
)
<- sl_libs$train(task_g) sl_fit_g
Clever Covariateの作成
- 傾向スコアからClever Covariateと呼ばれる情報を作成(IPWに似ている)
\[\begin{align} &H(A,W) &= \frac{A}{\mathrm{Pr}(A = 1 | W)} - \frac{1 - A}{1 - \mathrm{Pr}(A = 1 | W)} \\ &H(1,W) &= \frac{A}{\mathrm{Pr}(A = 1 | W)} \\ &H(0,W) &= - \frac{1 - A}{1 - \mathrm{Pr}(A = 1 | W)} \end{align}\]
<-
dat_tmle2 |>
dat_tmle1 mutate(
# Propensity Scoreの予測
ps = sl_fit_g$predict(task_g),
# ipw (Inverse Probability Weight)
ipw = case_when(
== 1 ~ 1 / ps,
A == 0 ~ 1 / (1 - ps)
A
),# Clever Covariates
H_A = case_when(
== 1 ~ 1 / ps,
A == 0 ~ -1 / (1 - ps)
A
),H_A1 = case_when(
== 1 ~ H_A,
A == 0 ~ 0
A
),H_A0 = case_when(
== 1 ~ 0,
A == 0 ~ H_A
A
) )
- IPWによるATE
\[ ATE_{ipw} = {\mathrm E}[\frac{A}{\mathrm{Pr}(A = 1 | W)}Y - \frac{1 - A}{1 - \mathrm{Pr}(A = 1 | W)}Y] \]
|>
dat_tmle2 summarise(CFmean = sum(Y*ipw) / sum(ipw), .by = A) |>
arrange(A) |>
summarise(ATE = diff(CFmean))
# A tibble: 1 × 1
ATE
<dbl>
1 0.301
- Augumented IPWによるATE
- (ほんとは関数推定時にcross-fitをする)
\[ ATE_{aipw} = \mathrm E[\mathrm E[Y | A := 1, W] - \mathrm E[Y | A := 0, W] + \frac{A}{\mathrm Pr(A = 1 | W)}(Y - {\mathrm E}[Y | A := 1, W]) - \frac{1 - A}{1 - \mathrm Pr(A = 1 | W)}(Y - \mathrm E[Y | A := 0, W])] \]
|>
dat_tmle2 summarise(ATE = mean(Q_A1 - Q_A0 + ipw*A*(Y - Q_A1) - ipw*(1 - A)*(Y - Q_A0)))
# A tibble: 1 × 1
ATE
<dbl>
1 0.300
Step3:変動パラメータの推定
- AIPWの問題点:統計的最適化がターゲットのパラメータ(ATE)に対してではなく、母平均関数\({\mathrm E}[Y | A,W]\)および傾向スコア関数\(\mathrm{Pr}(A = 1 | W)\)のパラメータについて最適化されている点
- 推定したいパラメータ(ATE)のEIF(Efficient Influence Function)を解くことがこのステップのポイントらしい
- 具体的には、Step1で推定した\(\mathrm{E}[Y | A, W]\)と、Step2で推定したClever Covariate\(H(A, W)\)を用いて、以下の回帰式の\(\epsilon\)(変動パラメータ)を推定する
\[ Y = \mathrm{E}[Y | A, W] + \epsilon H(A,W) \]
- 切片が0で、Step1の推定値の係数を1に固定するために、-1と
offset
を利用する
<- glm(Y ~ -1 + offset(Q_A) + H_A, data = dat_tmle2, family = gaussian()) fit
- 変動パラメータの推定値
<- coef(fit)
epsilon
epsilon
H_A
6.907486e-07
Step4:アウトカムの予測値を更新
- 推定したepsilonと\(Y\)の予測値をもとに、\(Y\)の予測値を更新
<-
dat_tmle3 |>
dat_tmle2 mutate(
Q_A_update = Q_A + epsilon*H_A,
Q_A1_update = Q_A1 + epsilon*H_A1,
Q_A0_update = Q_A0 + epsilon*H_A0,
)
Step5:推定したい統計量を推定
- 更新されたアウトカムの予測値を用いて、Standardizationの要領でATEを推定
|>
dat_tmle3 summarise(ATE = mean(Q_A1_update - Q_A0_update))
# A tibble: 1 × 1
ATE
<dbl>
1 0.300
<- mean(dat_tmle3$Q_A1_update - dat_tmle3$Q_A0_update) ATE
Step6:標準誤差の推定
- TMLEではbootstrapによらずとも標準誤差を算出できる(!)
- まずは、Influence Functionを推定する
- Influence Function:各サンプルがどれだけATEに影響をあたえるか?
<-
dat_tmle4 |>
dat_tmle3 mutate(
IF = (Y - Q_A_update)*H_A + Q_A1_update - Q_A0_update - ATE
)
ATEの標準誤差はIFを用いて
\[ SE = \sqrt{\frac{\mathrm{var}(IF)}{N}} \]
|>
dat_tmle4 summarise(SE = sqrt(var(IF) / 1000))
# A tibble: 1 × 1
SE
<dbl>
1 0.0000217
tmle3
パッケージによる実行
<-
node_list list(
W = dat_obs |> select(!c(A, Y)) |> names(),
A = 'A',
Y = 'Y'
)
<-
tmle3_fit tmle3(
tmle_spec = tmle_ATE(treatment_level = 1, control_level = 0),
data = dat_obs,
node_list = node_list,
learner_list = list(A = sl_libs, Y = sl_libs)
)
大体正しく推定できている
tmle3_fit
A tmle3_Fit that took 1 step(s)
type param init_est tmle_est se lower
<char> <char> <num> <num> <num> <num>
1: ATE ATE[Y_{A=1}-Y_{A=0}] 0.2999748 0.2996577 0.0002061923 0.2992535
upper psi_transformed lower_transformed upper_transformed
<num> <num> <num> <num>
1: 0.3000618 0.2996577 0.2992535 0.3000618