Survival on Gradient Boosting Regression Model

Machine Learning Survival Model

Learn from the mistake!

Jasper Lok https://jasperlok.netlify.app/
03-30-2024

Photo by Tyler Nix on Unsplash

Previously I have explored how to perform survival analysis.

In this post, I will continue the quest by using GBM models to perform survival analysis.

Demonstration

In this demonstration, I will be using this bank dataset from Kaggle.

Setup the environment

First, I will load the necessary packages into the environment.

pacman::p_load(tidyverse, lubridate, janitor, survival, survminer, censored, gbm, Hmisc)

Import Data

First I will import the dataset into the environment.

I will also clean the column names, drop the columns I don’t need, and transform the columns to be the right format.

df <- read_csv("https://raw.githubusercontent.com/jasperlok/my-blog/master/_posts/2022-09-10-kaplan-meier/data/Churn_Modelling.csv") %>%
  clean_names() %>%
  select(-c(row_number, customer_id, surname)) %>%
  filter(tenure > 0) %>% 
  mutate(across(!where(is.numeric), as.factor))

Note that GBM algorithm is unable to accept character variables, hence I have converted all the characters into factors.

Now, let’s start building the GBM model!

Model Building

For producibility purpose, I will specify the random seed.

set.seed(1234)

gbm_fit <-
  gbm(Surv(tenure, exited) ~ .
      ,distribution = "coxph"
      ,data = df
      ,interaction.depth = 2
      ,n.trees = 1000
      ,cv.folds = 5
      ,n.cores = 5)

Note that I have specified the interaction.depth to be 2. This is to allow interaction terms within the model, otherwise the algorithm will assume an additive model as explained in the GBM documentation.

I have also specified cv.folds and n.cores to use multiple cores to perform cross-validation.

If the distribution is not specified, the algorithm will take a guess what is the appropriate distribution for this problem.

gbm_fit_noDist <-
  gbm(Surv(tenure, exited) ~ .
      ,data = df
      ,interaction.depth = 2
      ,n.trees = 100
      ,cv.folds = 5
      ,n.cores = 5)
Distribution not specified, assuming coxph ...

Nevertheless, let’s go back to the original fitted model.

gbm_fit
gbm(formula = Surv(tenure, exited) ~ ., distribution = "coxph", 
    data = df, n.trees = 1000, interaction.depth = 2, cv.folds = 5, 
    n.cores = 5)
A gradient boosted model with coxph loss function.
1000 iterations were performed.
The best cross-validation iteration was 149.
There were 9 predictors of which 8 had non-zero influence.

From the result, we can see which iteration gave the best result and how many variables were kept in the final model.

We can see that one of the variables is not significant.

Alternatively, we will get the same result if we pass the fitted object into print function.

print(gbm_fit)
gbm(formula = Surv(tenure, exited) ~ ., distribution = "coxph", 
    data = df, n.trees = 1000, interaction.depth = 2, cv.folds = 5, 
    n.cores = 5)
A gradient boosted model with coxph loss function.
1000 iterations were performed.
The best cross-validation iteration was 149.
There were 9 predictors of which 8 had non-zero influence.

Model Performance

One of the common methods in measuring how good is the fitted model is to calculate the C Index.

Before doing that, we need to generate the predicted values.

gbm_predict <-
  predict(gbm_fit, df)

Then, we will use rcorr.cens function from Hmisc package to compute the C Index.

rcorr.cens(-gbm_predict, Surv(df$tenure, df$exited))["C Index"]
C Index 
0.80961 

Cross Validation

best_iter <- gbm.perf(gbm_fit, method = 'cv')
best_iter
[1] 149

The plot shows us the number for the best iteration, which is the blue dotted line.

Interaction

gbm package also has a function to help us estimate the strength of the interaction effect.

To do this, we will specify which interaction effect we would like to estimate.

For example, I would like to estimate the interaction effect between age and geography.

interact.gbm(gbm_fit, df, i.var = c("age", "geography"))
[1] 0.04414741

Alternatively, we would run a loop to estimate the interaction effects of all the variables.

var_list <-
  df %>% 
  select(-c(tenure
            ,exited)) %>% 
  names()

interact_result <-
  tibble(variable_1 = character()
         ,variable_2 = character()
         ,interaction = numeric())

for(i in var_list){
  for(j in var_list){
    if(i != j){
      interact_result <-
        interact_result %>% 
        add_row(
          variable_1 = i
          ,variable_2 = j
          ,interaction = interact.gbm(gbm_fit
                                      ,df
                                      ,i.var = c(i, j))
        )
    }
  }
}

interact_result %>% 
  ggplot(aes(variable_1, variable_2, fill = interaction)) +
  geom_tile() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

Variable Importance

Another awesome feature of gbm package is the algorithm will estimate variable importance of all the variables.

There are a few methods to obtain the variable importance results.

Method 1: Use summary function

summary(gbm_fit)

                              var    rel.inf
age                           age 26.9149738
num_of_products   num_of_products 24.3632770
credit_score         credit_score 13.2040099
balance                   balance 13.1281317
estimated_salary estimated_salary 10.8689779
is_active_member is_active_member  5.1401539
geography               geography  4.6517175
gender                     gender  1.2907075
has_cr_card           has_cr_card  0.4380509

Method 2: Use relative.influence function

relative.influence(gbm_fit)
n.trees not given. Using 149 trees.
    credit_score        geography           gender              age 
        72.72859        103.52181         37.78743        707.29663 
         balance  num_of_products      has_cr_card is_active_member 
       122.45139        629.82370          0.00000        166.83147 
estimated_salary 
        55.72705 

Personally, I prefer method 2 as this would allow me to pass the result to ggplot function and visualize the result as shown below.

as.data.frame(relative.influence(gbm_fit)) %>%
  rownames_to_column("variable") %>% 
  rename(variable_importance = `relative.influence(gbm_fit)`) %>% 
  ggplot(aes(variable_importance, reorder(variable, variable_importance))) +
  geom_col()
n.trees not given. Using 149 trees.

Partial Dependence

gbm package also has a function to help the users to generate a partial dependence plot to understand the marginal effect of the variable.

plot(gbm_fit
     ,i.var = "age")

However, I prefer using ggplot function to plot the chart.

Hence, I will specify return.grid to be TRUE so that the function will return the estimated results.

plot(gbm_fit
     ,i.var = "age"
     ,return.grid = TRUE) %>% 
  ggplot(aes(age, y)) +
  geom_line() +
  ylab("") +
  labs(title = "Partial Dependence Plot for Age") +
  theme_minimal() +
  theme(axis.text.y = element_blank()
        ,axis.ticks.y = element_blank())

If we would like to perform partial dependence on more than one variable, we just need to specify that in the i.var argument.

# partial dependence of two variables
plot(gbm_fit
     ,i.var = c("balance", "geography")
     ,return.grid = TRUE) %>% 
  ggplot(aes(balance, y, color = geography)) +
  geom_line() +
  ylab("") +
  labs(title = "Partial Dependence Plot for Age") +
  theme_minimal() +
  theme(axis.text.y = element_blank()
        ,axis.ticks.y = element_blank())

Viola! That is how GBM can be used to perform survival analysis.

Conclusion

That’s all for the day!

Thanks for reading the post until the end.

Feel free to contact me through email or LinkedIn if you have any suggestions on future topics to share.

Refer to this link for the blog disclaimer.

Till next time, happy learning!

Photo by RDNE Stock project