Learn from the mistake!

Photo by Tyler Nix on Unsplash
Previously I have explored how to perform survival analysis.
In this post, I will continue the quest by using GBM models to perform survival analysis.
In this demonstration, I will be using this bank dataset from Kaggle.
First, I will load the necessary packages into the environment.
pacman::p_load(tidyverse, lubridate, janitor, survival, survminer, censored, gbm, Hmisc)
First I will import the dataset into the environment.
I will also clean the column names, drop the columns I don’t need, and transform the columns to be the right format.
Note that GBM algorithm is unable to accept character variables, hence I have converted all the characters into factors.
Now, let’s start building the GBM model!
For producibility purpose, I will specify the random seed.
set.seed(1234)
gbm_fit <-
gbm(Surv(tenure, exited) ~ .
,distribution = "coxph"
,data = df
,interaction.depth = 2
,n.trees = 1000
,cv.folds = 5
,n.cores = 5)
Note that I have specified the interaction.depth to be 2. This is to allow interaction terms within the model, otherwise the algorithm will assume an additive model as explained in the GBM documentation.
I have also specified cv.folds and n.cores to use multiple cores to perform cross-validation.
If the distribution is not specified, the algorithm will take a guess what is the appropriate distribution for this problem.
gbm_fit_noDist <-
gbm(Surv(tenure, exited) ~ .
,data = df
,interaction.depth = 2
,n.trees = 100
,cv.folds = 5
,n.cores = 5)
Distribution not specified, assuming coxph ...
Nevertheless, let’s go back to the original fitted model.
gbm_fit
gbm(formula = Surv(tenure, exited) ~ ., distribution = "coxph",
data = df, n.trees = 1000, interaction.depth = 2, cv.folds = 5,
n.cores = 5)
A gradient boosted model with coxph loss function.
1000 iterations were performed.
The best cross-validation iteration was 149.
There were 9 predictors of which 8 had non-zero influence.
From the result, we can see which iteration gave the best result and how many variables were kept in the final model.
We can see that one of the variables is not significant.
Alternatively, we will get the same result if we pass the fitted object into print function.
print(gbm_fit)
gbm(formula = Surv(tenure, exited) ~ ., distribution = "coxph",
data = df, n.trees = 1000, interaction.depth = 2, cv.folds = 5,
n.cores = 5)
A gradient boosted model with coxph loss function.
1000 iterations were performed.
The best cross-validation iteration was 149.
There were 9 predictors of which 8 had non-zero influence.
One of the common methods in measuring how good is the fitted model is to calculate the C Index.
Before doing that, we need to generate the predicted values.
gbm_predict <-
predict(gbm_fit, df)
Then, we will use rcorr.cens function from Hmisc package to compute the C Index.
rcorr.cens(-gbm_predict, Surv(df$tenure, df$exited))["C Index"]
C Index
0.80961
best_iter <- gbm.perf(gbm_fit, method = 'cv')
best_iter
[1] 149
The plot shows us the number for the best iteration, which is the blue dotted line.
gbm package also has a function to help us estimate the strength of the interaction effect.
To do this, we will specify which interaction effect we would like to estimate.
For example, I would like to estimate the interaction effect between age and geography.
interact.gbm(gbm_fit, df, i.var = c("age", "geography"))
[1] 0.04414741
Alternatively, we would run a loop to estimate the interaction effects of all the variables.
var_list <-
df %>%
select(-c(tenure
,exited)) %>%
names()
interact_result <-
tibble(variable_1 = character()
,variable_2 = character()
,interaction = numeric())
for(i in var_list){
for(j in var_list){
if(i != j){
interact_result <-
interact_result %>%
add_row(
variable_1 = i
,variable_2 = j
,interaction = interact.gbm(gbm_fit
,df
,i.var = c(i, j))
)
}
}
}
interact_result %>%
ggplot(aes(variable_1, variable_2, fill = interaction)) +
geom_tile() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))

Another awesome feature of gbm package is the algorithm will estimate variable importance of all the variables.
There are a few methods to obtain the variable importance results.
summary functionsummary(gbm_fit)

var rel.inf
age age 26.9149738
num_of_products num_of_products 24.3632770
credit_score credit_score 13.2040099
balance balance 13.1281317
estimated_salary estimated_salary 10.8689779
is_active_member is_active_member 5.1401539
geography geography 4.6517175
gender gender 1.2907075
has_cr_card has_cr_card 0.4380509
relative.influence functionrelative.influence(gbm_fit)
n.trees not given. Using 149 trees.
credit_score geography gender age
72.72859 103.52181 37.78743 707.29663
balance num_of_products has_cr_card is_active_member
122.45139 629.82370 0.00000 166.83147
estimated_salary
55.72705
Personally, I prefer method 2 as this would allow me to pass the result to ggplot function and visualize the result as shown below.
as.data.frame(relative.influence(gbm_fit)) %>%
rownames_to_column("variable") %>%
rename(variable_importance = `relative.influence(gbm_fit)`) %>%
ggplot(aes(variable_importance, reorder(variable, variable_importance))) +
geom_col()
n.trees not given. Using 149 trees.

gbm package also has a function to help the users to generate a partial dependence plot to understand the marginal effect of the variable.
plot(gbm_fit
,i.var = "age")

However, I prefer using ggplot function to plot the chart.
Hence, I will specify return.grid to be TRUE so that the function will return the estimated results.
plot(gbm_fit
,i.var = "age"
,return.grid = TRUE) %>%
ggplot(aes(age, y)) +
geom_line() +
ylab("") +
labs(title = "Partial Dependence Plot for Age") +
theme_minimal() +
theme(axis.text.y = element_blank()
,axis.ticks.y = element_blank())

If we would like to perform partial dependence on more than one variable, we just need to specify that in the i.var argument.
# partial dependence of two variables
plot(gbm_fit
,i.var = c("balance", "geography")
,return.grid = TRUE) %>%
ggplot(aes(balance, y, color = geography)) +
geom_line() +
ylab("") +
labs(title = "Partial Dependence Plot for Age") +
theme_minimal() +
theme(axis.text.y = element_blank()
,axis.ticks.y = element_blank())

Viola! That is how GBM can be used to perform survival analysis.
That’s all for the day!
Thanks for reading the post until the end.
Feel free to contact me through email or LinkedIn if you have any suggestions on future topics to share.
Refer to this link for the blog disclaimer.
Till next time, happy learning!

Photo by RDNE Stock project