Analytics Vidhya
Published in

Analytics Vidhya

Predicting Subscription Churn Using PySpark ML

Photo from

Customer churn is one of the major issues that businesses face, and minimizing churn rates has a significant impact on revenues and bottom lines. Each customer lost often hurts profits, that’s why companies go to great lengths to keep them satisfied.

This analysis shows how using customer action data can help pinpoint specific customers that have a higher likelihood to leave a subscription service. Having an “alarm system” like this in place can help businesses plan targeted interventions with the goal of minimizing attrition rates, like loyalty and reward programs.

How can subscription-based services predict churn? Which methods are efficient? Which variables and features are potentially helpful? These are the key questions that this analysis would attempt to address.

The Dataset

The dataset contains events from a fictitious digital music streaming service called Sparkify. The events represent the actions that customers take while interacting with the service, ranging from playing and liking a song, logging in and logging out, downgrading, canceling their subscription, etc. These are the fields included in the 23MB medium-sized data:

|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

The data contains a total of 4470 events coming from 448 unique users.

This analysis showcases the power of PySpark in tandem with IBM Cloud. While most of the analysis can be done with usual Python libraries such as Pandas, Numpy, and Scikit-Learn, Spark offers an efficient way to wrangle data and train models out of bigger datasets through its distributed technology. The model was deployed using IBM Cloud Lite for more computational power.

Exploratory Data Analysis

For this subscription service, we define churn as the subscribers who canceled their subscription. There were 99 users who canceled out of the 448 user base, placing the churn rate at 22%.

Count of Churned Users: 99
Churn Rate: 0.22098214285714285

Apart from cancellation, users can also downgrade their subscription. Downgrade events may be worth adding as a feature to predict churn. Those users who downgraded their subscription might be strong candidates to leave the service, and that shall be explored further in the analysis.

Count of Downgraded Users: 301
Downgrade Rate: 0.671875

Looking at the users’ demographic characteristics would be a good starting point to know which groups have a high likelihood of churning.

Churn distribution by Gender

By gender it appears that females have a higher tendency to churn (24%) compared to males (22%), although the differences are not that stark.

Churn distribution by level / tier of subscription

By subscription tier, whether they are in the paid or free categories does not appear differentiating in terms of churn.

Churn distribution by downgrade status

As expected, those who downgraded from paid to free tier appear to have a higher tendency of churning (25%) versus those who did not downgrade (16%).

Feature Engineering

Apart from the readily available fields, some features that can potentially affect churn rates were derived. Here are some of the features that I considered for modeling:

  • total_sessions: Cumulative count of sessions when the user signed in to the service
  • song_count: Cumulative count of songs that the user listened to
  • artist_count: Cumulative count of artists whose songs the user listened to
  • avg_session_length: Average length (in minutes) spent in a session
  • avg_session_gap: Average gap (number of days) in between visit
  • days_from_reg: Number of days from registration date
  • thumbsup_count: Total count of thumbs up made
  • thumbsdown_count: Total count of thumbs down made
  • addfriend_count: Number of friend requests sent
  • addplaylist_count: Number of “add to playlist” actions
  • rolladvert_count: Number of ads seen
  • pageerror_count: Number of errors the user experienced
  • nextsong_count: Number of times the user skipped to next song
  • help_count: Number of times the user clicked “help”

Will they help us come up with a useful prediction model for churn? We’ll see this in the next section.

Predictive Modeling

This churn prediction is a binary classification task. In the data, “churn” is a binary outcome that takes 1 as a value if the customer has left, and 0 if they are still subscribed to the service. This is the key outcome to predict using readily available, and engineered features.

The data was split into a training (70%) and test (30%) set, with 3-fold cross-validation performed on the training set to be able to get a better sense of out-of-sample accuracy.

Methods and Hyperparameters Used

Three types of models were considered: Logistic Regression, Random Forest, and Gradient Boosting Trees. All three models are available in PySpark ML as LogisticRegression(), RandomForestClassifier(), and GBTClassifier() objects.

Default hyperparameters were used in fitting all three models, with the goal of choosing the best initial model to tune, if needed. Here are the key hyperparameters used in the default models:

Logistic Regression: maxIter=100, regParam=0.0, elasticNetParam=0, threshold=0.5Random Forest: maxDepth=5, maxBins=32, numTrees=20Gradient Boosting Trees: maxDepth=5, maxBins=32, maxIter=20

Evaluation Metrics Considered

Several evaluation metrics can be considered in classification problems:

  1. Accuracy: The proportion of correct predictions among all the predictions made.
  2. Precision: Answers the question — What percentage of all predicted as churners are actually churners?
  3. Recall: Answers the question — What percentage of actual churners were correctly classified by the model?
  4. F1-score: Considers both precision and recall — basically the harmonic mean of precision and recall:
F1-score = 2 * (precision * recall) / (precision + recall)

Accuracy is a valid metric to use if the outcome is balanced or not skewed to a certain category. In this case, we know that the data is imbalanced, with a 22% churn rate.

One might say that recall may be the way to go — it makes sense as we would want our model to catch churners well to avoid the potential business loss. But what if the company is much more conservative in terms of budget, and would want to optimize by just spending when we are confident that the member is about to leave? In that case, precision is a much better metric to use.

Given this dilemma, I chose F1-score as it considers both precision and recall in the evaluation. It is also a better choice than accuracy given that the data is imbalanced.

The Results

These are the results I obtained after training these three models using default parameters:

F1-scores from training the 3 models

Gradient Boosting came out with the highest F1-score by far, with 0.96. This is already a good accuracy score, and I tried to further improve by tuning 2 of the hyperparameters: maxBins (32 and 50) and maxIter (20 and 30). The F1-score did not really improve, and actually went down a bit to 0.95.

Since the original F1-score is already good enough, I chose the original GBTClassifier model (maxDepth=5, maxBins=32, maxIter=20) as the final churn prediction model.

Feature Importance Ranking

Now that we have a good performing model to determine churn, the question is: which key variables contributed most to the predictive outcome?

What’s really great is that our PySpark model can provide answers for this, using the featureImportances method.

Feature Importance Scores from the GBTClassifier() model

From this, we can see that “days from registration” emerged as the most important variable in predicting churn.

Number of Days from Registration by Churn Status

This tells us that churners are generally more recent subscribers, with a median of around 2 months into their subscription. This information can be used to pinpoint milestones in their customer subscription journey. For example, once a user reaches 2 months into their subscription, Sparkify should consider providing promos, incentives, and features that will entice them to stay.

Other valuable features were average session counts in a week, song count, and average session length.


This article offered the following useful steps in gaining insights about customer churn:

  • Initial data inspection and cleaning
  • Data exploration and visualization to guide our initial hypotheses
  • Feature engineering to derive variables that are essential predictors to churn (This step has led us to the variables that were actually most important in predicting churn)
  • Choosing the best evaluation metric to assess predictions (In this case F1-score was the most valid metric to use)
  • Considering various methods for benchmarking purposes (We chose Gradient Boosting Trees as it came out with the highest F1-score)
  • Analyzing Relative Importances of features in predicting churn (This gave us actionable steps to guide possible interventions)


This analysis has shown that customer transaction and log data can reveal insights that are directly actionable for businesses. When big data technologies such as Spark and IBM Cloud are leveraged, the possibilities can go as far as our creativity in building features.

Further improvements can be done with the same data with these considerations:

  • More data exploration and feature engineering may help enrich insights. For example, if there is a larger dataset, location-based patterns can be utilized.
  • Apart from binary classification, other models such as survival analysis, and time series analysis can help inform not just the probability of customers churning over their lifetime, but the likelihood of churn at a specific point in time (e.g. after a year, what is the likelihood that the customer will still stay with the service).

All code used for this analysis can be viewed on Github. Feel free to post your comments and questions about the project!




Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem

Recommended from Medium

When data analytics meets procurement… Welcome to Category Management

The business of big data

How to Include R and ggplot in a Python Notebook

Tracking Economic Indicators with MachineMetrics Data

attempting to help push the needle in T1D research

Hail the Ultimate Playbook for Digital Transformation for SMEs

Techment’s Digital Matuity Matrix

4 Non-Data Science Books That Will Make You a Better Data Scientist

Data engineering in 2020

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Noemi Ramiro

Noemi Ramiro

Follow me for marketing analytics and data science reads.

More from Medium

Customer Churn Prediction in Pyspark

Real Time Data Processing Using Spark Streaming

Sparkify Customer Churn

Churn Prediction with PySpark