ResponsibleML
Published in

ResponsibleML

survex: model-agnostic explainability for survival analysis

In this blog, we’d like to cover how model explainability can help make informed choices when working with survival models by showcasing the capabilities of the survex R package.

Survival analysis and explainability

When talking about machine learning, most people have classification and regression tasks in mind, as they are the most popular. However, these tasks are not the only application of ML models. Another popular one, especially in the fields of medicine and insurance, is survival analysis, dealing with predicting the time until a certain event (e.g., patient death, machine breakdown, etc.) occurs.

You can learn about it in this blog, but long story short, survival models (most often) predict a survival function. It tells us what is the probability of an event not happening until a given time t. The output can also be a single value (e.g., risk score) but these scores are always some aggregates of the survival function and this naturally leads to a loss of information included in the prediction.

The complexity of the output of survival models means that standard explanation methods cannot be applied directly.

Because of this, we (I and the team: Mateusz Krzyziński, Hubert Baniecki, and Przemyslaw Biecek) developed an R package — survex, which provides explanations for survival models. We hope this tool allows for more widespread usage of complex machine learning survival analysis models. Until now, simpler statistical models such as Cox Proportional Hazards were preferred due to their interpretability — vital in areas such as medicine, even though they were frequently outperformed by complex machine learning models.

Cheatsheet highlighting the main functionalities of the survex package

In this blog, I’d like to present the main functionalities of survex and explain how to interpret the explanations.

What influences the overall model predictions?

Knowing which variables are the most important for the model of interest when making predictions is critical. This information can be compared with domain knowledge and used to assess if the predictions are made based on correct variables or if the model is using something unexpected to provide outputs. Explanations of this kind can be calculated via the model_parts() function of survex.

To measure the global importance of variables, we use permutational variable importance. Intuitively, it works as such: we permute one variable (column) of the dataset and make predictions using this permuted input. Then we calculate a performance metric and see how much the performance of the model worsened after the permutation in comparison with the baseline result. We repeat this for all variables in the dataset.

We interpret the result as follows: the higher the decrease in performance of the model after permuting a variable, the more important the variable is for making the prediction — so the variables higher on the plot are the most important ones. Additionally, if we use a time-dependent metric, we can see if some variables are more important at certain times of the study.

Below you can see the permutational variable importance for a Random Survival Forest model.

Variable importance for a Random Survival Forest model

We can see that the karno variable is the most important until time ~100 and then celltype takes the lead.

How does the prediction change on average?

Another question a user of a survival model might ask is: how do model predictions depend on individual variables? Of course, this is a difficult question to answer because models are complex and variables are often dependent on one another but partial dependence profiles, accessible by the model_profile() function in survex, present this kind of information.

These profiles are an aggregation of “what-if” explanations across an entire dataset. To calculate them, we change the value of a single variable for each observation, observe how the prediction changes and present the average response to the change on the plot.

Below, the partial dependence profiles are presented for two variables — a categorical one — celltype and a continuous one — karno.

Partial dependence profiles for a Random Survival Forest

We see that celltypes large and squamous are almost synonymous to the model, as are adeno and smallcell with the latter ones having much worse survival chances. We also observe that the observations with high values of the continuous karno variable are more likely to survive longer.

What influences the prediction for a selected observation?

The general knowledge of how the model works is very useful but often we want to know what factors contribute to the prediction for a single chosen observation.

This can be done via the SurvSHAP(t) and SurvLIME methods, accessible by the predict_parts() method in survex.

SurvSHAP(t) is an extension of SHAP explanations for models with functional output. The basic idea is calculating model responses with some variables “turned off” and seeing the additive effect of turning them on in different subsets. The resulting explanation has some nice properties e.g., SurvSHAP(t) values for all variables sum to the average prediction for the dataset. This method allows for detecting time-dependent effects for some variables — the considered treatment can be effective at the beginning but starting from a specific timepoint can decrease the survival chances.

SurvLIME, on the other hand, works by fitting a surrogate Cox Proportional Hazards model using artificially generated data from the neighborhood of the explained observation. The coefficients of the Cox multiplied by the values of corresponding variables show the local importance of variables.

Below, both of these explanations for the same observation are presented. They show that for this observation, the survival function is below average, and the main contributions are the celltype and age variables. SurvSHAP(t) shows the contribution of each variable separately for each timepoint, whereas SurvLIME only describes the overall contribution. For the SurvLIME explanation, the survival function of the surrogate model is also plotted together with the black-box model explanation as this informs us how close the explanation really is to the actual prediction.

SurvSHAP(t) explanation
SurvLIME explanation

How does the prediction change for a selected observation?

The question “what if?” inspires the next kind of explanation. These are ceteris paribus profiles and they help us understand what happens to the prediction of a model when we change a variable for a single observation (patient). It can be used to verify the correctness of a model, for example, we can check if the introduction of treatment would, according to the model, increase the chances of a patient’s survival. These explanations are available via the predict_profile() function of survex.

Below a ceteris paribus explanation for the age variable is presented.

Ceteris paribus profile for the age variable

We see that if the patient was younger, they would have higher chances of survival for this particular observation.

Measuring performance

The performance of models used for making decisions, especially in crucial fields such as medicine, is critical to the users. survex allows for measuring and comparison of models via the model_performance() function. The performance of survival models can be measured in different ways — some of the metrics are time-dependent, that is, they assess the performance at specific time points. These include the Brier score (otherwise known as Graf score) and C/D AUC. Others, such as the concordance index and integrated versions of the Brier score and C/D AUC, provide information about the overall performance of the model across the entire time domain.

Brier score is the equivalent of the MSE metric, with minor adjustments allowing to take censored data into account, whereas C/D AUC is the extension of the AUC metric known from classification problems.

Below you can see a comparison of these metrics prepared using survex for two models.

Performance comparison of Cox Proportional Hazards and Random Survival Forest models

Summary

We hope that the availability of explanation tools for survival models will lead to the adaptation of the explanation aspect in the modeling pipelines for survival analysis.

If this article seems compelling, there’s no better way to find out how this all works than checking on your own. We highly encourage you to try out the explanations provided by survex for yourself. Just type install.packages("survex") and start explaining your survival models!

If you are interested in other posts about explainable, fair, and responsible ML, follow #ResponsibleML on Medium.

In order to see more R related content visit https://www.r-bloggers.com

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store