library(tidyverse)
library(broom)
<- partial(
kable ::kable,
knitrdigits = 3
)
set.seed(95)
Hernán and Robins (2020) のPart3では、複数時点での処置の因果推論について議論されている。ここでは、IPW、g-formulaを用いた因果効果の推定を実際にやってみる。
下準備
ライブラリなど
データの作成
以下のDAGを考える(Hernán and Robins 2020, fig. 20.3)
Treatmentは\(A_0\)と\(A_1\)であるが、ここから\(Y\)へのpathはないので、因果効果は全ての組み合わせについて0になる
具体的には、potential outcomeのすべての組み合わせについて、差分を取った値が0になる
\[\begin{align} \mathrm{E}[Y^{0, 0} - Y^{1, 0}] &= 0\\ \mathrm{E}[Y^{0, 0} - Y^{0, 1}] &= 0\\ \mathrm{E}[Y^{0, 0} - Y^{1, 1}] &= 0\\ \mathrm{E}[Y^{1, 0} - Y^{0, 1}] &= 0\\ \mathrm{E}[Y^{1, 0} - Y^{1, 1}] &= 0\\ \mathrm{E}[Y^{0, 1} - Y^{1, 1}] &= 0 \end{align}\]
上記のDAGに合うように作成されたデータが以下である(Hernán and Robins 2020, Table 20.1)
<-
data_agg tribble(
~N, ~A0, ~L1, ~A1, ~Y,
2400, 0, 0, 0, 84,
1600, 0, 0, 1, 84,
2400, 0, 1, 0, 52,
9600, 0, 1, 1, 52,
4800, 1, 0, 0, 76,
3200, 1, 0, 1, 76,
1600, 1, 1, 0, 44,
6400, 1, 1, 1, 44
)
|> kable() data_agg
N | A0 | L1 | A1 | Y |
---|---|---|---|---|
2400 | 0 | 0 | 0 | 84 |
1600 | 0 | 0 | 1 | 84 |
2400 | 0 | 1 | 0 | 52 |
9600 | 0 | 1 | 1 | 52 |
4800 | 1 | 0 | 0 | 76 |
3200 | 1 | 0 | 1 | 76 |
1600 | 1 | 1 | 0 | 44 |
6400 | 1 | 1 | 1 | 44 |
これを個人レベルのデータに変換
<- uncount(data_agg, N)
data_ind
data_ind
# A tibble: 32,000 × 4
A0 L1 A1 Y
<dbl> <dbl> <dbl> <dbl>
1 0 0 0 84
2 0 0 0 84
3 0 0 0 84
4 0 0 0 84
5 0 0 0 84
6 0 0 0 84
7 0 0 0 84
8 0 0 0 84
9 0 0 0 84
10 0 0 0 84
# ℹ 31,990 more rows
各処置の因果効果
- 各処置の効果を確認
- 処置からアウトカムへの直接のパスはないので、各時点の処置の因果効果も0になる
\(A_0\)の因果効果
|>
data_ind lm(Y ~ A0, data = _) |>
tidy() |>
kable()
term | estimate | std.error | statistic | p.value |
---|---|---|---|---|
(Intercept) | 60 | 0.118 | 507.077 | 0 |
A0 | 0 | 0.167 | 0.000 | 1 |
\(A_1\)の因果効果
- 以下のバックドアが存在するので\(L_1\)を統制
- \(A_1 \gets L_1 \gets U_1 \to Y\)
|>
data_ind lm(Y ~ A1 + L1, data = _) |>
tidy() |>
kable()
term | estimate | std.error | statistic | p.value |
---|---|---|---|---|
(Intercept) | 78.667 | 0.040 | 1943.969 | 0 |
A1 | 0.000 | 0.050 | 0.000 | 1 |
L1 | -29.867 | 0.049 | -611.652 | 0 |
collider bias
- \(A_0\)の因果効果の推定において、\(L_1\)を条件づける
- \(A_0 \to \boxed{L_1} \gets U_1 \to Y\)というパスが開いて、バイアスをもたらす
|>
data_ind lm(Y ~ A0 + L1, data = _) |>
tidy() |>
kable()
term | estimate | std.error | statistic | p.value |
---|---|---|---|---|
(Intercept) | 84 | 0 | 3.657678e+14 | 0 |
A0 | -8 | 0 | -3.597746e+13 | 0 |
L1 | -32 | 0 | -1.393401e+14 | 0 |
IPW
- \(\bar{A}_k = (A_1, A_2, \dots, A_k)\):\(k\)時点までの処置の履歴
- \(\bar{L}_k = (L_1, L_2, \dots, L_k)\):\(k\)時点までの共変量の履歴
以下のウェイトを作成する
\[ W^{\bar{A}_k} = \prod_{k=0}^{K}\frac{1}{f(A_k | \bar{A}_{k-1},\bar{L}_{k})} \]
2時点では以下のように書ける
\[ W^{A_0, A_1} = \frac{1}{f(A_0 | L_0)} \times \frac{1}{f(A_1 | A_0, L_0, L_1)} \]
さらに、今回はベースライン共変量がない(\(L_0 = \varnothing\))ので
\[ W^{A_0, A_1} = \frac{1}{f(A_0)} \times \frac{1}{f(A_1 | A_0, L_1)} \]
Stabilized IP weights も推定してみる
\[ SW^{A_0, A_1} = \frac{f(A_0)}{f(A_0)} \times \frac{f(A_1 | A_0)}{f(A_1 | A_0, L_1)} \]
各処置に対するウェイトの推定
まずは、\(f(A_0), f(A_1|A_0,L_1), f(A_1|A_0)\)の3つを推定する
# 各時点の処置を予測するモデル
<- glm(A0 ~ 1, data = data_ind, family = "binomial")
model_A0 <- glm(A1 ~ A0 * L1, data = data_ind, family = "binomial")
model_A1
# Stabilized Weightsの分子を予測するモデル
<- glm(A1 ~ A0, data = data_ind, family = 'binomial') model_A1_without_L1
予測値を算出
<-
ipw_data1 |>
data_ind mutate(
# 各モデルの予測値を算出
pred_A0 = predict(model_A0, type = 'response'),
pred_A1 = predict(model_A1, type = 'response'),
pred_A1_without_L1 = predict(model_A1_without_L1, type = 'response'),
)
|>
ipw_data1 count(pick(everything())) |>
kable()
A0 | L1 | A1 | Y | pred_A0 | pred_A1 | pred_A1_without_L1 | n |
---|---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0.5 | 0.4 | 0.7 | 2400 |
0 | 0 | 1 | 84 | 0.5 | 0.4 | 0.7 | 1600 |
0 | 1 | 0 | 52 | 0.5 | 0.8 | 0.7 | 2400 |
0 | 1 | 1 | 52 | 0.5 | 0.8 | 0.7 | 9600 |
1 | 0 | 0 | 76 | 0.5 | 0.4 | 0.6 | 4800 |
1 | 0 | 1 | 76 | 0.5 | 0.4 | 0.6 | 3200 |
1 | 1 | 0 | 44 | 0.5 | 0.8 | 0.6 | 1600 |
1 | 1 | 1 | 44 | 0.5 | 0.8 | 0.6 | 6400 |
各処置に対するウェイトを作成
<-
ipw_data2 |>
ipw_data1 mutate(
# A0に対するウェイト
ipw_A0 = case_when(
== 1 ~ 1 / pred_A0,
A0 == 0 ~ 1 / (1 - pred_A0)
A0
),# A0に対するSW
sw_A0 = case_when(
== 1 ~ pred_A0 / pred_A0,
A0 == 0 ~ (1 - pred_A0) / (1 - pred_A0)
A0
),# A1に対するウェイト
ipw_A1 = case_when(
== 1 ~ 1 / pred_A1,
A1 == 0 ~ 1 / (1 - pred_A1)
A1
),# A1に対するSW
sw_A1 = case_when(
== 1 ~ pred_A1_without_L1 / pred_A1,
A1 == 0 ~ (1 - pred_A1_without_L1) / (1 - pred_A1)
A1
),
)
|>
ipw_data2 count(pick(everything())) |>
kable()
A0 | L1 | A1 | Y | pred_A0 | pred_A1 | pred_A1_without_L1 | ipw_A0 | sw_A0 | ipw_A1 | sw_A1 | n |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0.5 | 0.4 | 0.7 | 2 | 1 | 1.667 | 0.500 | 2400 |
0 | 0 | 1 | 84 | 0.5 | 0.4 | 0.7 | 2 | 1 | 2.500 | 1.750 | 1600 |
0 | 1 | 0 | 52 | 0.5 | 0.8 | 0.7 | 2 | 1 | 5.000 | 1.500 | 2400 |
0 | 1 | 1 | 52 | 0.5 | 0.8 | 0.7 | 2 | 1 | 1.250 | 0.875 | 9600 |
1 | 0 | 0 | 76 | 0.5 | 0.4 | 0.6 | 2 | 1 | 1.667 | 0.667 | 4800 |
1 | 0 | 1 | 76 | 0.5 | 0.4 | 0.6 | 2 | 1 | 2.500 | 1.500 | 3200 |
1 | 1 | 0 | 44 | 0.5 | 0.8 | 0.6 | 2 | 1 | 5.000 | 2.000 | 1600 |
1 | 1 | 1 | 44 | 0.5 | 0.8 | 0.6 | 2 | 1 | 1.250 | 0.750 | 6400 |
IP weightの作成
各時点のウェイトを掛け合わせる。たしかにSWのほうがウェイトが狭い範囲にまとまっていることがわかる。
<-
ipw_data3 |>
ipw_data2 mutate(
# A0に対するウェイトとA1に対するウェイトをかけ算
ipw = ipw_A0 * ipw_A1,
sw = sw_A0 * sw_A1
)
|>
ipw_data3 count(pick(everything())) |>
kable()
A0 | L1 | A1 | Y | pred_A0 | pred_A1 | pred_A1_without_L1 | ipw_A0 | sw_A0 | ipw_A1 | sw_A1 | ipw | sw | n |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0.5 | 0.4 | 0.7 | 2 | 1 | 1.667 | 0.500 | 3.333 | 0.500 | 2400 |
0 | 0 | 1 | 84 | 0.5 | 0.4 | 0.7 | 2 | 1 | 2.500 | 1.750 | 5.000 | 1.750 | 1600 |
0 | 1 | 0 | 52 | 0.5 | 0.8 | 0.7 | 2 | 1 | 5.000 | 1.500 | 10.000 | 1.500 | 2400 |
0 | 1 | 1 | 52 | 0.5 | 0.8 | 0.7 | 2 | 1 | 1.250 | 0.875 | 2.500 | 0.875 | 9600 |
1 | 0 | 0 | 76 | 0.5 | 0.4 | 0.6 | 2 | 1 | 1.667 | 0.667 | 3.333 | 0.667 | 4800 |
1 | 0 | 1 | 76 | 0.5 | 0.4 | 0.6 | 2 | 1 | 2.500 | 1.500 | 5.000 | 1.500 | 3200 |
1 | 1 | 0 | 44 | 0.5 | 0.8 | 0.6 | 2 | 1 | 5.000 | 2.000 | 10.000 | 2.000 | 1600 |
1 | 1 | 1 | 44 | 0.5 | 0.8 | 0.6 | 2 | 1 | 1.250 | 0.750 | 2.500 | 0.750 | 6400 |
Counterfactual meanの推定
すべてのパターンで同じ値なので、差分をとれば0になる
<-
res_ipw |>
ipw_data3 # 重み付け推定
summarise(
CFmean_ipw = weighted.mean(Y, ipw),
CFmean_sw = weighted.mean(Y, sw),
.by = c(A0, A1),
|>
) arrange(A0, A1)
|>
res_ipw kable()
A0 | A1 | CFmean_ipw | CFmean_sw |
---|---|---|---|
0 | 0 | 60 | 60 |
0 | 1 | 60 | 60 |
1 | 0 | 60 | 60 |
1 | 1 | 60 | 60 |
Marginal structural modelの推定
以下のようなmarginal structural modelを考えてみる
\[ \mathrm{E}[Y^{A_0,A_1}] = \beta_0 + \beta_1 (A_0 + A_1) \]
当然ながら\(\beta_1 = 0\)となる
::lm_robust(Y ~ I(A0 + A1), data = ipw_data3, weight = ipw) |>
estimatr::tidy() |>
broomkable()
term | estimate | std.error | statistic | p.value | conf.low | conf.high | df | outcome |
---|---|---|---|---|---|---|---|---|
(Intercept) | 60 | 0.154 | 389.52 | 0 | 59.698 | 60.302 | 31998 | Y |
I(A0 + A1) | 0 | 0.122 | 0.00 | 1 | -0.240 | 0.240 | 31998 | Y |
::lm_robust(Y ~ I(A0 + A1), data = ipw_data3, weights = sw) |>
estimatr::tidy() |>
broomkable()
term | estimate | std.error | statistic | p.value | conf.low | conf.high | df | outcome |
---|---|---|---|---|---|---|---|---|
(Intercept) | 60 | 0.162 | 369.84 | 0 | 59.682 | 60.318 | 31998 | Y |
I(A0 + A1) | 0 | 0.127 | 0.00 | 1 | -0.250 | 0.250 | 31998 | Y |
G-formula
Sequential exchangeabilityが成立しているとき、以下の式が成り立つ
\[ \mathrm{E}[Y^{\bar{A}_k}] = \sum_\bar{l}\mathrm{E}[Y | \bar{A}_k = \bar{a}_k, \bar{L}_k = \bar{l}_k]\prod_{k = 0}^K f(l_k | \bar{a}_{k-1},\bar{l}_{k-1}) \]
一般式はイカツイが、2時点かつ\(L_0\)がない今回の条件だと以下のようになる
\[ \mathrm{E}[Y^{A_0,A_1}] = \sum_{l_1} \mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]f(l_1 | a_0) \]
アウトカムと共変量の予測モデルを推定
- \(\mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]\)と\(f(l_1 | a_0)\)のモデルを推定
- 1時点の時と比較すると、共変量の予測モデルが必要になる点が新しい
# アウトカムモデル
<- lm(Y ~ A0*A1*L1, data = data_ind)
model_Y # L1の予測モデル
<- lm(L1 ~ A0, data = data_ind) model_L1
共変量の予測値を推定
ひとまず、\(A_0 = A_1 = 0\)のときのpotential outcome\(\mathrm{E}[Y^{0,0}]\)を推定することを考える
まずサンプル全員の処置の値を\(A_0 = A_1 = 0\)で置き換える
<-
gform_data1 |>
data_ind # 名前を変えておく
rename(A0_obs = A0, A1_obs = A1, L1_obs = L1, Y_obs = Y) |>
# 処置の値を変更
mutate(A0 = 0, A1 = 0)
|>
gform_data1 count(pick(everything())) |>
kable()
A0_obs | L1_obs | A1_obs | Y_obs | A0 | A1 | n |
---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0 | 0 | 2400 |
0 | 0 | 1 | 84 | 0 | 0 | 1600 |
0 | 1 | 0 | 52 | 0 | 0 | 2400 |
0 | 1 | 1 | 52 | 0 | 0 | 9600 |
1 | 0 | 0 | 76 | 0 | 0 | 4800 |
1 | 0 | 1 | 76 | 0 | 0 | 3200 |
1 | 1 | 0 | 44 | 0 | 0 | 1600 |
1 | 1 | 1 | 44 | 0 | 0 | 6400 |
次に、\(A_0 = 0\)の状況での\(L_1\)の値を予測する。これが\(f(l_1 | 0)\)
\(L^{A_0 = 0}\)の値をシミュレートしているとも考えられる
<-
gform_data2 |>
gform_data1 # 全員がA0 = 0のときの共変量L1の値をシミュレート
mutate(L1 = predict(model_L1, newdata = gform_data1))
|>
gform_data2 count(pick(everything())) |>
kable()
A0_obs | L1_obs | A1_obs | Y_obs | A0 | A1 | L1 | n |
---|---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0 | 0 | 0.75 | 2400 |
0 | 0 | 1 | 84 | 0 | 0 | 0.75 | 1600 |
0 | 1 | 0 | 52 | 0 | 0 | 0.75 | 2400 |
0 | 1 | 1 | 52 | 0 | 0 | 0.75 | 9600 |
1 | 0 | 0 | 76 | 0 | 0 | 0.75 | 4800 |
1 | 0 | 1 | 76 | 0 | 0 | 0.75 | 3200 |
1 | 1 | 0 | 44 | 0 | 0 | 0.75 | 1600 |
1 | 1 | 1 | 44 | 0 | 0 | 0.75 | 6400 |
アウトカムの予測値を推定
さらに、シミュレートした\(L_1\)と、\(A_0 = A_1 = 0\)のもとでの\(Y\)の値を予測する。これが\(\mathrm{E}[Y | A_0 = 0,A_1 = 0,L_1 = l_1]\)
<-
gform_data3 |>
gform_data2 mutate(Y = predict(model_Y, newdata = gform_data2))
|>
gform_data3 count(pick(everything())) |>
kable()
A0_obs | L1_obs | A1_obs | Y_obs | A0 | A1 | L1 | Y | n |
---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 84 | 0 | 0 | 0.75 | 60 | 2400 |
0 | 0 | 1 | 84 | 0 | 0 | 0.75 | 60 | 1600 |
0 | 1 | 0 | 52 | 0 | 0 | 0.75 | 60 | 2400 |
0 | 1 | 1 | 52 | 0 | 0 | 0.75 | 60 | 9600 |
1 | 0 | 0 | 76 | 0 | 0 | 0.75 | 60 | 4800 |
1 | 0 | 1 | 76 | 0 | 0 | 0.75 | 60 | 3200 |
1 | 1 | 0 | 44 | 0 | 0 | 0.75 | 60 | 1600 |
1 | 1 | 1 | 44 | 0 | 0 | 0.75 | 60 | 6400 |
サンプル全体で平均する
以下のように変形できるので、アウトカムの予測値をサンプル全体で平均すればよい
\[ \sum_{l_1} \mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]f(l_1 | a_0) = \mathrm{E}[\mathrm{E}[Y | A_0 = a_0, A_1 = a_1, L_1]] \]
|>
gform_data3 summarise(CFmean = mean(Y), .by = c(A0, A1)) |>
kable()
A0 | A1 | CFmean |
---|---|---|
0 | 0 | 60 |
すべてのパターンを推定
# 処置のパターンごとにデータを4つ作成
list(
|> mutate(A0 = 0, A1 = 0),
data_ind |> mutate(A0 = 0, A1 = 1),
data_ind |> mutate(A0 = 1, A1 = 0),
data_ind |> mutate(A0 = 1, A1 = 1)
data_ind |>
) # 各パターンにおいてL1の値を予測
map(
|> mutate(L1 = predict(model_L1, newdata = data))
\(data) data |>
) # 予測したL1の値とトリートメントを用いて、Potential Outcomeを予測
map(
|> mutate(Y_pred = predict(model_Y, newdata = data))
\(data) data |>
) bind_rows() |>
summarise(
CFmean = mean(Y_pred),
.by = c(A0, A1)
|>
) kable()
A0 | A1 | CFmean |
---|---|---|
0 | 0 | 60 |
0 | 1 | 60 |
1 | 0 | 60 |
1 | 1 | 60 |
共変量が増えるとそれだけシミュレートする変数が増えるので大変そう…
疑問点&残された課題
- Marginal structural modelのパラメータはg-formulaでは推定できないのか?
- L-TMLEの推定ができるパッケージではMSMのパラメータも推定できる
- G-formulaとRegression with residualsの関係が気になった
- どちらも共変量のモデリングを行うが、g-fomulaは予測値を用いるのに対して、RWRは残差の方を用いる
- 本当は今回の課題でTMLEまで実装したかった
- 理解が追いつかず断念