When music is your passion…

alessandro guarnieri
Analytics Vidhya
Published in
7 min readAug 20, 2021
Churn prediction with Spark ML

Project Overview

We all, at least once, listened to music through a streaming application and we enjoyed the possibility to quickly find our favorite singer and listen to her/his music as long as we wanted.

As for every application, there are some premium features available only if you decide to pay a certain amount of money and become a premium member.

As long as you are happy with the features offered by the app, you keep paying to keep your premium features but it can happen that after a while, for some reason, you decide to stop your payment or worse you decide to unsubscribe.

What if we can leverage machine learning to predict if a user will unsubscribe?

In the following paragraphs, I will show how I approached this mentioned problem, which is a common business problem called “churn prediction”. I will show how the available data were processed and how a machine learning pipeline, based on supervised machine learning techniques, was implemented.

Using Machine Learning models and pipelines to predict user’s churn here is a good way of proceeding because we can implement a model which can predict what the single user will do based on the previous behavior. Statistical methods and tests can give useful results but are not as granular and detailed as a machine learning model.

The final goal is to produce a model which can be used to classify users and predict if they will churn or not.

Data Exploration and Visualization

The dataset used in this project contains a subset of data related to users registered to Sparkify, a fictitious application for music streaming.

Data were loaded into a Spark dataframe to be analyzed. The schema, as follows, contains an ID for each user and session together with information about the user itself (gender, location…) and the actions done (visited page…), ordered by timestamp.

Dataframe schema

To start our analysis, EDA was done to understand better the content of the dataset.

The dataset contains 278154 rows and 225 different users. The features available before feature engineering are 17, excluding the userId column.

Here are some results from the exploration (more can be found in the related Jupyter notebook):

Details on some columns of the dataframe

The most important data cleaning done during this phase was the removal of rows that had a null User Id. Moreover, it was interesting to notice that there were some null values for certain columns, since the action of the user was not related to a song, so for those rows the song’s author, name and duration were not populated.

Data preprocessing and implementation

Label definition

The goal of our analysis is to predict when a user will churn, so we added a column, the label, that defines if a user decided to cancel the subscription or not.

pySpark code to define label column “Churn”

The column “Churn” will be used as the target column for our model so that we can use it together with the feature columns to train and then test the machine learning models.

Feature Engineering

The features available in the dataset need to be manipulated to be used in a Machine Learning model.

Categorical features

“gender” was the only categorical feature used. In this case, it was necessary to convert the feature from categorical to numerical, using the One-Hot-Encoding technique.

Numerical features

The “page” column was used to build some of the numerical features, creating a count of the page actions for each user. Besides this, some other numerical features were added such as: “number of sessions per user”, “number of listened songs per user”, “average number of listened songs per user per session”

Note on data leakage

It was important to take into account data leakage when performing feature engineering, for example, it was necessary not to include “Cancellation confirmation” or “Cancel” in the features to avoid overfitting.

Model Evaluation and Validation

The ML models used in this project were Logistic Regression and Random Forest. This choice was done taking into account the nature of the problem, which is a classification problem, and the fact that the classes are imbalanced.

The training was performed on 80% of the data, while the remaining 20% was dedicated to validation.

Hyper-parameters tuning

Cross-Validation was performed on the training set so that we could tune the hyperparameters of the models and choose the best ones.

Definition of Cross-Validation, after defining pipeline, hyperparameters and evaluator

Here are the best values for the parameters after hyper-parameters tuning

Logistic Regression:

  • regParam: 0.1
  • elasticNetParam: 0.0

Random Forest:

  • numTrees: 50
  • maxDepth: 15

Metrics

The metrics used during cross-validation were f1-score and AreaUnderPR because the classes were imbalanced and predicting well the positive one was more important.

On the best model, other metrics were also computed for the positive class, as shown below:

Metrics on the test dataset, after training (evaluation metric: f1-score)

The trained Random Forest model was used also to predict labels of the training set and as it is shown below, the metric’s values on the training set are higher so the model is maybe overfitting:

Side notes on evaluation

Parameters responsible for overfitting

It is interesting to notice that out of the two parameters tuned on the Random Forest model, the one more affecting the overfitting was the maxDepth one. It was very important to tune it correctly and for this was necessary to perform different experiments.

F1-score in pySpark: challenges

The F1-score metric was used as an evaluation metric and for fine-tuning of the model. The Spark class MulticlassClassificationMetrics performs an average of the f1-scores of the classes so we had an f1-score of 0.58 for Logistic Regression and 0.66 for Random Forest.

Despite this, if we look at the f1-score for the positive class (shown above), we see that the f1-score of the Logistic Regression is 0. This has an impact both on the cross-validation process and on the perception we have of the Logistic Regression model, which could seem better than it is in reality.

Validation

Using collectSubModels during cross-validation, it was possible to keep the models of each fold and then evaluate them as well to test the robustness of the best model. The selected model appears not to be robust and a bit too dependent on the chosen dataset. This issue could be due to class imbalance and could be also solved using more data.

Final reflections and improvements

The chosen model to solve our churn prediction problem is the Random Forest one because as shown above it outperformed the Logistic Regression one, based on the evaluation metrics used. After hyper-parameters tuning the model scored a 0.66 F1-score and specifically for the positive class a 0.28 F1-score. Out of all the features used, the most important ones to predict a churn were “ThumbsDown” and “ThumbsUp”.

The most important part of the whole process was the hyper-parameter tuning: it was important to understand the meaning of the metrics used in relation to the business problem to solve and it was important to check not only the overall F1-score but also the one of the positive class.

Are we finally ready to help Sparkify predict if their user will churn or not?

Well, a big step was done but there could be many other steps to improve both the preprocessing and the model tuning.

For example, it could be useful to run this code into a distributed cluster to leverage the power of Spark and see if a better performance could be achieved.

It would also be interesting to correct the class imbalance or to add additional features, for example, a categorical feature such as the last status available for each user (free or paid) or other numerical features.

Finally, it could also be useful to use different models or create an ensemble of models.

You can find the complete Jupyter Notebook here

--

--