複数時点での処置の因果効果の推定

Causal Inference
Author

Kentaro Kamada

Published

February 22, 2024

Modified

May 14, 2025

Hernán and Robins (2020) のPart3では、複数時点での処置の因果推論について議論されている。ここでは、IPW、g-formulaを用いた因果効果の推定を実際にやってみる。

下準備

ライブラリなど

library(tidyverse)
library(broom)

kable <- partial(
  knitr::kable,
  digits = 3
)

set.seed(95)

データの作成

以下のDAGを考える(Hernán and Robins 2020, fig. 20.3)

Treatmentは\(A_0\)\(A_1\)であるが、ここから\(Y\)へのpathはないので、因果効果は全ての組み合わせについて0になる

具体的には、potential outcomeのすべての組み合わせについて、差分を取った値が0になる

\[\begin{align} \mathrm{E}[Y^{0, 0} - Y^{1, 0}] &= 0\\ \mathrm{E}[Y^{0, 0} - Y^{0, 1}] &= 0\\ \mathrm{E}[Y^{0, 0} - Y^{1, 1}] &= 0\\ \mathrm{E}[Y^{1, 0} - Y^{0, 1}] &= 0\\ \mathrm{E}[Y^{1, 0} - Y^{1, 1}] &= 0\\ \mathrm{E}[Y^{0, 1} - Y^{1, 1}] &= 0 \end{align}\]

上記のDAGに合うように作成されたデータが以下である(Hernán and Robins 2020, Table 20.1)

data_agg <- 
  tribble(
    ~N, ~A0, ~L1, ~A1, ~Y, 
    2400, 0, 0, 0, 84, 
    1600, 0, 0, 1, 84, 
    2400, 0, 1, 0, 52, 
    9600, 0, 1, 1, 52, 
    4800, 1, 0, 0, 76, 
    3200, 1, 0, 1, 76, 
    1600, 1, 1, 0, 44, 
    6400, 1, 1, 1, 44
  )

data_agg |> kable()
N A0 L1 A1 Y
2400 0 0 0 84
1600 0 0 1 84
2400 0 1 0 52
9600 0 1 1 52
4800 1 0 0 76
3200 1 0 1 76
1600 1 1 0 44
6400 1 1 1 44

これを個人レベルのデータに変換

data_ind <- uncount(data_agg, N)

data_ind
# A tibble: 32,000 × 4
      A0    L1    A1     Y
   <dbl> <dbl> <dbl> <dbl>
 1     0     0     0    84
 2     0     0     0    84
 3     0     0     0    84
 4     0     0     0    84
 5     0     0     0    84
 6     0     0     0    84
 7     0     0     0    84
 8     0     0     0    84
 9     0     0     0    84
10     0     0     0    84
# ℹ 31,990 more rows

各処置の因果効果

  • 各処置の効果を確認
  • 処置からアウトカムへの直接のパスはないので、各時点の処置の因果効果も0になる

\(A_0\)の因果効果

data_ind |> 
  lm(Y ~ A0, data = _) |> 
  tidy() |> 
  kable()
term estimate std.error statistic p.value
(Intercept) 60 0.118 507.077 0
A0 0 0.167 0.000 1

\(A_1\)の因果効果

  • 以下のバックドアが存在するので\(L_1\)を統制
  • \(A_1 \gets L_1 \gets U_1 \to Y\)
data_ind |> 
  lm(Y ~ A1 + L1, data = _) |>
  tidy() |> 
  kable()
term estimate std.error statistic p.value
(Intercept) 78.667 0.040 1943.969 0
A1 0.000 0.050 0.000 1
L1 -29.867 0.049 -611.652 0

collider bias

  • \(A_0\)の因果効果の推定において、\(L_1\)を条件づける
  • \(A_0 \to \boxed{L_1} \gets U_1 \to Y\)というパスが開いて、バイアスをもたらす
data_ind |> 
  lm(Y ~ A0 + L1, data = _) |> 
  tidy() |> 
  kable()
term estimate std.error statistic p.value
(Intercept) 84 0 3.657678e+14 0
A0 -8 0 -3.597746e+13 0
L1 -32 0 -1.393401e+14 0

IPW

  • \(\bar{A}_k = (A_1, A_2, \dots, A_k)\)\(k\)時点までの処置の履歴
  • \(\bar{L}_k = (L_1, L_2, \dots, L_k)\)\(k\)時点までの共変量の履歴

以下のウェイトを作成する

\[ W^{\bar{A}_k} = \prod_{k=0}^{K}\frac{1}{f(A_k | \bar{A}_{k-1},\bar{L}_{k})} \]

2時点では以下のように書ける

\[ W^{A_0, A_1} = \frac{1}{f(A_0 | L_0)} \times \frac{1}{f(A_1 | A_0, L_0, L_1)} \]

さらに、今回はベースライン共変量がない(\(L_0 = \varnothing\))ので

\[ W^{A_0, A_1} = \frac{1}{f(A_0)} \times \frac{1}{f(A_1 | A_0, L_1)} \]

Stabilized IP weights も推定してみる

\[ SW^{A_0, A_1} = \frac{f(A_0)}{f(A_0)} \times \frac{f(A_1 | A_0)}{f(A_1 | A_0, L_1)} \]

各処置に対するウェイトの推定

まずは、\(f(A_0), f(A_1|A_0,L_1), f(A_1|A_0)\)の3つを推定する

# 各時点の処置を予測するモデル
model_A0 <- glm(A0 ~ 1, data = data_ind, family = "binomial")
model_A1 <- glm(A1 ~ A0 * L1, data = data_ind, family = "binomial")

# Stabilized Weightsの分子を予測するモデル
model_A1_without_L1 <- glm(A1 ~ A0, data = data_ind, family = 'binomial')

予測値を算出

ipw_data1 <-
  data_ind |> 
  mutate(
    # 各モデルの予測値を算出
    pred_A0 = predict(model_A0, type = 'response'),
    pred_A1 = predict(model_A1, type = 'response'),
    pred_A1_without_L1 = predict(model_A1_without_L1, type = 'response'),
  )

ipw_data1 |> 
  count(pick(everything())) |> 
  kable()
A0 L1 A1 Y pred_A0 pred_A1 pred_A1_without_L1 n
0 0 0 84 0.5 0.4 0.7 2400
0 0 1 84 0.5 0.4 0.7 1600
0 1 0 52 0.5 0.8 0.7 2400
0 1 1 52 0.5 0.8 0.7 9600
1 0 0 76 0.5 0.4 0.6 4800
1 0 1 76 0.5 0.4 0.6 3200
1 1 0 44 0.5 0.8 0.6 1600
1 1 1 44 0.5 0.8 0.6 6400

各処置に対するウェイトを作成

ipw_data2 <-
  ipw_data1 |> 
  mutate(
    # A0に対するウェイト
    ipw_A0 = case_when(
      A0 == 1 ~ 1 / pred_A0,
      A0 == 0 ~ 1 / (1 - pred_A0)
    ),
    # A0に対するSW
    sw_A0 = case_when(
      A0 == 1 ~ pred_A0 / pred_A0,
      A0 == 0 ~ (1 - pred_A0) / (1 - pred_A0)
    ),
    # A1に対するウェイト
    ipw_A1 = case_when(
      A1 == 1 ~ 1 / pred_A1,
      A1 == 0 ~ 1 / (1 - pred_A1)
    ),
    # A1に対するSW
    sw_A1 = case_when(
      A1 == 1 ~ pred_A1_without_L1 / pred_A1,
      A1 == 0 ~ (1 - pred_A1_without_L1) / (1 - pred_A1)
    ),
  )

ipw_data2 |>
  count(pick(everything())) |>
  kable()
A0 L1 A1 Y pred_A0 pred_A1 pred_A1_without_L1 ipw_A0 sw_A0 ipw_A1 sw_A1 n
0 0 0 84 0.5 0.4 0.7 2 1 1.667 0.500 2400
0 0 1 84 0.5 0.4 0.7 2 1 2.500 1.750 1600
0 1 0 52 0.5 0.8 0.7 2 1 5.000 1.500 2400
0 1 1 52 0.5 0.8 0.7 2 1 1.250 0.875 9600
1 0 0 76 0.5 0.4 0.6 2 1 1.667 0.667 4800
1 0 1 76 0.5 0.4 0.6 2 1 2.500 1.500 3200
1 1 0 44 0.5 0.8 0.6 2 1 5.000 2.000 1600
1 1 1 44 0.5 0.8 0.6 2 1 1.250 0.750 6400

IP weightの作成

各時点のウェイトを掛け合わせる。たしかにSWのほうがウェイトが狭い範囲にまとまっていることがわかる。

ipw_data3 <-
  ipw_data2 |> 
  mutate(
    # A0に対するウェイトとA1に対するウェイトをかけ算
    ipw = ipw_A0 * ipw_A1,
    sw = sw_A0 * sw_A1
  ) 

ipw_data3 |> 
  count(pick(everything())) |> 
  kable()
A0 L1 A1 Y pred_A0 pred_A1 pred_A1_without_L1 ipw_A0 sw_A0 ipw_A1 sw_A1 ipw sw n
0 0 0 84 0.5 0.4 0.7 2 1 1.667 0.500 3.333 0.500 2400
0 0 1 84 0.5 0.4 0.7 2 1 2.500 1.750 5.000 1.750 1600
0 1 0 52 0.5 0.8 0.7 2 1 5.000 1.500 10.000 1.500 2400
0 1 1 52 0.5 0.8 0.7 2 1 1.250 0.875 2.500 0.875 9600
1 0 0 76 0.5 0.4 0.6 2 1 1.667 0.667 3.333 0.667 4800
1 0 1 76 0.5 0.4 0.6 2 1 2.500 1.500 5.000 1.500 3200
1 1 0 44 0.5 0.8 0.6 2 1 5.000 2.000 10.000 2.000 1600
1 1 1 44 0.5 0.8 0.6 2 1 1.250 0.750 2.500 0.750 6400

Counterfactual meanの推定

すべてのパターンで同じ値なので、差分をとれば0になる

res_ipw <- 
  ipw_data3 |>
  # 重み付け推定
  summarise(
    CFmean_ipw = weighted.mean(Y, ipw), 
    CFmean_sw = weighted.mean(Y, sw), 
    .by = c(A0, A1),
  ) |>
  arrange(A0, A1)

res_ipw |> 
  kable()
A0 A1 CFmean_ipw CFmean_sw
0 0 60 60
0 1 60 60
1 0 60 60
1 1 60 60

Marginal structural modelの推定

以下のようなmarginal structural modelを考えてみる

\[ \mathrm{E}[Y^{A_0,A_1}] = \beta_0 + \beta_1 (A_0 + A_1) \]

当然ながら\(\beta_1 = 0\)となる

estimatr::lm_robust(Y ~ I(A0 + A1), data = ipw_data3, weight = ipw) |> 
  broom::tidy() |> 
  kable()
term estimate std.error statistic p.value conf.low conf.high df outcome
(Intercept) 60 0.154 389.52 0 59.698 60.302 31998 Y
I(A0 + A1) 0 0.122 0.00 1 -0.240 0.240 31998 Y
estimatr::lm_robust(Y ~ I(A0 + A1), data = ipw_data3, weights = sw) |> 
  broom::tidy() |> 
  kable()
term estimate std.error statistic p.value conf.low conf.high df outcome
(Intercept) 60 0.162 369.84 0 59.682 60.318 31998 Y
I(A0 + A1) 0 0.127 0.00 1 -0.250 0.250 31998 Y

G-formula

Sequential exchangeabilityが成立しているとき、以下の式が成り立つ

\[ \mathrm{E}[Y^{\bar{A}_k}] = \sum_\bar{l}\mathrm{E}[Y | \bar{A}_k = \bar{a}_k, \bar{L}_k = \bar{l}_k]\prod_{k = 0}^K f(l_k | \bar{a}_{k-1},\bar{l}_{k-1}) \]

一般式はイカツイが、2時点かつ\(L_0\)がない今回の条件だと以下のようになる

\[ \mathrm{E}[Y^{A_0,A_1}] = \sum_{l_1} \mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]f(l_1 | a_0) \]

アウトカムと共変量の予測モデルを推定

  • \(\mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]\)\(f(l_1 | a_0)\)のモデルを推定
  • 1時点の時と比較すると、共変量の予測モデルが必要になる点が新しい
# アウトカムモデル
model_Y <- lm(Y ~ A0*A1*L1, data = data_ind)
# L1の予測モデル
model_L1 <- lm(L1 ~ A0, data = data_ind)

共変量の予測値を推定

ひとまず、\(A_0 = A_1 = 0\)のときのpotential outcome\(\mathrm{E}[Y^{0,0}]\)を推定することを考える

まずサンプル全員の処置の値を\(A_0 = A_1 = 0\)で置き換える

gform_data1 <- 
  data_ind |> 
  # 名前を変えておく
  rename(A0_obs = A0, A1_obs = A1, L1_obs = L1, Y_obs = Y) |> 
  # 処置の値を変更
  mutate(A0 = 0, A1 = 0)

gform_data1 |> 
  count(pick(everything())) |> 
  kable()
A0_obs L1_obs A1_obs Y_obs A0 A1 n
0 0 0 84 0 0 2400
0 0 1 84 0 0 1600
0 1 0 52 0 0 2400
0 1 1 52 0 0 9600
1 0 0 76 0 0 4800
1 0 1 76 0 0 3200
1 1 0 44 0 0 1600
1 1 1 44 0 0 6400

次に、\(A_0 = 0\)の状況での\(L_1\)の値を予測する。これが\(f(l_1 | 0)\)

\(L^{A_0 = 0}\)の値をシミュレートしているとも考えられる

gform_data2 <- 
  gform_data1 |> 
  # 全員がA0 = 0のときの共変量L1の値をシミュレート
  mutate(L1 = predict(model_L1, newdata = gform_data1))

gform_data2 |> 
  count(pick(everything())) |> 
  kable()
A0_obs L1_obs A1_obs Y_obs A0 A1 L1 n
0 0 0 84 0 0 0.75 2400
0 0 1 84 0 0 0.75 1600
0 1 0 52 0 0 0.75 2400
0 1 1 52 0 0 0.75 9600
1 0 0 76 0 0 0.75 4800
1 0 1 76 0 0 0.75 3200
1 1 0 44 0 0 0.75 1600
1 1 1 44 0 0 0.75 6400

アウトカムの予測値を推定

さらに、シミュレートした\(L_1\)と、\(A_0 = A_1 = 0\)のもとでの\(Y\)の値を予測する。これが\(\mathrm{E}[Y | A_0 = 0,A_1 = 0,L_1 = l_1]\)

gform_data3 <- 
  gform_data2 |> 
  mutate(Y = predict(model_Y, newdata = gform_data2))

gform_data3 |> 
  count(pick(everything())) |> 
  kable()
A0_obs L1_obs A1_obs Y_obs A0 A1 L1 Y n
0 0 0 84 0 0 0.75 60 2400
0 0 1 84 0 0 0.75 60 1600
0 1 0 52 0 0 0.75 60 2400
0 1 1 52 0 0 0.75 60 9600
1 0 0 76 0 0 0.75 60 4800
1 0 1 76 0 0 0.75 60 3200
1 1 0 44 0 0 0.75 60 1600
1 1 1 44 0 0 0.75 60 6400

サンプル全体で平均する

以下のように変形できるので、アウトカムの予測値をサンプル全体で平均すればよい

\[ \sum_{l_1} \mathrm{E}[Y | A_0 = a_0,A_1 = a_1,L_1 = l_1]f(l_1 | a_0) = \mathrm{E}[\mathrm{E}[Y | A_0 = a_0, A_1 = a_1, L_1]] \]

gform_data3 |> 
  summarise(CFmean = mean(Y), .by = c(A0, A1)) |> 
  kable()
A0 A1 CFmean
0 0 60

すべてのパターンを推定

# 処置のパターンごとにデータを4つ作成
list(
  data_ind |> mutate(A0 = 0, A1 = 0),
  data_ind |> mutate(A0 = 0, A1 = 1),
  data_ind |> mutate(A0 = 1, A1 = 0),
  data_ind |> mutate(A0 = 1, A1 = 1)
) |>
  # 各パターンにおいてL1の値を予測
  map(
    \(data) data |> mutate(L1 = predict(model_L1, newdata = data))
  ) |>
  # 予測したL1の値とトリートメントを用いて、Potential Outcomeを予測
  map(
    \(data) data |> mutate(Y_pred = predict(model_Y, newdata = data))
  ) |> 
  bind_rows() |> 
  summarise(
    CFmean = mean(Y_pred),
    .by = c(A0, A1)
  ) |> 
  kable()
A0 A1 CFmean
0 0 60
0 1 60
1 0 60
1 1 60

共変量が増えるとそれだけシミュレートする変数が増えるので大変そう…

疑問点&残された課題

  • Marginal structural modelのパラメータはg-formulaでは推定できないのか?
    • L-TMLEの推定ができるパッケージではMSMのパラメータも推定できる
  • G-formulaとRegression with residualsの関係が気になった
    • どちらも共変量のモデリングを行うが、g-fomulaは予測値を用いるのに対して、RWRは残差の方を用いる
  • 本当は今回の課題でTMLEまで実装したかった
    • 理解が追いつかず断念

References

Hernán, Miguel A., and James M. Robins. 2020. Causal Inference: What If. Boca Raton: Chapman & Hall/CRC. https://miguelhernan.org/whatifbook.