Sparkify: Using PySpark to Predict Customer Churn

Vagner Zeizer C. Paes
Geek Culture
Published in
10 min readMay 2, 2021

Introduction

In this story, I will show a data science study based on a fictitious enterprise named Sparkify (like Spotify) to predict customer churn on a dataset that mimics music streaming real-world data. Customer churn happens when a customer decides to cancel or downgrade its subscription to a given plan. This project is my personal choice for the capstone project of the Nanodegree Data Scientist of Udacity.

Goals of the Project

The main goal of this project is to create a machine learning model to predict which customer will (possibly) churn at a given time. If the machine learning model is capable of predicting with high reliability which customer is likely to churn, the same model applied over millions of customers should be able to save millions of dollars for (the fictitious) Sparkify by contacting these customers and offering them a discount, potentially avoiding the predicted churns to take place. Exploratory data analysis, feature engineering and machine learning modeling will come in handy to find a good approach to solve this problem in PySpark. The chosen metric to be used when validating the machine learning model in the test data will be the f1-score, since it is a weighted measure of precision and recall.

Understanding the data

In order to run this dataset in the Udacity clusters, a mini version consisting of 286500 rows and 18 columns was made available. The dataset showed no duplicates, however, there are many NaN's in it, which were properly handled as it will be shown here. The full code can be downloaded in the following GitHub repository here.

Let us first start by printing the schema of the Dataframe (the mini dataset which was made available):

|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

The columns of the Dataframe are self-explanatory and the most important for this project will be studied in detail as shown below. It is worth emphasizing that since the Dataframe contains the time (‘ts’) as one of its columns, it consists of a time series.

In the search for nulls and NaN's in the dataset, we find

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId| song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
| 58392| 0| 8346| 8346| 0| 8346| 58392| 0| 8346| 0| 0| 8346| 0|58392| 0| 0| 8346| 0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+

This is quite a lot of missing data considering the overall size of the data. Therefore, it is not a good idea to drop the missing data. Because of that, we approached the problem by inserting “missing_val” in the NaN's of the categorical columns and by placing the respective column's average in the numerical columns. The cleaned dataset had not NaN's anymore.

The next step was to create the churn column, which can be calculated by considering the page column and regarding that when we have “Submit Downgrade” or ‘Downgrade’ or ‘Cancellation Confirmation’ or ‘Cancel’, churn has taken place. This led to an imbalanced column ‘churn’ in the dataset.

Now, we are ready to move on to the next stage of our story, which is the exploratory data analysis, which will be done in Pandas!

Exploratory Data Analysis

In the Exploratory Data Analysis shown below, I will investigate customers who churned, getting insights from this data.

  1. Churn imbalance:

Figure 1 shows the total number of no-churns (zero's) and customers who churned (one's) in the dataset:

0    284278
1 2222
Figure 1. it is shown the imbalance of the churn values, zero's represent customers which did not churn and the one's represent customers who churned.

The churn imbalance is shown in the above histogram and it can be seen that it is a very imbalanced dataset. This, a priori somehow should make the prediction harder to be made.

2. Churn length:

The below histogram (Figure 2) shows the length of the song for customers who churned. Interestingly, they all had the same length!! This is related to the fact that the artists and songs of customers that churned were all NaN's in the original given mini dataset.

Figure 2. It shows a histogram of the length of the song for customers who churned.

3. Churn by gender:

Figure 3 below shows the gender histogram for customers who churned.

Figure 3. This figure shows a histogram for each gender for customers who churned.

As it can be seen, more females churned than males. It seems appropriate for Sparkify to announce a discount for females in order to diminish this pronounced “female churn”.

4. Churn by location:

Figure 4 below shows the total number of churns by location.

Figure 4. It shows the top 10 locations where churn took place.

From this figure, it can be seen that Los Angeles, New York, and Boston have the highest churn numbers. Therefore, it is appropriate to offer a discount for Sparkify's plans in order to minimize customer churns in these places (the best it would be to apply this discount in the top 10 places shown in the figure).

5. Churn userAgent:

Figure 5 below shows the total number of customer churn by userAgent.

Figure 5. This figure shows the total number of customers who churned by userAgent.

As previously discussed, we can work on the top 10 userAgent and try to minimize the customer churns.

6. Churn itemInSession:

Below (Figure 6) is shown a plot of the total number of customer churns by itemInSession.

Figure 6. It shows the top 20 itemInSession from which customer churn takes place.

From the above figure, the same discussion as for previously mentioned features hold, it is appropriate to work on this top 20 itemInSession in order to try to minimize customer churn.

7. Temporal Exploratory Data Analysis

‘Now, it comes to the temporal analysis of this dataset.

Figure 7 shows the total number of customer churns by day in the given period.

Figure 7. It shows total churns by day.

This time series is interesting, it has an almost well-defined pattern, with the number of total churns reaching maximums and minimums almost periodically.

Feature Engineering

In this session, I will summarize the feature engineering process which was done in order to create the machine learning models through pipelines.

Columns that we found to be useless to be predictive (they were dropped):

to_drop_cols=[‘artist’,’auth’,’firstName’,’lastName’,’page’,’registration’,’sessionId’,’song’,’hour’]

since ‘artist’ and ‘song’ were just NaN's for the churns, and ‘firstName’, ‘LastName’, ‘registration’, ‘sessionId’, ‘hour’ (defined as a function of time) do not carry important information about the customer.

Columns [‘gender’, ‘level’, ‘method’] were feature engineered by using StringIndexer and OneHotEncoder, while columns [‘location’, ‘userAgent’] were just feature engineered with StringIndexer since they contain too many categorical values. A vector assembler was created from these engineered features.

Machine Learning Model

In order to avoid data leakage of this time series, the mini-dataset was firstly chronologically ordered and split into training (80%) and test (remainder 20%) Dataframes. Several machine learning models along with parameter tuning were tested in order to find the most appropriate one.

We have set in all the machine learning models a reasonable number of cross-validation folds equals 4. Our metric to be maximized (the closer to 1.0, the better) was chosen as to be the f1-score, as previously discussed. For some models, the parameters grid was exhaustively evaluated in order to improve results, while for some more convenient models we reached a good model straightforwardly and easily without an exhaustively grid search, but at a higher computational cost.

In summary, we have tested the following models:

  1. Logistic Regression (LR): regularization parameter equals 0.05 and 0.15, elastic net parameters equal to 0.25 and 0.75, and maximum iterations equals 5 and 10; The best parameters were found as to be: regParam: 0.05, elasticNetParam: 0.25, maxIter:5;
  2. Linear Regression (LReg): regularization parameter from 0.0 to 0.8, and elastic net parameters from 0.0 to 0.75; the best parameters are: regParam: 0.0, elasticNetParam:0.0;
  3. Random Forest (RF): number of trees from 50 to 100, and max depth from 10 to 20; the best parameters are: numTrees: 50, maxDepth: 15;
  4. OneVsRest Classifier (OVR) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently: regularization parameter varying from 0.05 to 0.15; elastic Net parameter with values of 0.25 and 0.75; maximum iterations equals 5 or 10;
  5. Support Vector Machine (SVM): it was run with regularization parameters varying from 0.05 to 0.4; the best parameter, regParam was equal to 0.1;
  6. Gradient Boosting Tree Classifier (GBT): it was run with default parameters. The best parameters were the default, naturally.

The table below summarizes our best findings for each model on the test data:

Model f1-score

LR 0.98829

LReg 0.00000

RF 0.98829

OVR 0.98829

SVM 0.98829

GBT 0.98799

Random Forest was chosen to be the most appropriate machine learning model for this problem. Other models with similar f1-score on the test data could have been chosen, but we chose Random Forest due to its nature on decision trees and relatively fast training speed. Figure 8 shows the top 10 most important features obtained from this classifier:

Figure 8. It shows the top 10 feature importances for the random forest classifier.

The most important feature refers to the columns are ‘userAgent’, ‘itemInSession’, and ‘ts’ (time).

Model Evaluation and Validation

The best model, the Random Forest, presents a quite good result on the training data (0.98796) and test data(0.98829), with the best number of trees equals 50 and maximum depth equals 15. More trees or a larger maximum depth do not make the algorithm perform better. These values are quite good for a Random Forest classifier and train relatively faster. Cross validations with K=2 or K=3 were run and similar results were found. Additionally, the top 3 features of this model make very sense to be responsible for customer churn, as previously discussed. Moreover, as will be discussed in the Justification section below, Random Forest is suitable for large datasets.

Justification

Random Forest is my current choice as to be the best machine learning model for this problem because of its easy understanding due to its nature on decision trees and relatively fast training speed. Although more complex than decision trees, Random Forest is not prone to overfit in the test data. Moreover, the power and scalability of Random Forest for massive data makes it proper to run it in clusters such as Amazon Web Services, or IBM Cloud.

Discussion

We have successfully trained several models on the test data, reaching a very good f1-score. This might be due to the small size of the data. Running the full 12 GB dataset should be a better approach, getting a real value for the f1-score. However, our model does not overfit and many interesting conclusions could be drawn from these analyses, as was shown here. Linear Regression, the simplest model, failed miserably in the training and prediction and the other models reached similar results.

Conclusion

To summarize, we have studied in detail a dataset through PySpark and Pandas that mimics a real music streaming service, named Sparkify.

Our EDA results have pointed out that a small fraction of the customers churn and more females churn than males. Locations, itemInSessions, userAgent were investigated and customer churn can be minimized by properly handling these features.

The current machine learning model results have shown that one can consider Random Forest as the most appropriate one for this problem. This model has a great f1-score on the test data, therefore it probably would save money millions of dollars for Sparkify when applied over millions of users when trying to minimize customer churn. The most predictive columns of this model are ‘userAgent’, ‘itemInSession’, and ‘ts’ (time). This is quite interesting since the two first features were explored in the Exploratory Data Analysis section and they play a fundamental role in the customer churn. So, working on these features is really important and deserves strong attention from the (fictitious) enterprise Sparkify.

Further Improvements

As further improvements for the Exploratory Data Analysis, one can make a more in-depth data analysis of the time series. For instance, we work on an analysis of the total number of customer churns by minutes, in the same way as was shown here for several days.

The feature engineering could have been done differently (based on this notebook from Kaggle):

  1. Use scaling and normalization: StandardScaler and MinMaxScaler, and many others.
  2. Bucketing for converting the continuous variables into categorical variables. After that, we can tackle the problem just by StringIndexer and OneHotEncoder.
  3. Use PCA to use the dimensionality of features.
  4. Moreover, feature selection by the chi-square selector.
  5. Test Models like Multilayer Perceptron, xgboost and catboost classifiers can also be tested. However, in order to train xgboost and catboost you will have to use PySpark wrappers.
  6. Stacking: Once it was a performed supervised learning of several models, stacking could be used in order to get a better model from the other predictions. Stacking uses as a first-level (base), the predictions of a few basic classifiers and then uses another model at the second-level to predict the output from the earlier first-level predictions.

I hope this was an insightful read for you :)

Feel free to add me on LinkedIn.

Constructive criticism is welcome.

If you liked it, please give it some claps.

--

--

Vagner Zeizer C. Paes
Geek Culture

Data Scientist; Data Passionate; Applied Machine Learning; Data Analysis