Managing Attrition. Feature Driven Survival Analysis
This is the second article in the Survival Analysis series. The first article (here) was an introduction to Survival analysis using simple non-parametric methods namely the Kaplan Meier method.
Survival Analysis is a class of models and techniques used to analyze and predict time to an event. Survival Analysis can be useful in any context where we want to analyze the time to an event, some examples below
Survival Models broadly fall into 2 categories
- Non-Parametric Survival models: These are built directly from the data. They assume no parameters or a distribution. In a business context they give good high level segment averages but cannot perform sensitivity analysis. They also tend to be stepped or discontinuous.
- Parametric Models: These are generally Survival Regression models and are built for sensitivity analysis. Can be used to answer questions like keeping all else constant how will time to event be impacted if one feature is changed
Now lets look at how parametric models work. Fully parametric models are also known as AFT (Accelerated Failure Time) models and are represented by the equation below.
- X represents the time to event
- e in this case represents a base distribution. We need to find a base distribution that fits the data well
- x1….xn represent the features and w1…..wn represent the weights or regression coefficients attached to the features.
- In other words there is a base survival curve and changing the feature value accelerates or decelerates the time to event and changes the shape of the survival curve. A negative coefficient for a feature has the effect of decreasing the time to event as the feature value increases.
Now lets apply this to a dataset. We will use the Prison recidivism dataset to understand what factors affect time to arrest for a population of previous offenders. We will visualize the results and try to generate some predictions. Fortunately Python has a robust package called lifelines for all kinds of Survival Analysis, which we will use below.
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines import WeibullFitter
import numpy as np
import matplotlib.pyplot as plt
# Loading the dataset
prison = pd.read_csv('https://assets.datacamp.com/production/repositories/5850/datasets/4e20aa97a26bbe32106a94b76ae4cabf1a632d59/rossi.csv')
prison.head()
prison.shape
Looking at the first 5 rows we can see a series of variables like paroled (paro), priors (prio), age, race etc. The event of interest is arrest and the duration column is week. arrest=1 means an arrest took place and the corresponding duration is week.
Now lets try to build a parametric survival model. The first step in this process is to try to find a baseline distribution that fits the actual survival. This would correspond to the term e in the equation below.
An easy way to visualize this is to see if a distribution overlays closely over the Kaplan Meier curve. Within the lifelines package we have many candidate distributions like the Lognormal, Weibull, Exponential etc. Here we try some such distributions using the code below
from lifelines import LogNormalFitter
from lifelines import ExponentialFitter
from lifelines import LogLogisticFitter
from lifelines import KaplanMeierFitter
from lifelines import WeibullFitter
from lifelines.plotting import qq_plot
# Instantiating the various distribution fitters
wb = WeibullFitter()
ln = LogNormalFitter()
Exp = ExponentialFitter()
logit = LogLogisticFitter()
kmf = KaplanMeierFitter()
# Fitting to the data to get the best possible parameters for each distribution
wb.fit(durations=prison['week'],event_observed=prison['arrest'])
ln.fit(durations=prison['week'],event_observed=prison['arrest'])
Exp.fit(durations=prison['week'],event_observed=prison['arrest'])
logit.fit(durations=prison['week'],event_observed=prison['arrest'])
kmf.fit(durations=prison['week'],event_observed=prison['arrest'])
# Plotting the various distributions over the Kaplan Meier Fitter
plt.style.use('ggplot')
fig,ax = plt.subplots()
ax = kmf.plot_survival_function()
ax = wb.plot_survival_function()
ax = Exp.plot_survival_function()
ax = logit.plot_survival_function()
ax.set_title('Survival Parametric Models vs Kaplan Meier Actuals')
ax.set_ylabel('Fraction')
Its not very clear here, but the logit and the Weibull distributions seem to overlay the closest on the KM curve. Another way is to use a qq_plot which compares the actual quantiles with the quantiles predicted by the distribution. If the empirical quantiles line up with the distribution quantiles then scatter points will be along the y=x line as shown in the plots below.
# Code to generate qq_plots from the already fitted models
models = [wb,ln,Exp,logit]
for model in models:
qq_plot(model)
plt.show()
Based on the qq_plot it is very clear that the Weibull distribution and logit are very close fits, we will use Weibull as our base distribution.
Now we can build the Survival regression model. This model will model the baseline survival using Weibull distribution and build a regression model on top of it which will give us an estimate of how much each factor, in this case like race or age, causes a deviation from the baseline survival.
To accomplish this we import the WeibullAFTFitter() class. This class has methods to fit a regression on top of a baseline Weibull curve.
from lifelines import WeibullAFTFitter
aft = WeibullAFTFitter()
# In this case we are using all columns in the data set
aft.fit(prison,duration_col='week',event_col='arrest')
aft.summary
The summary of the results is as shown in the table above.
- The 2 most important columns are exp(coef) and the p column.
- Since the regression model is fit on Log of the time (X) so the exponent of the coefficient is more interpretable.
- The p column shows the p_value and the features prio corresponding to priors and age have p values less than 0.05 and are statistically significant
Lets use the plot_partial_effects_on_outcome method to understand the effect of each of these factors vs the baseline survival
# Plotting Effect of priors on time to arrest
fig,ax = plt.subplots(figsize=(6,4))
ax = aft.plot_partial_effects_on_outcome('prio',[0,2,6])
plt.title('Effect of priors on time to arrest')
#Plotting Effect of age on time to arrest
fig,ax = plt.subplots(figsize=(6,4))
aft.plot_partial_effects_on_outcome('age',[20,26,35])
plt.title('Effect of age on time to arrest')
# Predicting survival for new customers in the data frame called new
aft.predict_survival_function(new).transpose()
We can also use the predict method and pass new instances and the fitted model will predict a survival curve for the new instance. In the business context new signups or subscriptions can be scored or assigned a survival curve at the time of sign up. This kind of model can be used as the first step in calculating Customer Lifetime Value (CLV).
Key shortcoming of this approach and possible alternatives
To use this approach it is critical to find a parametric distribution that fits the actual data (baseline distribution) fairly well. For this we can use qq_plots (as discussed above). In the above example the Weibull distribution fit the data well. In some cases you will not find a distribution that fits the data, in which case the parametric approach will not be an ideal approach. In such cases semi-parametric models may be used.
In conclusion parametric models are very powerful models to predict survival. They can estimate the effect of different features/covariates on the survival and can be used to predict personalized survival. This is very useful to understand what factors can affect attrition for customers or equipment failure in the case of operations.