How to predict churn of Sparkify users

This article describes the method and some of the results I managed to achieve in the data analysis of user actions on a ficticious music streaming service called Sparkify. The objective was to identify and model the user churn, so as to be able to predict it. We are using the Spark library to manage the whole analysis inside a Jupyter notebook.
This data science project is the capstone project of the Udacity Data Scientist nanodegree, and the article is based on the full results available in a github repository at https://github.com/fkstef/sparkify
Project Overview
The Sparkify music service is wanting to identify the population of users that churn (they can be users from the free service, or the paid service). Once the possible criteria for churn are better undestood, our objective will be to build a model and predict the users churn so that some specific attention can be given to the users and try to retain them on the platform.
We will be using a tiny subset (128MB) of the full dataset available (12GB). The provided dataset contains the logs of the user actions (pages seen by users along with other encoded properties such as location, time, operating system used, subscription level of the user for example).
We will be using Spark, with the pyspark library (version 2.4.3) to handle the dataset. Even though the tiny subset of data fits into the memory of a single computer, the usage of Spark allows the same code to scale to the full dataset and allows the execution of the analysis on a cluster of machines.
Problem Statement
The problem is that the Sparkify service would like to understand the reasons behind the churn of the users, and to be able to predict it in advance. This would allow the service to implement actions to try to prevent the churn.
The following methodology will be used to come up with a solution to that problem:
- Data loading and cleanup: to ensure the rest of the data analysis is based on accurate and clean data
- Data exploration and understanding: before the creation of a model we need to understand the data we are observing, so the exploration will be used to get familiar with the data
- Features engineering: as a preparation to the model, we need to identify and build a features dataset that can serve as input to the model
- Modeling: here the real fun takes place, where we will get to try different models on the dataset and assess their ability to provide a good prediction
- Conclusion: once the most robust model was identified and selected, we discuss the possible problems of the approach we have used and identify future enhancements
Data Cleaning
The first step is to load the data with spark. The printSchema function is conveniently used to print a summary of the data structure:
root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)The dataset has 286,500 rows and there are no null values. We realise however that 8346 records have an empty string for their userId value, as identified in the first row in the following table:
+------+-----+
|userId|count|
+------+-----+
| | 8346|
| 10| 795|
| 100| 3214|
|100001| 187|
|100002| 218|
|100003| 78|
+------+-----+The rest of the data does not need further processing. The fact that the dataset is collected from user logs, automatically collected by the Sparkify platform, means the different available columns are filled in with data. Machine-generated data such as the one we have here will usually be easier to clean than data collected from user surveys, or manual data input.
Data Exploration
In the next step, we explore the data to get a better understanding on what are the possible values for the different columns,
For example, possible pages visited by identified users are:
+--------------------+
| page|
+--------------------+
| About|
| Add Friend|
| Add to Playlist|
| Cancel|
|Cancellation Conf...|
| Downgrade|
| Error|
| Help|
| Home|
| Logout|
| NextSong|
| Roll Advert|
| Save Settings|
| Settings|
| Submit Downgrade|
| Submit Upgrade|
| Thumbs Down|
| Thumbs Up|
| Upgrade|
+--------------------+We are also trying to identify the platform on which the user listens to the music, through a decomposition of the userAgent field and a grouping of the platforms into:
+--------+
|platform|
+--------+
| Linux|
| Apple|
| Windows|
+--------+Define Churn
In order to complement the exploration, and later in order to facilitate the modeling, we are adding a new column to the dataset called churn so we can identify the users who churned with a value 1 in that column, and 0 otherwise. We base this decision on the page visited named Cancellation Confirmation.
We can now continue with the data exploration, and will be representing information about the category of churn users, as compared to the rest of the population in charts.
The first useful information is the proportion of churn users which represents a little less than a quarter of the population as we see on the following chart:

Then we look at the gender information, and compare the proportion of male/female in the population that churn vs. the rest of the population. We can see that in the churn population we find a higher proportion of men than in the non-churn population:

The users can be subscribed to a different level of the service at the time of the churn, as there is a free service (with rolling ads) and a paid service. Since a user can upgrade or downgrade when using the service, the following chart shows the information of their last level of service right at the time of churn. We can see that the free users are present in a higher proportion in the churn population:

Since we encoded the platform information in the early steps of the data exploration, we can also use this information to assess the relative platforms used by the users at the time of churn. The chart that follows indicates the ratio of Linux users is higher in the churn population, while the ratio of Apple and Windows users is sensibly the same, or slightly lower than in the non-churn population:

One of the key information in the dataset is, of course, the actions of the users during their navigation on the Sparkify service. We can extract from this information the number of pages of each type visited by the users, and the following chart is a relative plot of the ratio of the pages seen in the churn group of users, as compared to the non-churn users. For example, the Roll Advert page was seen 69% more times by users who churn than those who didn’t:

In the chart above:
- a value above 0 is the ratio of pages that are seen more frequently by the users in the churn group
- a value below 0 is the ratio of pages that are seen less frequently in the churn group.
Feature Engineering
Now that we have a good understanding of the information contained in the dataset, we can build out the features to train a model on.
I have selected the following features that I believe are good indicators of difference in properties, or behaviour of users who churn:
- gender (M/F)
- level (free/paid)
- number of different artist per user
- platform (Apple/Windows/Linux)
- number of songs per hour on average
- number of Thumbs Down
- number of Thumbs Up
- number of Roll Advert
- number of Add Friend
- number of Add to Playlist
- number of UpgradeSo in the notebook, we are creating a new feature column for each of these selected features. Each feature column gives the feature encoding as a numeric value for each of the userId in the dataset.
For example, we build a feature column for the Thumbs Up action:
+------+---------+
|userId|thumbs up|
+------+---------+
|100010| 17|
|200002| 21|
| 51| 100|
| 124| 171|
| 7| 7|
+------+---------+
only showing top 5 rowsWe build another one for the Roll Advert action:
+------+------+
|userId|advert|
+------+------+
|100010| 52|
|200002| 7|
| 125| 1|
| 7| 16|
| 124| 4|
+------+------+
only showing top 5 rowsOnce we have our features columns, we build the features dataframe by joining all the individual columns together (using the userId as a key) so as to be able to use it in our modeling phase. We also take the step to replace null values by 0 as the nulls could be introduced by the fact every user did not perform every action (sparse data).
Here are the first few rows of the resulting dataframe:

With this dataframe, we calculate its correlation matrix as we want to see if we have a diversity of data so that the model will be able to separate the use of different features in the classification. If all columns are correlated it may be more difficult to distinguish the most important features from others.

Interestingly enough, the churn label is however not really correlated to any other column. Hence there is an actual need for a Machine Learning model to help us predict this value since a typical correlation is not obvious here.
Modeling
The data is now ready to be used with different Machine Learning models to try to predict the churn. We will test out a few classification methods and evaluate their accuracy using the F1 score as the metric to optimize.
This metric is best suited in our situation since there will always be a very small percentage of churn users in our data, and if we optimise on accuracy we could design for example a classification model that always returns 0 (meaning “not churn”) ending up with a high accuracy, but it would not be a good model for us since it never predicts churn.
The models we will be evaluating are:
- Random Forest Classifier
- Logistic Regression
- Gradient-Boosted Trees
- Multiplayer Perceptron ClassifierWe divide the data in 3 sets: training set (60%), test set (20%), and validation set (20%). All our model training and evaluation will be performed on the training and test sets, we will be using the validation set only as a final step to calculate the accuracy of the selected best model.
Every model is constructed using the Pipeline structure available in Spark. The stages of the pipeline are:
- a vector assembler: to take the individual features columns and build a vector out of it
- a scaling: we use normalization to ensure all features get a value between 0 and 1
- a model: this is the stage of the pipeline that will get adjusted for each model and which we will use for tuning its parameters
1. Random Forest Classifier
This model was created with the pyspark default parameters, and the resulting evaluation is:
f1 score: 0.786406870440484
accuracy: 0.8163265306122449We can get a view on the feature importance out of this kind of model. The chart below shows that “Thumbs Down” has the highest importance in the churn of the users.

2. Logistic Regression
This model was also created with the pyspark default parameters, and the resulting evaluation is:
f1 score: 0.7024205030849549
accuracy: 0.7551020408163265We can get a view on the coefficients of the logistic regression model. The chart below shows that the platform seems to have the most influence in the logistic regression model:

3. Gradient-Boosted Trees classifier
Similar to the above, this model was created with the pyspark default parameters, and the resulting evaluation is:
f1 score: 0.7285284484164036
accuracy: 0.7346938775510204Similar to Random Forest Classifier, we can also get a view on the feature importance here. The chart below shows that “Thumbs Down” has the highest importance in the churn of the users for this model as well, with “Add Friend” as a close second.

4. Multilayer Perceptron Classifier
This is our last model we evaluate, and we create the classifier with the following layers as a starting architecture: [13, 25, 25, 10, 2]. The input layer has to be 13, as this is the number of features in our dataset, and the output layer has to be 2, as we are predicting 2 possible values (churn / not-churn).
The resulting evaluation is:
f1 score: 0.8454388984509467
accuracy: 0.8571428571428571Parameters Tuning
The Random Forest Classifier and Multilayer Perceptron Classifier model provided the highest F1 score as well as the highest accuracy. We are therefore tuning a few parameters of both models with a ParamGrid and measure them on their accuracy on the validation set.
For the Random Forest Classifier, I decided to tune the impurity, maxDepth, and numTrees parameters. The resulting evaluation on the test set after tuning is:
f1 score: 0.6823323990981037
accuracy: 0.8163265306122449For the Multilayer Perceptron Classifier, I decided to only tune the layers parameter to evaluate different network architectures. The resulting evaluation on the test set after tuning is:
f1 score: 0.7096096198263381
accuracy: 0.7755102040816326The best model after tuning in our case is the Multilayer Perceptron classifier, which got a higher F1 score.
Note that the evaluation results after tuning can be lower than evaluation values of our initial model. This is due to the fact the values shown after tuning are an average over the n folds (in our case 3 folds) of the data, so a lower value is possible, and should also be more representative of the actual performance of the model.
We are now evaluating this best model accuracy on the validation dataset:
Multilayer Perceptron Classifier
===============================accuracy: 0.7924528301886793
Results
We have now trained 4 different models on the test set, selected the 2 best for parameters tuning, and evaluated the resulting best model on the validation set. A visualisation of these results is presented here:

The Logistic Regression algorithm had the lowest score of the selected models, since it relies on a linear relationship between the features and the label. However the user behaviour, and how the features relate to churn are not necessarily a linear relation, or the “area” of features which leads to churn is not necessarily in a single portion of the features.
This is the reason why I introduced other classifiers in my selection such as random forest or gradient boosted trees, to allow the model to be based on decisions such as if “feature x < value” and “feature y > value” and “feature z < value” that can isolate more precise “areas” of the activity of a user to predict churn. We could see this approach is giving a higher scoring of the model.
Finally, the multilayer perceptron classifier gives even more flexibility to the structure of the “network” that is implemented in different layers, so as to be able to identify relations that a tree may not be able to describe. Eventually we managed to train this model to perform the best.
Conclusion
In this project, I wanted to load and explore the Sparkify user log dataset, identify and engineer possible features that could indicate the user churn, and build a model to be able to predict it.
I identified there was not a lot of cleaning required, and focused on removing empty values for the userId field. The data exploration allowed me to identify some differences in the data when separating the population between users who churn and the others. This allowed me to define a set of interesting features to describe each user.
The definition of features was not very easy since in a lot of situations, there is no obvious influencer of the user behaviour (with regards to churn), and difference between the 2 groups is not always clear. Therefore I conside the selection and engineering of features as an “art” that needs to be grown and nurtured as a data scientist. I used the identified interesting features to design a small data transformation process to build the feature dataframe from the base user log data.
I identified from the pyspark classification algorithms a list of algorithms I wanted to explore further in the modeling phase, and which I tuned and evaluated on the F1 metric. The complexity of this task is that there are a lot of algorithms to choose from, and the selection of the model requires a sense of matching an objective with the capabilities of each algorithms, so as to maximise chances of success. The parameters tuning aspects also benefit from the experience of the data scientist to make choices that are likely to lead to an improvement of the model.
After running the best model on the validation data, we managed to achieve an accuracy of 79.2% with the Multilayer Perceptron Classifier. This result is quite good, based on the small size of our dataset, and the relative small portion of churn users in the dataset.
Future improvements
This data analysis would benefit from being run on the full dataset, and this should be a future enhancement for this work. Getting an AWS cluster up and running is not a very complex task, and the whole notebook should be executed there.
Future work could include the engineering of more or different features which would be less correlated together to see if this could further improve the model accuracy.
It will also be valuable to run a wider variety of parameters tuning for the models, to ensure multiple combinations are explored and that the model finally selected has been well optimised.
Learnings
I learnt many things in this project, especially how to use pyspark to handle a large dataset and distribute the data and computations to a cluster of machines.
I liked the functional programming approach that is necessary when running distributed computations (to make parallelisation of work easier, and to run calculations on chunks of data).
Pyspark provides a number of models and the fact it is built by following a similar logic as sklearn is helpful to get started more easily, however I find the structure of the pyspark documentation less convenient to work with that of sklearn. It was not very easy to find the relevant information in the very long pages of the API documentation of spark. However I am likely to use it again since the problems of data science are often beyond the capabilities of the memory of a single computer.
You can find the details of the analysis on github: https://github.com/fkstef/sparkify
You can also visualise the workbook with all results here: https://nbviewer.jupyter.org/github/fkstef/sparkify/blob/master/Sparkify.ipynb
