causalOT
was developed to reproduce the methods in Optimal transport methods
for causal inference. The functions in the package are built to
construct weights to make distributions more same and estimate causal
effects. We recommend using the Causal Optimal Transport methods since
they are semi- to non-parametric. This document will describe some
simple usages of the functions in the package and should be enough to
get users started.
The weights can be estimated by using the calc_weight
function in the package. We select optimal hyperparameters through our
bootstrap-based algorithm and target the average treatment effect.
library(causalOT)
set.seed(1111)
<- Hainmueller$new(n = 128)
hainmueller $gen_data()
hainmueller
<- calc_weight(data = hainmueller, method = "Wasserstein",
weights add.divergence = TRUE, grid.search = TRUE,
estimand = "ATE",
verbose = FALSE)
These weights will balance distributions, making estimates of treatment effects the same. We can then estimate effects with
<- estimate_effect(data = hainmueller, weights = weights,
tau_hat hajek = TRUE, doubly.robust = FALSE,
estimand = "ATE")
This creates an object of class causalEffect
which can
be fed into the native R
function confint
to
calculate asymptotic confidence intervals.
<- confint(object = tau_hat, level = 0.95,
ci_tau method = "asymptotic",
model = "lm",
formula = list(control = "y ~.", treated = "y ~ ."))
This then gives the following estimate and C.I.
print(tau_hat$estimate)
#> [1] 0.2574432
print(ci_tau$CI)
#> [1] -0.08814654 0.60303297
Diagnostics are also an important part of deciding whether the weights perform well. There are several areas that we will explore:
Typically, estimated samples sizes with weights are calculated as
\(\sum_i 1/w_i^2\) and gives us a
measure of how much information is in the sample. The lower the
effective sample size (ESS), the higher the variance, and the lower the
sample size, the more weight a few individuals have. Of course, we can
calculate this in causalOT
!
ESS(weights)
#> Control Treated
#> 36.44976 31.14867
Of course, this measure has problems because it can fail to diagnose
problems with variable weights. In response, Vehtari et al. use Pareto
smoothed importance sampling. We offer some shell code to adapt the
class causalWeights
to the loo
package:
<- PSIS(weights) raw_psis
This will also return the Pareto smoothed weights and log weights.
If we want to easily examine the PSIS diagnostics, we can pull those out too
PSIS_diag(raw_psis)
#> $w0
#> $w0$pareto_k
#> [1] 0.2173782
#>
#> $w0$n_eff
#> [1] 35.64768
#>
#>
#> $w1
#> $w1$pareto_k
#> [1] 0.3378445
#>
#> $w1$n_eff
#> [1] 30.33031
We can see all of the \(k\) values
are below the recommended 0.5, indicating finite variance and that the
central limit theorem holds. Note the estimated sample sizes are a bit
lower than the ESS
method above.
Many authors consider the standardized absolute mean balance as a
marker for important balance: see Stuart
(2010). That is \[ \frac{|\overline{X}_c
- \overline{X}_t| }{\sigma_{\text{pool}}},\] where \(\overline{X}_c\) is the mean in the
controls, \(\overline{X}_t\) is the
mean in the treated, and \(\sigma_{\text{pool}}\) is the pooled
standard deviation. We offer such checks in causalOT
as
well.
First, we consider pre-weighting mean balance
mean_bal(data = hainmueller)
#> X1 X2 X3 X4 X5 X6
#> 1.3156637 1.1497004 1.0696805 0.6061098 0.1089106 0.1555189
and after weighting mean balance
mean_bal(data = hainmueller, weights = weights)
#> X1 X2 X3 X4 X5
#> 0.001290803 0.006736102 0.005986706 0.005668089 0.006262544
#> X6
#> 0.002943358
Pretty good!
However, mean balance doesn’t ensure distributional balance.
Ultimately, distributional balance is what we care about in causal inference. Fortunately, we can also measure that too. If we have python installed, we can use the GPU enabled GeomLoss package. Otherwise, the approxOT package can provide similar calculations. We consider the 2-Sinkhorn divergence of Genevay et al. since it metrizes the convergence in distribution.
Before weighting, distributional balance looks poor:
# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
a = rep(1/64, 64), b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss,
w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
a = rep(1/64, 64), b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss
)#> $w0
#> [1] 1.549706
#>
#> $w1
#> [1] 1.549707
But after weighting, it looks much better!
# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
a = weights$w0, b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss,
w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
a = weights$w1, b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss
)#> $w0
#> [1] 6.205006e-06
#>
#> $w1
#> [1] 0.0002435214
After Causal Optimal Transport, the distributions are much more similar.
The calc weight function can also handle other methods. We have implemented methods for logistic or probit regression, the covariate balancing propensity score (CBPS), stable balancing weights (SBW), and the synthetic control method (SCM).
calc_weight(data = hainmueller, method = "Logistic",
estimand = "ATE")
calc_weight(data = hainmueller, method = "Probit",
estimand = "ATE")
calc_weight(data = hainmueller, method = "CBPS",
estimand = "ATE",
verbose = FALSE)
calc_weight(data = hainmueller, method = "SBW",
grid.search = TRUE,
estimand = "ATE", solver = "osqp",
verbose = FALSE, formula = "~.")
calc_weight(data = hainmueller, method = "SCM", penalty = "none",
estimand = "ATE", solver = "osqp",
verbose = FALSE)