Causal Inference - Weight

Machine Learning Causal Inference
Jasper Lok https://jasperlok.netlify.app/
01-02-2024

Photo by Elena Mozhvilo on Unsplash

In this post, I will be continuing my journey in exploring causal inference.

Previously, I looked at how to draw a direct acyclic graph in my previous post.

Steps to perform causal inference

Once the assumptions are made, the following are the basic steps in performing a causal analysis using data preprocessing (Greifer 2023):

  1. Decide on covariates for which balance must be achieved
  2. Estimate the distance measure (e.g., propensity score)
  3. Condition on the distance measure (e.g., using matching, weighting, or subclassification)
  4. Assess balance on the covariates of interest; if poor, repeat steps 2-4
  5. Estimate the treatment effect in the conditioned sample

Different balancing techniques

Below are different balancing techniques:

Different estimands

As part of the causal inference, we will need to choose the estimands depending on what question we are trying to answer.

I find this summary table from this book is rather helpful in guiding one in choosing the appropriate estimands (Barrett, McGowan, and Gerke 2023a).

For simplicity, I will be using the average treatment effect in this analysis.

What is ‘Average Treament Effect’ (ATE)?

ATE is the difference in means of the treated and control groups (Nguyen 2023).

The author mentioned in the section that with random assignment, the observed mean difference between the two groups is an unbiased estimator of the average treatment effect.

Demonstration

In this demonstration, I will be using a Kaggle dataset on the employee resignation dataset.

pacman::p_load(tidyverse, janitor, WeightIt, cobalt)

Import Data

First, I will import the data into the environment.

For the explanations on the data wrangling, they can be found in this post.

df <- read_csv("https://raw.githubusercontent.com/jasperlok/my-blog/master/_posts/2022-03-12-marketbasket/data/general_data.csv") %>%
  # drop the columns we don't need
  dplyr::select(-c(EmployeeCount, StandardHours, EmployeeID)) %>%
  clean_names() %>% 
  # impute the missing values with the mean values
  mutate(
    num_companies_worked = case_when(
      is.na(num_companies_worked) ~ mean(num_companies_worked, na.rm = TRUE),
      TRUE ~ num_companies_worked),
    total_working_years = case_when(
      is.na(total_working_years) ~ mean(total_working_years, na.rm = TRUE),
      TRUE ~ total_working_years),
    ind_promoted_in_last1Yr = if_else(years_since_last_promotion <= 1, "yes", "no"),
    ind_promoted_in_last1Yr = as.factor(ind_promoted_in_last1Yr),
    attrition = as.factor(attrition),
    job_level = as.factor(job_level)
    ) %>%
  droplevels()

Balance statistics

First, I will use bal.tab function from cobalt package to check whether any of the variables are imbalanced.

This is corresponding to the first step mentioned in the earlier section.

# unadjusted
bal.tab(ind_promoted_in_last1Yr ~ age + gender + department 
        ,data = df
        ,estimand = "ATE"
        ,stats = c("m", "v")
        ,thresholds = c(m = 0.05))
Balance Measures
                                     Type Diff.Un      M.Threshold.Un
age                               Contin. -0.3297 Not Balanced, >0.05
gender_Male                        Binary -0.0024     Balanced, <0.05
department_Human Resources         Binary  0.0141     Balanced, <0.05
department_Research & Development  Binary -0.0183     Balanced, <0.05
department_Sales                   Binary  0.0042     Balanced, <0.05
                                  V.Ratio.Un
age                                   1.0497
gender_Male                                .
department_Human Resources                 .
department_Research & Development          .
department_Sales                           .

Balance tally for mean differences
                    count
Balanced, <0.05         4
Not Balanced, >0.05     1

Variable with the greatest mean difference
 Variable Diff.Un      M.Threshold.Un
      age -0.3297 Not Balanced, >0.05

Sample sizes
      no  yes
All 1596 2814

We could set the threshold of the mean bypassing the threshold to the threshold argument as shown above. Similar logic applies to variance as well.

Based on the result shown above, we can see that age is not balanced in this dataset as the unadjusted difference exceeds the threshold.

The function also allows one to include interactions and polynomial functions through int and poly arguments respectively.

bal.tab(ind_promoted_in_last1Yr ~ age + gender + department
        ,data = df
        ,int = TRUE
        ,poly = 2
        ,estimand = "ATE"
        ,stats = c("m", "v")
        ,thresholds = c(m = 0.05))
Balance Measures
                                                     Type Diff.Un
age                                               Contin. -0.3297
gender_Male                                        Binary -0.0024
department_Human Resources                         Binary  0.0141
department_Research & Development                  Binary -0.0183
department_Sales                                   Binary  0.0042
age²                                              Contin. -0.3056
age * gender_Female                               Contin. -0.0552
age * gender_Male                                 Contin. -0.0974
age * department_Human Resources                  Contin.  0.0644
age * department_Research & Development           Contin. -0.1455
age * department_Sales                            Contin. -0.0363
gender_Female * department_Human Resources         Binary  0.0079
gender_Female * department_Research & Development  Binary  0.0071
gender_Female * department_Sales                   Binary -0.0127
gender_Male * department_Human Resources           Binary  0.0062
gender_Male * department_Research & Development    Binary -0.0254
gender_Male * department_Sales                     Binary  0.0168
                                                       M.Threshold.Un
age                                               Not Balanced, >0.05
gender_Male                                           Balanced, <0.05
department_Human Resources                            Balanced, <0.05
department_Research & Development                     Balanced, <0.05
department_Sales                                      Balanced, <0.05
age²                                              Not Balanced, >0.05
age * gender_Female                               Not Balanced, >0.05
age * gender_Male                                 Not Balanced, >0.05
age * department_Human Resources                  Not Balanced, >0.05
age * department_Research & Development           Not Balanced, >0.05
age * department_Sales                                Balanced, <0.05
gender_Female * department_Human Resources            Balanced, <0.05
gender_Female * department_Research & Development     Balanced, <0.05
gender_Female * department_Sales                      Balanced, <0.05
gender_Male * department_Human Resources              Balanced, <0.05
gender_Male * department_Research & Development       Balanced, <0.05
gender_Male * department_Sales                        Balanced, <0.05
                                                  V.Ratio.Un
age                                                   1.0497
gender_Male                                                .
department_Human Resources                                 .
department_Research & Development                          .
department_Sales                                           .
age²                                                  0.9356
age * gender_Female                                   0.8591
age * gender_Male                                     0.8896
age * department_Human Resources                      1.3362
age * department_Research & Development               0.8793
age * department_Sales                                0.8940
gender_Female * department_Human Resources                 .
gender_Female * department_Research & Development          .
gender_Female * department_Sales                           .
gender_Male * department_Human Resources                   .
gender_Male * department_Research & Development            .
gender_Male * department_Sales                             .

Balance tally for mean differences
                    count
Balanced, <0.05        11
Not Balanced, >0.05     6

Variable with the greatest mean difference
 Variable Diff.Un      M.Threshold.Un
      age -0.3297 Not Balanced, >0.05

Sample sizes
      no  yes
All 1596 2814

But, for simplicity, I will not include any interaction and polynomial terms in the formula.

cobalt package also offers users to visualize the density plot, allowing users to access the independence between treatment and selected covariate.

For example, below is the density plot for age. As shown below, the average age of individuals being promoted in the last 1 year is lower than those who are not.

bal.plot(ind_promoted_in_last1Yr ~ age + gender + department
         ,data = df
         ,"age")

Another cool thing about this function is the plot will be changed to a bar plot when the covariate is a categorical variable.

bal.plot(ind_promoted_in_last1Yr ~ age + gender + department
         ,data = df
         ,"department")

Next, love.plot function is used to check the absolute mean differences for unadjusted and adjusted variables.

The usual recommended threshold is 0.1 and 0.05.

Before adjustment, we could see the absolute mean differences for age and gender are more than 0.05.

love.plot(ind_promoted_in_last1Yr ~ age + gender + department
          ,data = df
          ,drop.distance = TRUE
          ,abs = TRUE
          ,threshold = c(m = 0.05)
          ,line = TRUE)

According to the documentation, love.plot function also allows different balance statistics.

Weight

In this post, I will be exploring how to perform matching.

One way to think about matching is as a crude “weight” where everyone who was matched gets a weight of 1 and everyone who was not matched gets a weight of 0 in the final sample. Another option is to allow this weight to be smooth, applying a weight to allow, on average, the covariates of interest to be balanced in the weighted population (Barrett, McGowan, and Gerke 2023b).

w_outcome <- 
  weightit(ind_promoted_in_last1Yr ~ age + gender + department
           ,data = df
           ,estimand = "ATE")

w_outcome
A weightit object
 - method: "glm" (propensity score weighting with GLM)
 - number of obs.: 4410
 - sampling weights: none
 - treatment: 2-category
 - estimand: ATE
 - covariates: age, gender, department

Based on the output, we noted the following:

Considerations when building propensity model (Barrett, McGowan, and Gerke 2023c):

Next, I will pass the weightit object into the summary function.

summary(w_outcome)
                 Summary of weights

- Weight ranges:

       Min                                  Max
no  1.7816     |-----------------------| 4.9589
yes 1.2273 |------|                      2.2876

- Units with the 5 most extreme weights by group:
                                       
       2717   1247   3479   2009    539
  no 4.9529 4.9529 4.9589 4.9589 4.9589
       3562   2422   2092    952    622
 yes 2.2876 2.2876 2.2876 2.2876 2.2876

- Weight statistics:

    Coef of Var   MAD Entropy # Zeros
no        0.203 0.161   0.020       0
yes       0.131 0.101   0.008       0

- Effective Sample Sizes:

                no     yes
Unweighted 1596.   2814.  
Weighted   1533.05 2766.48

If we want to see the distribution of the weights, we could pass the summary object into plot function.

plot(summary(w_outcome))

# propensity model
bal.tab(w_outcome
        ,stats = c("m", "v")
        ,thresholds = c(m = 0.05))
Balance Measures
                                      Type Diff.Adj     M.Threshold
prop.score                        Distance   0.0056 Balanced, <0.05
age                                Contin.  -0.0139 Balanced, <0.05
gender_Male                         Binary   0.0023 Balanced, <0.05
department_Human Resources          Binary  -0.0013 Balanced, <0.05
department_Research & Development   Binary   0.0009 Balanced, <0.05
department_Sales                    Binary   0.0004 Balanced, <0.05
                                  V.Ratio.Adj
prop.score                             1.1592
age                                    1.2024
gender_Male                                 .
department_Human Resources                  .
department_Research & Development           .
department_Sales                            .

Balance tally for mean differences
                    count
Balanced, <0.05         6
Not Balanced, >0.05     0

Variable with the greatest mean difference
 Variable Diff.Adj     M.Threshold
      age  -0.0139 Balanced, <0.05

Effective sample sizes
                no     yes
Unadjusted 1596.   2814.  
Adjusted   1533.05 2766.48
bal.plot(w_outcome, "age", which = "both")
bal.plot(w_outcome, "department", which = "both")

Now, if we were to call the love.plot function on weightit object, we could see that all the covariates are below the threshold.

love.plot(w_outcome
          ,data = df
          ,drop.distance = TRUE
          ,abs = TRUE
          ,threshold = c(m = 0.05)
          ,line = TRUE)

Conclusion

That’s all for the day!

Thanks for reading the post until the end.

Feel free to contact me through email or LinkedIn if you have any suggestions on future topics to share.

Refer to this link for the blog disclaimer.

Till next time, happy learning!

Photo by Piret Ilver on Unsplash

Barrett, Malcolm, Lucy D’Agostino McGowan, and Travis Gerke. 2023a. Bookdown. https://www.r-causal.org/chapters/11-estimands#choosing-estimands.
———. 2023b. Bookdown. https://www.r-causal.org/chapters/08-building-ps-models.
———. 2023c. Bookdown. https://www.r-causal.org/chapters/07-prep-data.
Greifer, Noah. 2023. “Covariate Balance Tables and Plots: A Guide to the Cobalt Package.” https://cloud.r-project.org/web/packages/cobalt/vignettes/cobalt.html.
Nguyen, Mike. 2023. Bookdown. https://bookdown.org/mike/data_analysis/causal-inference.html.

References