Predicting User Churn — Apache Spark

Mati Kucz
14 min readJan 23, 2019

--

Project Definition

I am always eager to learn new frameworks and expand my capabilities, so when I heard about the possibility of a project utilized Apache Spark and Hadoop I was already very intrigued. Having learned the basics of Apache Spark’s PySpark API, there is no better way of displaying machine learning prowess than in a big data context. This project revolves around a key business issue that many firms face; How can we know which customers want to leave, and how can our marketing department target them?

Business applications are what excites me the most about Data Science. Proving that I can glean valuable insights from corporate-sized data sources would prove to me that I can say Big Data as more than just a buzzword.

Apache software foundation [Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)], via Wikimedia Commons

See the project development and source code on GitHub

Overview

The goal is to analyze a 12GB+ dataset that contains information regarding users and their usage of a music streaming service, create features based on this data, and train a supervised machine learning algorithm to be able to predict users who are most likely to quit the service — churn.

I will use Apache Spark 2.4.0, utilizing PySpark. PySpark is the Python API for Apache Spark. PySpark.ML is the main PySpark machine learning package. PySpark.MLLib is another machine learning library, but it is based on the Apache Spark’s RDD’s (Resilient Distributed Dataset). I will run the programme using AWS’s EMR (Elastic Map Reduce) Hadoop framework where one can configure cluster sizes and computational power, while leaving much of the backend , storage, and software configuration to the AWS platform.

PySpark.ML provides and extra layer of abstraction on top of this and lets the user work strictly with DataFrames. Programming is more granular with RDD’s (and providing more functionality), but it is possible to use both frameworks concurrently. Apache Sparks’ newer framework, PySpark.ML, is what I will focus on using.

Apache Spark’s SQL library provides the ability to query structured data, but is equivalent in functionality to the DataFrame library. Code written in both Spark+SQL or Spark+DataFrame goes through the same optimizer in Spark, so there should be no performance difference either.

In order to solve the churn prediction problem I have to think critically about the defining factors of why a user would cancel their account. This could include:

  • How recommended songs fit musical taste
  • Frequency of use
  • Subscription level and status

The schema of the dataset is as below:

root
|-- artist: string
|-- auth: string
|-- firstName: string
|-- gender: string
|-- itemInSession: long
|-- lastName: string
|-- length: double
|-- level: string
|-- location: string
|-- method: string
|-- page: string
|-- registration: long
|-- sessionId: long
|-- song: string
|-- status: long
|-- ts: long
|-- userAgent: string
|-- userId: string

Problem Statement

The end goal is to successfully identify users who will churn. To start off, exploratory data analysis is important to understand the dataset and be able to think critically about how features can be defined. I will follow PySpark Pipeline workflow that relies on Transformers (DataFrame modification) and Estimators (an algorithm like a supervised learning model). Firstly, I will perform exploratory data analysis to investigate the data and see what features could prove to be useful. Thereafter, creating a set of functions that will take care of creating the features will prove useful in fitting and optimizing models. I will compare the performance of various supervised machine learning algorithms on a validation set, before optimizing hyperparameters of the best models. The development will be done on a small subset of data to decrease computational time. From experience, random forest classifiers are the best, and computationally less expensive than support vector machines. However, distributed computing might make the difference in fitting complexity negligible.

Metrics

It is important that the model is able to find as many users who are considering churning as possible. This is the recall metric. However, our model cannot be based on recall alone as then it would be best just to guess that ALL users will churn. Albeit successfully maximizing the recall metric, it is unrealistic in a business setting to provide incentives to every user. Precision gives a rating of correctly identified users who churn out of all users who were marked as churning by the model. The F-Score (harmonic mean of precision and recall) combines both of these metrics into one score. This will be the most important metric.

It turns out that the BinaryClassificationEvaluator only accepts AUC-ROC or AUC-PR metrics for its Pipeline optimization. In the case of this data set, there are more non-churners than churners by a wide margin, so due to this dataset imbalance it makes sense to use AUC-PR as the optimization metric.

Exploratory Data Analysis

It is really important to investigate a data set and the variables included in it to understand how to best structure it for algorithmic processing.

Essential variables

itemInSession — nth action in users current session, length — how long a song is,level — whether user is paid or free,registration — when a user signed up,ts — when user accessed certain page,userId — user’s unique identifier,sessionId — unique identifier for a user’s session, page — what a user action is doing (described below)

+--------------------+
| page|
+--------------------+
| About|
| Add Friend|
| Add to Playlist|
| Cancel|
|Cancellation Conf...|
| Downgrade|
| Error|
| Help|
| Home|
| Logout|
| NextSong|
| Roll Advert|
| Save Settings|
| Settings|
| Submit Downgrade|
| Submit Upgrade|
| Thumbs Down|
| Thumbs Up|
| Upgrade|
+--------------------+

Non-essential variables

artist firstName lastName auth location method song useragent status — These variables don’t tell us much about how the user behaves, but only about who they are. I don’t believe what browser is being used influences the churn rate for example. The type of music listened to should not either. method and status is information from HTTP requests, and will not unveil much about user behaviour.

Plots of variable relationships — small subset

I explored the relationships in the small subset of data to decide on what features I would create.

This shows that churn could be tied to song playing behaviors, duh!

Methodology

Data Preprocessing

Preprocessing of the dataset focuses on creating user level features from the click-level dataset. This was also probably the most difficult step as it requires critical thinking regarding the data and an element of putting yourself in a users’ shoes. After the rows where the userId or sessionId is an empty string are removed, (presumably errors, or simply a tracking of users who have not yet subscribed). I begin to codify the features. I do not clean any of the other essential features I listed in the exploratory data analysis section, even though length contains null values or a minimum value of 0.78322 seconds . While on the surface it seems awkward for the streaming service to include such a short clip, it doesn’t mean it is an illegitimate value. Furthermore, nulls are acceptable since not all actions are associated with song names for example. The resulting schema was as follows:

root
|-- userId: string (nullable = true)
|-- downgraded: long (nullable = true)
|-- cancelled: long (nullable = true)
|-- visited_cancel: long (nullable = true)
|-- visited_downgrade: long (nullable = true)
|-- dailyHelpVisits: double (nullable = true)
|-- dailyErrors: double (nullable = true)
|-- free: integer (nullable = true)
|-- paid: integer (nullable = true)
|-- avgThumbsUp: double (nullable = true)
|-- avgThumbsDown: double (nullable = true)
|-- numFriends: long (nullable = false)
|-- avgSongsTillHome: double (nullable = true)
|-- avgTimeSkipped: double (nullable = true)
|-- skipRate: double (nullable = true)

The two most important aspects of the data preprocessing step, essential to creating user level data, were:

  1. pyspark.sql.functions.udf— user defined function — that I frequently used to keep track of factors within features.
  2. pyspark.sql.Window — help keep track of information across rows
  • userId was simply extracted by finding all the unique values
  • downgrade and cancelled count how many times a user has pressed the Downgrade Confirmation or Cancellation Confirmation buttons respectively.
  • visited_cancel and visited_downgrade are distinguished from the the features above in that when a user visits a cancellation page it usually requires them to confirm that intent to prevent accidental unsubscriptions. I believe that is feature would be strongly indicative of churn. Those who look at downgrading or cancelling are most likely discontent with the service. Taking the sum of the visits to these pages per user gives portrays how many times they have thought of quitting, even if some of those visits are inadvertent.
  • dailyHelpVisits and dailyErrors also keep track of the number of visits to each of the respective sites. Users who get many errors or need a lot of help navigating their interactions with a service are most likely to be discontent with it. Engineering these variables required summing and then averaging the visits over daily intervals. This was done with window functions and converting the Unix timestamps into datetime.date objects.
  • free vs paid users could also differ in their propensity to churn. It is reasonable to assume that regardless of the tier a user subscribes to, they contribute equal amounts of revenue to the business whether it is through their monthly dues or the ads they listen to. Therefore, it is prudent to keep both paying and free users. However, it is probable that users who pay are more likely to cancel — potentially without downgrading to a free subscription first — since they are actually paying money for the service.
  • avgThumbsUp & avgThumbsDown are also quite self explanatory names. It was created using the same Window function aggregation process as dailyHelpVisits and dailyErrors . It tracks the number of times users rate songs up or down. If users rate a lot of songs up then they are probably content with the service, regardless of the subscription level they are on. On the other hand, many down votes would also be indicative of an unhappy user. It is important to keep in mind that these features alone are not strong enough indicators since users can have different interactions with the service, where some that are unhappy will not interact at all, and others would have many down votes.
  • numFriends is a feature created in a similar vein to avgThumbsUp described above. Users who have a strong community with many friends are less unlikely to unsubscribe. It is counted by summing the total number of occurrences of the Add Friend page.
  • avgTimeSkipped was the last feature I implemented after deciding that I should delve deeper into the data to see if there were insights I could glean that were not visible at first. This variable was complicated as it requires calculating the difference between how long a song was supposed to be and the timestamp of the next http connection. This also only takes into account concurrent songs in one session. The assumption is that the songs do not keep playing if a user goes to see another page. One thing that became evident is that even if a user presses the thumbs up or down buttons, the song keeps playing. Therefore these events are removed before the analysis is started. The general approach is then to also use window functions, and pyspark.sql.functions.lag to be able to find the difference between the current and next timestamp and compare that to the song length.

Implementation

The overall flow of implementation is as follows:

  1. Load data in from remote database
  2. Engineer a new user-level DataFrame from click-granular data with the feature_engineering() function. Set DataFrame persistence mode to true with df_scaled.persist() to save the DataFrame in memory and help reduce computational time since the same DataFrame is being used in the creation of features.
  3. Scale data with the feature_scaling() function
  4. Define training, validation, and testing sets with a 0.85 : 0.075 : 0.075 ratio. This is done by using the randomSplit(), setting a constant seed and then splitting the result of the first split again to get the validation and testing sets.
  5. Define a classifier, fit() the classifier and then call transform() while passing in the validation set to get a DataFrame with a new column for predictions.
  6. Call custom_evaluation() to compare the model’s predictions to the actual labels.

Custom Evaluation

The BinaryClassificationEvaluator objects that come with PySpark 2.4.0 are limited in the metrics that they provide. Unlike the multi-class evaluator equivalent, the binary one only provides Precision-Recall Area Under Curve (PR_AUC), or Receiving Operator Characteristic Area Under Curve (ROC_AUC) as metrics. To get a more in-depth assessment, I evaluate the number of true positives (TP), false positives (FP), true negatives (TN), and false negatives (FN) manually, and print out the PR_AUC for comparison. In certain examples I have seen usage of the multi-class evaluator for binary evaluation which then allows easy evaluation of a variety of metrics, however this did not work for me. Moreover, the BinaryClassificationEvaluator returned a PR_AUC=1 for all classifiers. I was not able to figure out this discrepancy in the metric.

Model Training And Evaluation

I selected four classification algorithms to initially train and compare the metrics on the validation set. They were chosen for their classifying and computational efficacy (multi-layer perceptron classifier was left out). The four algorithms were:

  1. Random Forest
  2. Gradient Boosted Trees
  3. Support Vector Machines
  4. Logistic Regression

What followed was very odd, since three of the models — random forest, gradient boosted trees, and support vector machines — they perfectly classified the validation set (only true negatives and true positives) whereas logistic regression only did well in classifying non-cancelling users as it did not predict any users to churn. This result was naturally quite alarming (I assumed it more likely that my process was faulty rather than having a perfect classifier) and began to investigate these results (which is harder than it sounds working with such a big file). However, there was no data leakage to be found, and comparing the predictions and labels manually did not disprove the metrics that were returned.

Comparison of Confusion Matrices between best and worst classifiers

Refinement

Since the model returned a perfect classifying score for the validation set there is not much to be improved. However, I wanted to investigate whether Principal Component Analysis could be used to decrease the number of features, and also shed light into what are the most important factors in determining whether a user will want to cancel. The PCA analysis shows that 97.69% of the variance in the dataset can be explained by the first 6 components. Therefore, when fitting and predicting it would be prudent to save time and computational power, by performing PCA(k=6)

Skee plot informing the decision to keep 6 principal components. Even 4 would be enough

Results

Model Evaluation

The final chosen model is a Random Forest classifier which was chosen due to it being the fastest model to train and generate predictions by half being twice as fast as the next fastest model. The classifier that predicted on a dataset of reduced dimensionality still managed exceptional performance.

Confusion Matrix of testing set results -perfect classification

The random forest classifier was mostly the out-of-the-box PySpark random forest. Below are the main parameters of the model:

featureSubsetStrategy : 'auto' - the number of features to consider for splits at each tree nodenumTrees: 10 - number of trees to train & query (default=20)minInfoGain: 0.0 - min info gain for a split to be considered at a tree nodeminInstancesPerNode: 1 - min # of instances each child must have after splitimpurity: 'gini' - criterion used for information gain calculationmaxBins: 32 - max # of bins for discretizing continuous featuresmaxDepth: 5 - maximum depth of the treesubsamplingRate: 1.0 - fraction of the training data used for learning each decision tree

The most interesting part of examining the features of the model is that it actually uses a much simpler version of the random forest classifier than the OOTB version. Instead of creating 32 decision trees that classify each instance, it only creates 10 during training, and lets the trees vote during prediction. The fact that this model did not make mistakes in validation or testing datasets suggests that it is a robust model.

Justification

The striking success of the final model is in stark contrast to the abysmal performance of logistic regression. The downfall of the logistic regression classifier is probably multicollinearity of input features, and possibly dataset outliers as well. This would substantially decrease the model performance, and it is clear that this approach was not good in this case.

The relative speed and efficiency of random forest classification is clear as there is relatively little training to do, as the classifier is based more on a set of randomly selected decision trees.

The final random forest classifier performs exactly how one would want it to. It correctly classifies all users who will cancel their subscription, and does not have any false positive results. This means that whatever action the service wants to take, they will not be wasting resources on trying to keep users who were not going to leave anyway. No false negative predictions is also a great sign of a solid model, as it does not miss any users who churned without classifying them as such.

Conclusion

Reflection

Classifying user churn likelihood based on interaction-level data is challenging as it requires a high level of feature engineering to be able to predict what a user is feeling — their satisfaction of a service. There are naturally factors outside of a users satisfaction that might affect their subscription status that are not included in the data provided. However, although the challenge is great, the fun and exciting aspect is putting oneself in a users mind and thinking of the reasons they might want to quit, and thereafter creating numeric variables that would describe this to a machine. Distributed and parallel computation using a framework like Apache Spark makes this process easier by allowing computation of large datasets at reasonable times. The engineered features focus on a variety of behaviours such as interactions with the service through thumbs up, thumbs down, or adding friends, as well as listening habits that focus on time and habituality rather than music taste. The features also include measures of satisfaction with the service by tracking visits to help or downgrade & cancel sites, or number of errors users see daily. The features are then scaled to the range [0,1] so that binary variables are not distorted, and the larger values do not distort the model. Finally principal component analysis is used to reduce the dimensionality of the dataset and keep only the top 6 features that also then explain more than 97% of the variance in the dataset.

Part of the project that was difficult was scaling up the implementation of the scripts and algorithms using Apache Spark and AWS. There were a lot of errors and exceptions that I tracked that could only be attributed to AWS or Hadoop backend. Furthermore, the dataset is so large with so many possible features and explanations that it was also difficult to narrow down what feature engineering should be focused on.

Improvement

Since many of the models performed very well, I would want to perform k-fold cross validation to tease out any differences between the models. However, I have not seen a cross validation implementation in pyspark that would simply train and test the model on each cross validation iteration instead of also optimizing hyper-parameters. That models hyperparameters could then be optimized, but the difference would only be seen in future testing, as it could not possibly do any better on the final testing set than the current model.

--

--

Mati Kucz

UC San Diego `18, Data Scientist and Golfer: Learning & Exploring