Responsible Machine Learning for Survival Analysis
A brief introduction to survival analysis and the use of machine learning models in this area
You can skip the first two sections if you know what a survival analysis is. Otherwise, read to get your intuition about this kind of analysis.
What is survival analysis?
Let us start by defining the type of task. Survival analysis is used to predict the time until some specific event occurs for a selected individual in the considered population. This time is estimated based on various features (covariates) describing this individual. For example, the task may be to estimate the survival time for a given cancer patient. Hence the name of this type of statistical procedure.
It is worth bearing in mind that although survival analysis is most often used in a medical context, it is not only. An event may also be a particular component failure or a deactivation of the service by the customer.
Why not regression?
The short answer to why simple regression is not appropriate is the presence of censored data in the study. Censoring is when we have incomplete information about an individual’s survival time.
The most common type is right-censoring — it means that part of the population did not experience the event during the study (data collection). This gives some information about these individuals — that they survived up to a point without experiencing the event, but we don’t know how quickly they would have experienced it after the end of the study. Thus, the survival time of those individuals is longer than their study participation time.
Going back to the example above, it is known that a certain patient survived for 100 weeks until data collection was completed or until contact with him was lost. However, it is not known what happened next — whether he died a week later, 100 weeks later, or is still alive.
Censored data should not be analyzed using classification or regression techniques as the time to event or event indicator is not taken into account (respectively). Survival analysis models have been developed to deal with such data.
Before we can explore the survival analysis models, a mathematical introduction is necessary. Instead of predicting a single time moment, a certain function of time is the output in the survival analysis.
Survival function describes the probability of an individual surviving until time t without experiencing the event.
T is a non-negative continuous random variable representing time-to-event for an individual.
Hazard function can be interpreted as the probability of an event occurring in a short (infinitesimal) time interval, provided that it has not occurred by time t.
The goal is to estimate the risk of an event occurring at a given point in time. However, since T is a continuous random variable, the probability that it is equal to some value is 0. Hence, this form of the hazard function — note that it is not a probability by definition (values can be greater than 1).
Of course, the above functions are related to each other. It can be shown that:
Another important object is also the cumulative hazard function:
With the basic knowledge and notation already in place, let’s move on to discussing a few models.
Cox Proportional Hazards Model
The most popular, especially among medics, model used in survival analysis is the Cox Proportional Hazards model. It consists of two parts:
- the baseline hazard function that describes the change in risk over time (does not depend on covariates),
- the effect parameters that describe how the hazard changes depending on the characteristics of the individual (based on covariates values).
This is perfectly illustrated by the formula describing the model:
where h₀ is the baseline hazard function (estimated from the entire population), β are model parameters, and x are the values of the covariates for the analyzed individual. Thus, we can note that the log-hazard of an individual is a linear function of their covariates x.
Model parameters β are estimated by maximizing partial likelihood, expressed by the formula:
where D stands for individuals who experienced the event, and Rᵢ stands for individuals at risk (those for which the event did not occur before its occurrence for the i-th individual).
Notably, the relative simplicity of this model makes it possible to assume its partial interpretability. A positive value of βᵢ means that i-th covariate is positively associated with the event probability. The higher value of this covariate increases the risk of an event, and a lower value of this variable reduces the hazard.
Variations on the Cox model are its regularized versions. Regularization is not specific in this case — it is based on well-known concepts (such as lasso, ridge, or elastic net).
Machine learning in survival analysis
In many cases, the Cox model will meet expectations, but you should bear in mind its limitations, such as the fact that it is a linear method. Therefore, in order to create a better model, it is worth considering using more complex algorithms. Machine learning methods can learn complex relationships between covariates and survival time. This flexibility is expected to lead to more accurate predictions. The algorithms used in the survival analysis are modifications of algorithms known from classification or regression, appropriately adapted to the censored data. So let’s look at some examples …
Random Survival Forest
Random Survival Forest (RSF) is an approach that extends random forests to censored data. It is an ensemble of survival trees (they are also considered interpretable models, but aggregating predictions of many of them enables a better performance to be achieved). The randomness comes from each tree being trained on a different bootstrap data sample. Moreover, an optimal split is found at each node using a random subset of covariates.
The key to understanding is how individual survival trees are formed. Naturally, the main difference from regression or classification trees is how the splitting rule is defined (i.e., criterion determining the survival difference between daughter nodes). The most common method is the log-rank splitting rule. It is based on the log-rank test comparing survival distributions of two samples. The goal is to maximize the value of the test statistic:
where k is the number of distinct event times in the population, Aₜ is the random variable corresponding to number of events at time t for the first subpopulation, and aₜ is the observed value. Expected value and variance are computed assuming the null hypothesis (i.e., that both subpopulations have the same survival distribution). For further information on understanding and computing this formula, see Regression Trees for Censored Data.
The single survival tree prediction for an individual is a cumulative hazard function (CHF) computed for all individuals in the same tree terminal node; CHF is estimated using the Nelson-Aalen estimator. For terminal node h, the estimator is as follows:
tₗₕ are distinct event times from node h, while dₗₕ and Rₗₕ are the number of events and individuals at risk at time tₗₕ.
Further, the entire random survival forest prediction is the CHF averaged over all trees:
where Hᵢ is the estimated CHF for the individual x’s terminal node in the i-th of the N trees.
Survival Gradient Boosting
Another framework known from classification and regression problems extended to survival analysis is gradient boosting. As you can guess, survival GBM is an additive model composed of simple base learners, so it can be represented as
where fᵢ denotes the i-th weak learner and βᵢ is its weight.
GBM is created in a greedy, stage-wise process; the objective is to maximize log-partial likelihood known from the Cox model, where the F model prediction replaces the linear prediction:
An interesting modification is to create a model where the objective is to maximize the concordance index (CI). Concordance index is a measure of a model’s performance often used in survival analysis. It should be understood as a fraction of correctly ordered predictions for comparable pairs of individuals, which in the mathematical notation is as follows:
where I is the indicator of the condition in parentheses, and P is the set of all comparable ordered pairs of individuals. Two individuals i and j are comparable if they both experienced an event at different times such that tᵢ<tⱼ, or the individual j did not experience the event, but their survival time is known to be longer than for i.
However, CI has a discrete form; hence a differentiable approximation, called smoothed concordance index (SCI), is used for optimization:
where α is the smoothing parameter in the adopted sigmoid function.
Of course, there are other models adapted to deal with censored data, such as Survival-SVM or artificial neural networks (like DeepSurv). However, the purpose of this post is to introduce you to the use of ML in survival analysis based on these two popular algorithms. On the other hand, the extensions of other algorithms rely on similar intuitions.
What about the responsibility?
Machine learning models often work as black-box models. Hence, an extremely important direction is responsible ML (you can read more on this blog). In the case of survival analysis, which is often used in the medicine or insurance industry, responsible modeling may be even more crucial. Undoubtedly, explainability is one of the critical aspects of that approach to machine learning.
Naturally, some of the known explanation methods from classification or regression models can be used to explain survival analysis models (e.g., based on the predicted risk score). However, given that the natural prediction of such models is some function (survival function or cumulative hazard function), this approach is often wrong. This motivates the development of explanation methods designed specifically to evaluate survival analysis models. One example is the Surv-LIME (family of methods) being a LIME extension where the prediction of the black-box model for an individual is explained by applying the Cox PH model to the local neighborhood of the given individual. While in the survxai R package, we can find four methods tailored to explain survival models. Explanations are in the form of plots that are functions of time — there are Ceteris Paribus, BreakDown, Partial Dependence, and Model Performance plots.
Unfortunately, there are still few methods specific to the survival analysis models. An even bigger problem is that methods described in the literature are often not available for easy use due to the lack of open-source packages.
If you are interested in other posts about explainable, fair, and responsible ML, follow #ResponsibleML on Medium.