文献
Sudharsanan, Nikkil & Maarten J. Bijlsma, 2021, “Educational Note: Causal Decomposition of Population Health Differences Using Monte Carlo Integration and the G-Formula,” International Journal of Epidemiology , 50(6): 2098–2107 (10.1093/ije/dyab090 ).
下準備
library (tidyverse)
library (cfdecomp)
kable <- partial (
knitr:: kable,
digits = 3
)
# the decomposition functions in our package are computationally intensive
# to make the example run quick, I perform it on a subsample (n=125) of the data:
set.seed (100 )
data <-
cfd.example.data[sample (1000 ),] |>
select (SES, age, med.gauss, out.gauss, id) |>
as_tibble ()
パッケージによる推定
# cfd.mean
mean.results.1 <-
cfd.mean (
formula.y = out.gauss ~ SES * med.gauss * age,
formula.m = med.gauss ~ SES * age,
mediator = 'med.gauss' ,
group = 'SES' ,
data = as.data.frame (data),
family.y = 'gaussian' ,
family.m = 'gaussian' ,
bs.size= 250 ,
mc.size= 10 ,
alpha= 0.05 ,
# cluster.sample=FALSE,
# cluster.name='id'
)
SES2とSES3のmediatorの分布をSES1のmediatorの分布に揃える
tibble (
category = c ("SES1" , "SES2" , "SES3" ),
factual_mean = c (
mean (mean.results.1 $ out_nc_y[, 1 ]),
mean (mean.results.1 $ out_nc_y[, 2 ]),
mean (mean.results.1 $ out_nc_y[, 3 ])
),
# and after giving the gaussian mediator of SES group 2 the distribution of the one in group 1
# the difference becomes:
counterfactual_mean = c (
mean (mean.results.1 $ out_cf_y[, 1 ]),
mean (mean.results.1 $ out_cf_y[, 2 ]),
mean (mean.results.1 $ out_cf_y[, 3 ])
)) |>
kable ()
SES1
4.306
4.306
SES2
3.239
3.511
SES3
2.212
2.946
step 1: regression estimates
mediatorとoutcomeのモデルをデータから推定
mediator_model <- lm (med.gauss ~ SES * age, data = data)
outcome_model <- lm (out.gauss ~ SES * age * med.gauss, data = data)
step 2: simulate the natural-course pseudo-population
推定したmediatorのモデルから、mediatorの「分布」を再現
\[\begin{align*}
Med_i = \mathrm{E}[Med | X] + e_i \\
e_i \sim \mathrm{N}(0, \sigma)
\end{align*}\]
# predict mediator
# mediatorの「分布」のパラメータを取得
pred_mean_m <- predict (mediator_model, newdata = data, type = "response" )
residual_ref_m <- mediator_model$ residuals
sd_ref_m <- sd (residual_ref_m)
推定したパラメータをもとに、mediatorの値をシミュレート
df_nc_med <-
data |>
mutate (
# ランダム性なし
pred_med = pred_mean_m,
# ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
pred_med_draw_1 = rnorm (n (), mean = pred_mean_m, sd = sd_ref_m),
# ランダム性をもたせる2(残差からランダムにサンプリング)
pred_med_draw_2 = pred_mean_m + sample (residual_ref_m, n (), replace = TRUE )
)
df_nc_med |>
summarise (
across (c (med.gauss, pred_med: pred_med_draw_2), mean),
.by = SES
) |>
arrange (SES) |>
kable ()
1
8.428
8.428
8.418
8.286
2
7.223
7.223
7.276
7.454
3
5.410
5.410
5.579
5.374
シミュレートした値をoutcomeモデルに代入して予測値を計算・集計
df_nc_med |>
mutate (
# ランダム性なし
pred_out = predict (
outcome_model, newdata = df_nc_med |> mutate (med.gauss = pred_med)
),
# ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
pred_out_draw_1 = predict (
outcome_model, newdata = df_nc_med |> mutate (med.gauss = pred_med_draw_1)
),
# ランダム性をもたせる2(残差からランダムにサンプリング)
pred_out_draw_2 = predict (
outcome_model, newdata = df_nc_med |> mutate (med.gauss = pred_med_draw_2)
)
) |>
summarise (
across (c (out.gauss, pred_out: pred_out_draw_2), mean),
.by = SES
) |>
arrange (SES) |>
kable ()
1
4.309
4.309
4.312
4.286
2
3.237
3.237
3.249
3.292
3
2.217
2.217
2.254
2.214
step 3: simulate the counterfactual pseudo-population
推定したmediatorのモデルにおいて、全員のSESが1だった場合のmediatorの分布を再現
全員のSESを1にしてmediatorのパラメータを取得
回帰モデルでは残差の部分は共変量に依存しない(SES間で分布が同じ、平均0・共通の標準偏差の正規分布)
ならばSESが1のグループの標準偏差を使わなくても良いのでは?(全体の標準偏差でもよい)
標準偏差もグループによって異なる、といったモデルの場合にはどうなるか?
# 平均
pred_mean_m_SES1 <- predict (mediator_model, newdata = data |> mutate (SES = '1' ))
# SES = 1のグループの残差
residual_ref_m_SES1 <-
broom:: augment (mediator_model) |>
filter (SES == '1' ) |>
pull (.resid)
# 標準偏差
sd_ref_m_SES1 <- sd (residual_ref_m_SES1)
推定したパラメータをもとに、mediatorの値をシミュレート
df_cf_med <-
data |>
mutate (
# ランダム性なし
pred_med_SES1 = pred_mean_m_SES1,
# ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
pred_med_draw_1_SES1 = rnorm (n (), mean = pred_mean_m_SES1, sd = sd_ref_m_SES1),
# ランダム性をもたせる2(残差からランダムにサンプリング)
pred_med_draw_2_SES1 = pred_mean_m_SES1 + sample (residual_ref_m_SES1, n (), replace = TRUE )
)
df_cf_med |>
summarise (
across (c (med.gauss, pred_med_SES1: pred_med_draw_2_SES1), mean),
.by = SES
) |>
arrange (SES) |>
kable ()
1
8.428
8.428
8.386
8.632
2
7.223
8.337
8.258
8.292
3
5.410
8.442
8.292
8.593
シミュレートした値をoutcomeモデルに代入して予測値を計算・集計
df_cf_med |>
mutate (
# ランダム性なし
pred_out_SES1 = predict (
outcome_model, newdata = df_cf_med |> mutate (med.gauss = pred_med_SES1)
),
# ランダム性をもたせる1(推定した標準偏差のパラメータを使用)
pred_out_draw_1_SES1 = predict (
outcome_model, newdata = df_cf_med |> mutate (med.gauss = pred_med_draw_1_SES1)
),
# ランダム性をもたせる2(残差からランダムにサンプリング)
pred_out_draw_2_SES1 = predict (
outcome_model, newdata = df_cf_med |> mutate (med.gauss = pred_med_draw_2_SES1)
)
) |>
summarise (
across (c (out.gauss, pred_out_SES1: pred_out_draw_2_SES1), mean),
.by = SES
) |>
arrange (SES) |>
kable ()
1
4.309
4.309
4.303
4.347
2
3.237
3.507
3.485
3.498
3
2.217
2.949
2.920
2.979
monte carloとbootstrapの実装
実際にはrandom drawは一回ではなく何回か行うことで不確実性を表現する
標準誤差の推定のためにbootstrap法も必要
まずはtreatmentとoutcomeのモデルを推定し、パラメータを取得
# パラメータ推定
estimate_model <- function (data) {
mediator_model <- lm (med.gauss ~ SES * age, data = data)
outcome_model <- lm (out.gauss ~ SES * age * med.gauss, data = data)
pred_mean_m <- predict (mediator_model, newdata = data, type = "response" )
residual_ref_m <- mediator_model$ residuals
sd_ref_m <- sd (residual_ref_m)
pred_mean_m_SES1 <- predict (mediator_model, newdata = data |> mutate (SES = '1' ))
residual_ref_m_SES1 <-
broom:: augment (mediator_model) |>
filter (SES == '1' ) |>
pull (.resid)
sd_ref_m_SES1 <- sd (residual_ref_m_SES1)
}
パラメータをもとにmediatorをシミュレートするのを何回か繰り返す
montecarlo_sampling <- function (data, mc = 10 ) {
# パラメータ推定
estimate_model (data)
# モンテカルロシミュレーション
map (1 : mc, \(mc) {
# mediatorサンプリング
boot_sample <-
data |>
mutate (
pred_med_draw_1 = rnorm (n (), mean = pred_mean_m, sd = sd_ref_m),
pred_med_draw_2 = pred_mean_m + sample (residual_ref_m, n (), replace = TRUE ),
pred_med_draw_1_SES1 = rnorm (n (), mean = pred_mean_m_SES1, sd = sd_ref_m_SES1),
pred_med_draw_2_SES1 = pred_mean_m_SES1 + sample (residual_ref_m_SES1, n (), replace = TRUE ),
)
# サンプリングしたものからoutcome予測
boot_sample |>
mutate (
pred_out_draw_1 = predict (
outcome_model, newdata = boot_sample |> mutate (med.gauss = pred_med_draw_1)
),
pred_out_draw_2 = predict (
outcome_model, newdata = boot_sample |> mutate (med.gauss = pred_med_draw_2)
),
pred_out_draw_1_SES1 = predict (
outcome_model, newdata = boot_sample |> mutate (med.gauss = pred_med_draw_1_SES1)
),
pred_out_draw_2_SES1 = predict (
outcome_model, newdata = boot_sample |> mutate (med.gauss = pred_med_draw_2_SES1)
)
) |>
group_by (SES) |>
summarise (across (c (pred_out_draw_1: pred_out_draw_2_SES1), mean))
}) |>
list_rbind (names_to = 'mc' ) |>
# シミュレーション結果を集計
group_by (SES) |>
summarise (across (c (pred_out_draw_1: pred_out_draw_2_SES1), mean))
}
result <-
map (1 : 250 , \(index) {
# bootstrapサンプル発生
bootsample <- slice_sample (data, prop = 1 , replace = TRUE )
montecarlo_sampling (bootsample, mc = 10 )
}) |>
list_rbind (names_to = 'index' )
result |>
pivot_longer (
cols = c (pred_out_draw_1: pred_out_draw_2_SES1),
names_to = 'type' ,
values_to = 'value'
) |>
summarise (
mean = mean (value),
conf.low = quantile (value, 0.025 ),
conf.high = quantile (value, 0.975 ),
.by = c (SES, type)
) |>
arrange (type) |>
kable ()
1
pred_out_draw_1
4.015
3.937
4.091
2
pred_out_draw_1
3.183
3.124
3.251
3
pred_out_draw_1
2.645
2.599
2.688
1
pred_out_draw_1_SES1
4.276
4.193
4.351
2
pred_out_draw_1_SES1
3.513
3.446
3.582
3
pred_out_draw_1_SES1
2.970
2.934
3.000
1
pred_out_draw_2
4.015
3.938
4.094
2
pred_out_draw_2
3.182
3.123
3.253
3
pred_out_draw_2
2.646
2.600
2.687
1
pred_out_draw_2_SES1
4.276
4.183
4.358
2
pred_out_draw_2_SES1
3.515
3.457
3.584
3
pred_out_draw_2_SES1
2.970
2.941
2.998