library(bartcs)
The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).
This tutorial will use The Infant Health and Development Program (IHDP) dataset. The dataset includes 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score. This dataset was first used by Hill (2011). My version of dataset is the first realization generated by Louizos et al. (2017) and you can find other versions in his github.
data(ihdp, package = "bartcs")
<- mbart(
fit Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 10,
num_chain = 4,
num_post_sample = 100,
num_burn_in = 100,
verbose = FALSE
)
fit#> `bartcs` fit by `mbart()`
#>
#> mean 2.5% 97.5%
#> ATE 3.971219 3.772492 4.191938
#> Y1 6.384208 6.219867 6.576574
#> Y0 2.412989 2.318502 2.515517
You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.
Both sbart()
and mbart()
fits multiple MCMC
chains. summary()
provides result and Gelman-Rubin
statistic to check convergence.
summary(fit)
#> `bartcs` fit by `mbart()`
#>
#> Treatment Value
#> Treated group : 1
#> Control group : 0
#>
#> Tree Parameters
#> Number of Tree : 10 Value of alpha : 0.95
#> Prob. of Grow : 0.28 Value of beta : 2
#> Prob. of Prune : 0.28 Value of nu : 3
#> Prob. of Change : 0.44 Value of q : 0.95
#>
#> Chain Parameters
#> Number of Chains : 4 Number of burn-in : 100
#> Number of Iter : 200 Number of thinning : 0
#> Number of Sample : 100
#>
#> Outcome Diagnostics
#> Gelman-Rubin : 0.9954304
#>
#> Outcome
#> estimand chain 2.5% 1Q mean median 3Q 97.5%
#> ATE 1 3.766454 3.871160 3.943341 3.941979 4.006615 4.133084
#> ATE 2 3.757339 3.896470 3.967855 3.985382 4.044467 4.175748
#> ATE 3 3.807881 3.918242 3.986488 3.988572 4.048306 4.189810
#> ATE 4 3.807688 3.905469 3.987192 3.987528 4.060553 4.218099
#> ATE agg 3.772492 3.900104 3.971219 3.972040 4.043685 4.191938
#> Y1 1 6.240561 6.334287 6.389011 6.393528 6.430904 6.560700
#> Y1 2 6.215269 6.341583 6.399921 6.402613 6.456707 6.586694
#> Y1 3 6.221501 6.309304 6.368682 6.366694 6.418269 6.547575
#> Y1 4 6.223379 6.313169 6.379219 6.363937 6.449049 6.569930
#> Y1 agg 6.219867 6.320230 6.384208 6.380211 6.443003 6.576574
#> Y0 1 2.362361 2.414919 2.445670 2.444196 2.478593 2.538175
#> Y0 2 2.359850 2.404359 2.432066 2.431125 2.456632 2.515024
#> Y0 3 2.308972 2.355458 2.382193 2.382278 2.414769 2.444218
#> Y0 4 2.309223 2.356522 2.392027 2.392113 2.425868 2.474419
#> Y0 agg 2.318502 2.377731 2.412989 2.413843 2.447588 2.515517
You can get posterior inclusion probability for each variables.
plot(fit, method = "pip")
Since inclusion_plot()
is a wrapper function of
ggcharts::bar_chart()
, you can use its arguments for better
plot.
plot(fit, method = "pip", top_n = 10)
plot(fit, method = "pip", threshold = 0.5)
With trace_plot()
, you can visually check trace of
effects or other parameters.
plot(fit, method = "trace")
plot(fit, method = "trace", "dir_alpha")
count_omp_thread()
#> [1] 6
Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will be not effective with small and moderate datasets. I recommend parallelization for data with size of at least 10,000.
For comparison purpose, I will create dataset with 20,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will set most parameters to 1.
<- sample(nrow(ihdp), 2e4, TRUE)
idx <- ihdp[idx, ]
ihdp
::microbenchmark(
microbenchmarksimple_mbart = mbart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = FALSE
),parallel_mbart = mbart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = TRUE
),times = 50
)#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> simple_mbart 59.8103 68.0068 72.41204 72.5930 76.8579 86.1473 50
#> parallel_mbart 56.5558 61.3377 68.89435 64.3631 68.1698 169.9934 50
Result show that parallelization gives better result.