Sudharsanan & Bijlsma (2021)

Causal Inference
Author

Kentaro Kamada

Published

February 22, 2024

文献

Sudharsanan, Nikkil & Maarten J. Bijlsma, 2021, “Educational Note: Causal Decomposition of Population Health Differences Using Monte Carlo Integration and the G-Formula,” International Journal of Epidemiology, 50(6): 2098–2107 (10.1093/ije/dyab090).

下準備

  • パッケージ
library(tidyverse)
library(cfdecomp)

kable <- partial(
  knitr::kable,
  digits = 3
)
  • データ
# the decomposition functions in our package are computationally intensive
# to make the example run quick, I perform it on a subsample (n=125) of the data:
set.seed(100)
data <- 
  cfd.example.data[sample(1000),] |> 
  select(SES, age, med.gauss, out.gauss, id) |> 
  as_tibble()

パッケージによる推定

  • cfdecompパッケージで推定
# cfd.mean 
mean.results.1 <- 
  cfd.mean(
    formula.y = out.gauss ~ SES * med.gauss * age,
    formula.m = med.gauss ~ SES * age,
    mediator = 'med.gauss',
    group = 'SES',
    data = as.data.frame(data),
    family.y = 'gaussian',
    family.m = 'gaussian',
    bs.size=250,
    mc.size=10,
    alpha=0.05,
    # cluster.sample=FALSE,
    # cluster.name='id'
  )
  • SES2とSES3のmediatorの分布をSES1のmediatorの分布に揃える
tibble(
  category = c("SES1", "SES2", "SES3"),
  factual_mean = c(
    mean(mean.results.1$out_nc_y[, 1]),
    mean(mean.results.1$out_nc_y[, 2]),
    mean(mean.results.1$out_nc_y[, 3])
  ),
  # and after giving the gaussian mediator of SES group 2 the distribution of the one in group 1
  # the difference becomes:
  counterfactual_mean = c(
    mean(mean.results.1$out_cf_y[, 1]),
    mean(mean.results.1$out_cf_y[, 2]),
    mean(mean.results.1$out_cf_y[, 3])
  )) |> 
  kable()
category factual_mean counterfactual_mean
SES1 4.306 4.306
SES2 3.239 3.511
SES3 2.212 2.946

自分でコードを書いてみる

step 1: regression estimates

  • mediatorとoutcomeのモデルをデータから推定
mediator_model <- lm(med.gauss ~ SES * age, data = data)
outcome_model <- lm(out.gauss ~ SES * age * med.gauss, data = data)

step 2: simulate the natural-course pseudo-population

  • 推定したmediatorのモデルから、mediatorの「分布」を再現

\[\begin{align*} Med_i = \mathrm{E}[Med | X] + e_i \\ e_i \sim \mathrm{N}(0, \sigma) \end{align*}\]

  • mediatorの分布のパラメータを取得
# predict mediator
# mediatorの「分布」のパラメータを取得
pred_mean_m <- predict(mediator_model, newdata = data, type = "response")
residual_ref_m <- mediator_model$residuals
sd_ref_m <- sd(residual_ref_m)
  • 推定したパラメータをもとに、mediatorの値をシミュレート
df_nc_med <- 
  data |> 
  mutate(
    # ランダム性なし
    pred_med = pred_mean_m,
    # ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
    pred_med_draw_1 = rnorm(n(), mean = pred_mean_m, sd = sd_ref_m),
    # ランダム性をもたせる2(残差からランダムにサンプリング)
    pred_med_draw_2 = pred_mean_m + sample(residual_ref_m, n(), replace = TRUE)
  )

df_nc_med |> 
  summarise(
    across(c(med.gauss, pred_med:pred_med_draw_2), mean),
    .by = SES
  ) |> 
  arrange(SES) |> 
  kable()
SES med.gauss pred_med pred_med_draw_1 pred_med_draw_2
1 8.428 8.428 8.418 8.286
2 7.223 7.223 7.276 7.454
3 5.410 5.410 5.579 5.374
  • シミュレートした値をoutcomeモデルに代入して予測値を計算・集計
df_nc_med |> 
  mutate(
    # ランダム性なし
    pred_out = predict(
      outcome_model, newdata = df_nc_med |> mutate(med.gauss = pred_med)
    ),
    # ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
    pred_out_draw_1 = predict(
      outcome_model, newdata = df_nc_med |> mutate(med.gauss = pred_med_draw_1)
    ),
    # ランダム性をもたせる2(残差からランダムにサンプリング)
    pred_out_draw_2 = predict(
      outcome_model, newdata = df_nc_med |> mutate(med.gauss = pred_med_draw_2)
    )
  ) |> 
  summarise(
    across(c(out.gauss, pred_out:pred_out_draw_2), mean),
    .by = SES
  ) |> 
  arrange(SES) |> 
  kable()
SES out.gauss pred_out pred_out_draw_1 pred_out_draw_2
1 4.309 4.309 4.312 4.286
2 3.237 3.237 3.249 3.292
3 2.217 2.217 2.254 2.214

step 3: simulate the counterfactual pseudo-population

  • 推定したmediatorのモデルにおいて、全員のSESが1だった場合のmediatorの分布を再現
  • 全員のSESを1にしてmediatorのパラメータを取得
  • 回帰モデルでは残差の部分は共変量に依存しない(SES間で分布が同じ、平均0・共通の標準偏差の正規分布)
    • ならばSESが1のグループの標準偏差を使わなくても良いのでは?(全体の標準偏差でもよい)
    • 標準偏差もグループによって異なる、といったモデルの場合にはどうなるか?
# 平均
pred_mean_m_SES1 <- predict(mediator_model, newdata = data |> mutate(SES = '1'))
# SES = 1のグループの残差
residual_ref_m_SES1 <-
  broom::augment(mediator_model)  |>
  filter(SES == '1') |> 
  pull(.resid)
# 標準偏差
sd_ref_m_SES1 <- sd(residual_ref_m_SES1)
  • 推定したパラメータをもとに、mediatorの値をシミュレート
df_cf_med <- 
  data |> 
  mutate(
    # ランダム性なし
    pred_med_SES1 = pred_mean_m_SES1,
    # ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
    pred_med_draw_1_SES1 = rnorm(n(), mean = pred_mean_m_SES1, sd = sd_ref_m_SES1),
    # ランダム性をもたせる2(残差からランダムにサンプリング)
    pred_med_draw_2_SES1 = pred_mean_m_SES1 + sample(residual_ref_m_SES1, n(), replace = TRUE)
  )

df_cf_med |> 
  summarise(
    across(c(med.gauss, pred_med_SES1:pred_med_draw_2_SES1), mean),
    .by = SES
  ) |> 
  arrange(SES) |> 
  kable()
SES med.gauss pred_med_SES1 pred_med_draw_1_SES1 pred_med_draw_2_SES1
1 8.428 8.428 8.386 8.632
2 7.223 8.337 8.258 8.292
3 5.410 8.442 8.292 8.593
  • シミュレートした値をoutcomeモデルに代入して予測値を計算・集計
df_cf_med |> 
  mutate(
    # ランダム性なし
    pred_out_SES1 = predict(
      outcome_model, newdata = df_cf_med |> mutate(med.gauss = pred_med_SES1)
    ),
    # ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
    pred_out_draw_1_SES1 = predict(
      outcome_model, newdata = df_cf_med |> mutate(med.gauss = pred_med_draw_1_SES1)
    ),
    # ランダム性をもたせる2(残差からランダムにサンプリング)
    pred_out_draw_2_SES1 = predict(
      outcome_model, newdata = df_cf_med |> mutate(med.gauss = pred_med_draw_2_SES1)
    )
  ) |> 
  summarise(
    across(c(out.gauss, pred_out_SES1:pred_out_draw_2_SES1), mean),
    .by = SES
  ) |> 
  arrange(SES) |> 
  kable()
SES out.gauss pred_out_SES1 pred_out_draw_1_SES1 pred_out_draw_2_SES1
1 4.309 4.309 4.303 4.347
2 3.237 3.507 3.485 3.498
3 2.217 2.949 2.920 2.979

monte carloとbootstrapの実装

  • 実際にはrandom drawは一回ではなく何回か行うことで不確実性を表現する
  • 標準誤差の推定のためにbootstrap法も必要
  • まずはtreatmentとoutcomeのモデルを推定し、パラメータを取得
# パラメータ推定
estimate_model <- function(data) {
  
  mediator_model <- lm(med.gauss ~ SES * age, data = data)
  outcome_model <- lm(out.gauss ~ SES * age * med.gauss, data = data)
  
  pred_mean_m <- predict(mediator_model, newdata = data, type = "response")
  residual_ref_m <- mediator_model$residuals
  sd_ref_m <- sd(residual_ref_m)

  pred_mean_m_SES1 <- predict(mediator_model, newdata = data |> mutate(SES = '1'))
  residual_ref_m_SES1 <-
    broom::augment(mediator_model)  |>
    filter(SES == '1') |> 
    pull(.resid)
  sd_ref_m_SES1 <- sd(residual_ref_m_SES1)

}
  • パラメータをもとにmediatorをシミュレートするのを何回か繰り返す
montecarlo_sampling <- function(data, mc = 10) {
  # パラメータ推定
  estimate_model(data)
  # モンテカルロシミュレーション
  map(1:mc, \(mc) {
    # mediatorサンプリング
    boot_sample <- 
      data |> 
      mutate(
        pred_med_draw_1 = rnorm(n(), mean = pred_mean_m, sd = sd_ref_m),
        pred_med_draw_2 = pred_mean_m + sample(residual_ref_m, n(), replace = TRUE),
        pred_med_draw_1_SES1 = rnorm(n(), mean = pred_mean_m_SES1, sd = sd_ref_m_SES1),
        pred_med_draw_2_SES1 = pred_mean_m_SES1 + sample(residual_ref_m_SES1, n(), replace = TRUE),
      )
    # サンプリングしたものからoutcome予測
    boot_sample |> 
      mutate(
        pred_out_draw_1 = predict(
          outcome_model, newdata = boot_sample |> mutate(med.gauss = pred_med_draw_1)
        ),
        pred_out_draw_2 = predict(
          outcome_model, newdata = boot_sample |> mutate(med.gauss = pred_med_draw_2)
        ),
        pred_out_draw_1_SES1 = predict(
          outcome_model, newdata = boot_sample |> mutate(med.gauss = pred_med_draw_1_SES1)
        ),
        pred_out_draw_2_SES1 = predict(
          outcome_model, newdata = boot_sample |> mutate(med.gauss = pred_med_draw_2_SES1)
        )
      ) |> 
      group_by(SES) |> 
      summarise(across(c(pred_out_draw_1:pred_out_draw_2_SES1), mean))
    
  }) |> 
    list_rbind(names_to = 'mc') |> 
    # シミュレーション結果を集計
    group_by(SES) |> 
    summarise(across(c(pred_out_draw_1:pred_out_draw_2_SES1), mean))
  
}
  • これをbootstrapで繰り返す
result <- 
  map(1:250, \(index) {
    # bootstrapサンプル発生
    bootsample <- slice_sample(data, prop = 1, replace = TRUE)
    montecarlo_sampling(bootsample, mc = 10) 
  }) |> 
  list_rbind(names_to = 'index')
  • 結果を集計
result |> 
  pivot_longer(
    cols = c(pred_out_draw_1:pred_out_draw_2_SES1), 
    names_to = 'type', 
    values_to = 'value'
  ) |> 
  summarise(
    mean = mean(value), 
    conf.low = quantile(value, 0.025),
    conf.high = quantile(value, 0.975),
    .by = c(SES, type)
  ) |> 
  arrange(type) |> 
  kable()
SES type mean conf.low conf.high
1 pred_out_draw_1 4.015 3.937 4.091
2 pred_out_draw_1 3.183 3.124 3.251
3 pred_out_draw_1 2.645 2.599 2.688
1 pred_out_draw_1_SES1 4.276 4.193 4.351
2 pred_out_draw_1_SES1 3.513 3.446 3.582
3 pred_out_draw_1_SES1 2.970 2.934 3.000
1 pred_out_draw_2 4.015 3.938 4.094
2 pred_out_draw_2 3.182 3.123 3.253
3 pred_out_draw_2 2.646 2.600 2.687
1 pred_out_draw_2_SES1 4.276 4.183 4.358
2 pred_out_draw_2_SES1 3.515 3.457 3.584
3 pred_out_draw_2_SES1 2.970 2.941 2.998