Sparkify Capstone Project — User Churn Prediction on an AWS Cluster

Why cancel a subscription with your music streaming service?

Introduction

Welcome to my capstone project, for the Udacity Data Science Nanodegree. For this final Capstone Project, I have chosen to work on the Sparkify option, which, like the some of the other projects, emphasizes prediction to enhance a product, but, unlike the other projects, emphasizes Spark. The idea with this project is to look at log files for a fictitious music streaming service, Sparkify, and predict whether or not a given user will cancel their subscription with the service. We will be leveraging Spark, and subsequently Scala, in order to look at the small dataset (128MB) or the large dataset (12GB). When looking at the large dataset, we will be using AWS, to rent a Hadoop Cluster with their EMR service.

Small Dataset

I decided that the best way to approach this project would be to do everything first on the small dataset: exploratory data analysis, feature engineering, and modeling.

The dataset contains log files that generate entries whenever a user makes an action on a site, like picking the next song, giving a song a ‘thumbs down’, or landing onto a new page. These files take into account information about the user, such as their location, user agent string for accessing the site, and account level: free or paid. It’s important to note that those who use the free service receive advertisements. Let’s take a look at the schema of this log file:

Checking for Missing Data

There’s an empty string for the userId, and many nulls, but more importantly the user has reached the “Logged Out” authentication status

Some actions also generate nulls, like “add to playlist”, “Roll Advert”, and “Thumbs Up”

In the future, we are going to want to run aggregations on the userId column, so that cannot be an empty string to predict on. So we will remove all userId values for an empty string. The other columns will not need the cleaned, as it will not impact the calculations or predictions in further steps.

Wow, that operation removed 8,346 entries

Exploratory Data Analysis

Now, it is time to break down the dataset into the classes we will eventually try to predict, between the users who did or did not cancel their subscription with the service. After we can define the different classes, then the distribution between these two groups can be observed.

After running the above operation, the (small) dataset shows that 52 people have left their service, with 225 total being having signed up.

Looking Deeper at the Columns

Before addressing each column, and their distributions, I want to first take mention of the columns that will not require further analysis:

auth: This column is used to talk about the user authentication steps, such as logging in and out of the service. Information like this can also be found in looking at unique sessionId’s, which are generated, and terminated based on the actions defined in this column.

firstName, lastName: These columns contain unique identifiers, which would only aid in directly identifying that the userChurned. Furthermore, I’m not sure that people avoid using a particular service because of their first, or last name (ok, maybe I can think of examples, but it doesn’t seem likely).

length: This column will not be looked at as it simply describes the length of the song that is currently being listened to. This choice seems up to the user, and would happen regardless of the music streaming service.

method: Certain page actions will happen through GET or POST methods, so it is more informative to look at the more meaningful action that cause the GET or POST.

registration: This is an informative column, but the same information may be obtained by looking at the minimum value of a particular userId’s ts column.

song: There are simply too many songs for this information to be informative. As the service is only about the end user listening to music, other metrics about total amount of usage can offer the same information.

status: This column refers to the http response status code. Values other than 200, for a successful response will be captured by the “Error” value in the page column.

userId: userId will be leveraged for aggregation, and later deduplication of the data, to ensure each entry represent a user and their features. However, at that point we should not use this for prediction, as it will correlate directly with the value we want to predict, userChurned.

Bar plots — for count of categorical data

Gender:

Level:

Page:

Let’s make more meaningful columns for location and userAgent

The location column, and the user agent column, contain data that can be split into separate, more sensible, columns. The location column can become a city, and a state column. Also, the user agent column provides information about the user, such as the browser and platform, used to access the service (both becoming their own columns). Let’s take a look at this data, and compare users who did or did not churn from the service.

State:

Browser:

Platform:

Looking at this graph, my immediate insight is the difference in mac and windows users, seems that you’re much more likely to stay, if you are using a Mac or a Windows machine to access the service. But, we will look further at feature importance, after the modeling step.

Box and Whisker plot — for numerical data

Artist column:

Maybe, people who do not like the music service do not explore as many artists. So, create a column for the number of distinct artists each user listened to.

Session columns — itemsInSession and sessionID:

From these values, we will discern the length of the average sessions, number of entries it generated, and the number of sessions used by a user, by looking at the distinct number of session ID’s.

Feature Engineering and More Visualizations

After getting acquainted with the data, cleaning it and analyzing the values of different columns, I began to think of different ways this information could be represented. Having the ts (timestamp) column made me think of tracking the users actions on a daily basis, and creating columns for those. I then created another column based on the ts data, called uniqueDay, so I could run aggregations on this value. First I amassed the total number of interactions for each day, then those were averaged and created into a new column for the average daily interaction.

Thumbs up:

Thumbs Down:

If we remove the outliers from this graph, then the difference is clear, maybe these users are not being recommended songs that they like, discouraging them from using the service.

Errors:

Advertisements:

Next Songs:

I then engineered another feature that was not based on daily metrics. Instead I looked at some other feature that would apply to the user over the entirety of their time using the service, like the duration of time passed since registering with the service.

Before we begin the modeling step, let’s take a look at the result of engineering our features, and pick and choose which of these features we are going to keep.

We will now construct a dataset containing just the variables for prediction., deduplicating on the ‘userId’ column, in order to predict the userChurned column for a given user. Then we will select the important features to define a user, for the categorical variables: ‘gender’, ‘level’,’browser’, ‘platform’, ‘state’, ‘city’ , and numerical variables: ‘avgDailySessions’, ‘avgSessionLength’, ‘avgDailyThumbsUp’, ‘avgDailyThumbsDown’, ‘avgDailyErrors’, ‘avgDailyAdvertisements’, ‘avgDailyNextSong’, ‘totalDistinctArtists’, ‘durationUsingService’.

Feature Preprocessing

The modeling step will be broken down into vectorizing the data, training a few models, and improving the best performing model through cross validation.

Vectorizing the data took the form of leveraging StringIndexer, for categorical variables, and VectorAssembler for numerical variables. It’s important to note that the numerical variables were scaled, and then passed through another VectorAssembler, along with the categorical variables, to produce one vector representing all of the features.

Performance Metrics

F1 Score: The target variable that we are trying to predict, whether or not a user churned, is represented by a 1 or a 0. Therefore, we should be solving this problem as a binary classification. Sometimes, we will predict the existence of the userChurned property, and it will be correct, or incorrect, the difference here is between a false positive and a true positive. The same can be said about the negative results. Taking into account these different areas of accuracy, we can calculate “Precision” and “Recall”, but also the more robust F1-Score.

See wiki for more on these calculations

Area Under ROC: I will be using this metric as well to measure the results of my models, since it also leverages binary classification. This metric is also based on binary classification, but instead of using precision and recall, it uses sensitivity and (1-specificity). These metrics are calculated with true positives, false negatives, etc. but looks at the rate that these values increase. Thus, a high score in Area Under ROC will demonstrate accuracy, but also emphasize the gap between number of true positives and true negatives.

Modeling

I then split the data to create a training and test set to train a few models: Logistic Regression, Random Forest, and Gradient Boosted Trees.

Step 1: I began by testing for the F1 score and Area Under ROC for a Logistic Regression model. I chose this model as the baseline for my other models because the model is a generally simple classifier, and a more complicated model should produce better results. Here are the baseline metrics to compare the more complex models against: F1 score of 79.55% , and Area Under ROC score of 85%.

Step 2: I then tried out a Gradient Boosted Trees model, to see if this would improve the results of the baseline. However, this was not the case, with the following metrics returned: F1 score of 60.27% , and Area Under ROC score of 72.71%. As this model is more complicated, I concluded that this model performed much worse because it needed more data to work with, we are currently only using the small dataset locally.

Step 3: I wanted to try out one more complex model, to see if maybe the issue was the model itself in the previous step, and not the lack of data. However, this model also did not perform as well as the baseline logistic regression model: F1 score of 71.97% , and Area Under ROC score of 77.08%. The F1 score is slightly better, but the Area Under ROC is significantly worse.

Step 4: Since the baseline model ended up working the best, I decided to run a crossvalidator on it and hypertune the parameters: elasticNet and maxIter, though only increasing the max iterations improved the model. Of course increasing iterations will force the model to learn more patterns, and it’s important that when choosing a solution like this to check the potential for overfitting. As a counter to increasing the iterations, I used elasticNet to potentially regularize the regression, and reduce complexity. Here are the results: F1 score of 79.55%, and Area Under ROC of 83.33%. It will now be interesting to see how these results change when looking at the full dataset. These values are lower, but according to the table below, there was an increase in value when setting maxIter to 50, not just 10. So, the elasticNetParam must not affect this data as much.

Hyper-parameter values:

Feature Importance

Large Dataset

In order to look at the large dataset, I turned to AWS, and used their EMR service. Through this service, I was able to leverage a Jupyter notebook to look at a 12GB dataset, 25 million rows.

I went through the same steps as the above data cleaning, and data preprocessing, but when I got to the checking the model performance, I noticed a clear difference. This time, they were all returning an F1 score of a little over 80%. I guess the other 2 models needed much more data in order to achieve a reasonable accuracy.

I improved the logistic regression model through cross validation, but there were barely any improvements. I then ran cross validation on the Random Forest model (for the following parameters: maxDepth, and impurity).

Hyper-parameters: The following parameters were chosen to increase accuracy on the dataset. The first maxDepth allows for the Decision Trees within the random forest to grow to a larger depth, learning more combinations, but risking overfitting. I thought that this would not be such a risk, as the dataset is now much larger, and more difficult to overfit on. The second, impurity, helps to define the method of the classifying a node with ‘Gini’ classifying for an exact index value, and entropy running calculations to find the closest value. So, the Gini index should work best for our binary classification problem.

Final Results

In utilizing the cross validator, and a much larger dataset, the Random Forest model ended up returning the best results in predicting User Churn. With the help of Spark, and AWS, the training of these models was fairly cheap, and quick (it took longer to look run these models locally on the smaller dataset)!

Be Careful when Using Spark

When using Spark, there is the consideration of memory, and delegating Spark jobs to workers. Spark retains calculations in memory before giving the answer to the end user, this is to create a DAG. The DAG will efficiently run through the operations, in the most efficient order possible, to reduce runtime and memory. Here are some problems I ran into when using Spark casually, and the solutions that I came up with.

  1. If you’re working on feature engineering, chances are you’re creating columns and appending them to your dataframe. It’s possible, you want to run an aggregation, create a new table, and run a join with the original dataframe. However, there are issues when running joins on a clone of the original dataframe! See this stackoverflow for a full discussion and the location for my solution.
I had to rename the column to have the same name…

2. After running the code many times over, calling user_log.show(5) whenever I felt like it, and running multiple models, everything stopped working. I couldn’t output anything related to the user_log, what happened? Well, after forcing the model to deliver the output multiple times, interrupting the DAG, and running fit methods for multiple models, the driver ran out of memory. The solution here was to reduce the number of times I called for results, and simply cache the training and test datasets, to safely run my models.

Further Improvements

  1. I think that designing more complicated features from the dataset would help to improve these predictions. Such as taking into account not only daily behavior, but weekly and monthly as well. I’d also like to look at specific user journeys, like number of times a user ended a session immediately after receiving an advertisement.
  2. The model could be run for much longer, trying out more hyper-parameters, to achieve the best combination of options.
  3. Lastly, the models performance could be tested against realtime log files, to see if a user that is predicted to eventually leave can be targeted and convinced to stay with the service. The model can then predict which users are worth targeting.

Github

Check out my code!

--

--