Predicting user churn with PySpark (Part 1)
First of a three part series in which we start exploring user data from the fictional music streaming platform Sparkify and define what it means to churn
Introduction
This article series was the outcome of my work for the final project that was required to complete Udacity’s Data Scientist Nanodegree and is meant to be educational in nature.
Sparkify is a fictional service similar to Spotify. At a high level, users can come and play songs from a large pool of artists either as a guest or as a logged-in user. They can also decide to pay for the service for more benefits. They’re also free to unsubscribe from the service at any time.
Udacity’s is graciously providing both a medium (128MB) and large (12GB) dataset with artificial user activity to play with. In this dataset, the rows represent the action of a particular user at some point in time, e.g the action of playing a song from the artist “Metallica”.
Over the course of three articles, I’ll show you how I used pyspark to craft a supervised machine learning model for predicting if a user will churn in the near future from the platform (in this case, unsubscribe from the service)
Predicting churn is a challenging and common problem that data scientists and analysts regularly encounter in any customer-facing business. Additionally, the ability to efficiently manipulate large datasets with Spark is one of the highest-demand skills in the field of data.
Here’s a breakdown of what you’ll learn in each article:
- Part 1 (this article): We’ll run a data exploration of the 128MB dataset and then work backwards from the user events to define churn.
- Part 2 (link): Armed with the knowledge from the exploration phase, we’ll craft some predicting features and feed the data into a Supervised Machine Learning Model.
- Part 3 (link): Finally, we’ll walk through the process of how to set up an AWS EMR cluster to train and evaluate our model with the 12GB dataset.
Let’s get started!
If you prefer, you could skip these tutorials and just visit the github repo, which has all the code and instructions for the results presented in the articles (and some more)
Prerequisites
I’ll assume you’re already familiar with the basics of PySpark SQL, if not I recommend checking out the official getting started guide first and then come back.
If you want to follow along and execute the code locally, you’ll need to download the medium size dataset, which you can find here.
Also I highly recommend running the code within a Jupyter Notebook session. Check out this guide if you want to get an introduction.
Python Dependencies
I recommended setting up a virtual environment to install dependencies. I personally like conda, for which you can find the installation instructions here.
Once you have an environment, open a terminal to run the following command which will install all required python dependencies:
pip install jupyterlab==1.2.4 \
pyspark==2.4.4 \
numpy==1.18.1 \
matplotlib==3.1.2 \
pandas==0.25.3
Loading the Data
Let’s start by importing all the necessary packages (some of these are gonna be used in future articles):
import datetimeimport matplotlib.pyplot as plt
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.mllib.evaluation import MulticlassMetrics
import pyspark.sql.functions as sqlF
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.types import IntegerType%matplotlib inline
Next let’s create the SparkSession that we’ll be using from this point onward and load the medium dataset for analysis (output code in bold):
# Create the spark session that will be used for the whole notebook
spark = SparkSession \
.builder \
.appName("Sparkify") \
.getOrCreate()# Read the medium sized sparkify dataset for the initial exploration
# This assumes the json file was downloaded and is on the same
# directory in which you're running the code
file_path = 'mini_sparkify_event_data.json'
df = spark.read.json(file_path)
df.head()Row(artist='Martha Tilston', auth='Logged In', firstName='Colin', gender='M', itemInSession=50, lastName='Freeman', length=277.89016, level='paid', location='Bakersfield, CA', method='PUT', page='NextSong', registration=1538173362000, sessionId=29, song='Rockpools', status=200, ts=1538352117000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='30')
Exploring the Data
Let’s start by gathering some high level facts about the dataset:
def nice_describe(df, jump_size=5):
"""Wrapper around describe that prints columns at a time"""
ncols = len(df.columns)
for idx in range(0, ncols, jump_size):
col_list = df.columns[idx:idx+jump_size]
print(f'Summary statistics for {col_list}')
df.describe(col_list).show()# Print the schema per entry in the json
df.printSchema()root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)# Print how many rows and columns are in the dataset
df.count(), len(df.columns)(286500, 18)# Print descriptive statistics 2 columns at a time
nice_describe(df, 2)Summary statistics for ['artist', 'auth']
+-------+------------------+----------+
|summary| artist| auth|
+-------+------------------+----------+
| count| 228108| 286500|
| mean| 551.0852017937219| null|
| stddev|1217.7693079161374| null|
| min| !!!| Cancelled|
| max| ÃÂlafur Arnalds|Logged Out|
+-------+------------------+----------+
Summary statistics for ['firstName', 'gender']
+-------+---------+------+
|summary|firstName|gender|
+-------+---------+------+
| count| 278154|278154|
| mean| null| null|
| stddev| null| null|
| min| Adelaida| F|
| max| Zyonna| M|
+-------+---------+------+
Summary statistics for ['itemInSession', 'lastName']
+-------+------------------+--------+
|summary| itemInSession|lastName|
+-------+------------------+--------+
| count| 286500| 278154|
| mean|114.41421291448516| null|
| stddev|129.76726201141085| null|
| min| 0| Adams|
| max| 1321| Wright|
+-------+------------------+--------+
Summary statistics for ['length', 'level']
+-------+------------------+------+
|summary| length| level|
+-------+------------------+------+
| count| 228108|286500|
| mean|249.11718197783722| null|
| stddev| 99.23517921058324| null|
| min| 0.78322| free|
| max| 3024.66567| paid|
+-------+------------------+------+
Summary statistics for ['location', 'method']
+-------+-----------------+------+
|summary| location|method|
+-------+-----------------+------+
| count| 278154|286500|
| mean| null| null|
| stddev| null| null|
| min| Albany, OR| GET|
| max|Winston-Salem, NC| PUT|
+-------+-----------------+------+
Summary statistics for ['page', 'registration']
+-------+-------+--------------------+
|summary| page| registration|
+-------+-------+--------------------+
| count| 286500| 278154|
| mean| null|1.535358834085557E12|
| stddev| null| 3.291321616328068E9|
| min| About| 1521380675000|
| max|Upgrade| 1543247354000|
+-------+-------+--------------------+
Summary statistics for ['sessionId', 'song']
+-------+-----------------+--------------------+
|summary| sessionId| song|
+-------+-----------------+--------------------+
| count| 286500| 228108|
| mean|1041.526554973822| Infinity|
| stddev|726.7762634630834| NaN|
| min| 1|ÃÂg ÃÂtti Gr...|
| max| 2474|ÃÂau hafa slopp...|
+-------+-----------------+--------------------+
Summary statistics for ['status', 'ts']
+-------+------------------+--------------------+
|summary| status| ts|
+-------+------------------+--------------------+
| count| 286500| 286500|
| mean|210.05459685863875|1.540956889810471...|
| stddev| 31.50507848842202|1.5075439608187113E9|
| min| 200| 1538352117000|
| max| 404| 1543799476000|
+-------+------------------+--------------------+
Summary statistics for ['userAgent', 'userId']
+-------+--------------------+------------------+
|summary| userAgent| userId|
+-------+--------------------+------------------+
| count| 278154| 286500|
| mean| null| 59682.02278593872|
| stddev| null|109091.94999910519|
| min|"Mozilla/5.0 (Mac...| |
| max|Mozilla/5.0 (comp...| 99|
+-------+--------------------+------------------+
Let’s plot some visualizations as well:
# Show a bar chart with proportions of visits per page
page_counts_pd = df.groupby('page').count().sort('count').toPandas()
page_counts_pd['count'] = page_counts_pd['count'].astype(float)
total_visits = page_counts_pd['count'].sum()
page_counts_pd['prop'] = page_counts_pd['count'] / total_visitsplt.figure(figsize=(16, 6))
plt.barh(page_counts_pd['page'], page_counts_pd["prop"])
plt.title("Proportions of visits per page")
plt.xlabel("Proportion")
plt.ylabel("Page");
# Show a bar chart with proportions of auth types
auth_counts_pd = df.groupby('auth').count().sort('count').toPandas()
auth_counts_pd['count'] = auth_counts_pd['count'].astype(float)
total_auths = auth_counts_pd['count'].sum()
auth_counts_pd['prop'] = auth_counts_pd['count'] / total_authsplt.figure(figsize=(16, 6))
plt.barh(auth_counts_pd['auth'], auth_counts_pd["prop"])
plt.title("Proportions of auth types")
plt.xlabel("Proportion")
plt.ylabel("auth");
# Distribution of user actions per session
action_counts_pd = df.groupby('userId', 'sessionId') \
.max() \
.withColumnRenamed('max(itemInSession)', 'session_actions') \
.toPandas()plt.figure(figsize=(16, 6))
plt.hist(action_counts_pd['session_actions'])
plt.title("Distribution of user actions per session (computed from itemInSession)")
plt.xlabel("Amount of actions");
Cleaning the Dataset
We’re not missing any userId
, but it looks like we have rows with empty values. Since we're interested in user churn, then ideally we want to be able to trace back each row to some user's action. Let's explore those rows and then make a decision about what do with them:
def show_unique_stats(df, columns, sample_size=5):
"""Function to print unique value stats of specific columns"""
for col in columns:
print(f'\nColumn "{col}":')
uniques = df.select(col).dropDuplicates()
nuniques = uniques.count()
print(f'\tNumber of unique values: {nuniques}')
print(f'\tSample: {uniques.head(sample_size)}')# Explore rows with an empty user_id
no_user_df = df.filter('userId == ""')
print(f'Number of rows with empty userId: {no_user_df.count()}')
print('Sample of rows:')
no_user_df.head(1)Number of rows with empty userId: 8346
Sample row:
[Row(artist=None, auth='Logged Out', firstName=None, gender=None, itemInSession=100, lastName=None, length=None, level='free', location=None, method='GET', page='Home', registration=None, sessionId=8, song=None, status=200, ts=1538355745000, userAgent=None, userId='')]# Print unique value statistics of all categorical columns
# for rows with no user_id defined
categorical_cols = [
'artist', 'auth', 'firstName', 'gender', 'lastName', 'level',
'location', 'method', 'page', 'song', 'userAgent'
]
show_unique_stats(no_user_df, categorical_cols, 10)Column "artist":
Number of unique values: 1
Sample: [Row(artist=None)]Column "auth":
Number of unique values: 2
Sample: [Row(auth='Logged Out'), Row(auth='Guest')]Column "firstName":
Number of unique values: 1
Sample: [Row(firstName=None)]Column "gender":
Number of unique values: 1
Sample: [Row(gender=None)]Column "lastName":
Number of unique values: 1
Sample: [Row(lastName=None)]Column "level":
Number of unique values: 2
Sample: [Row(level='free'), Row(level='paid')]Column "location":
Number of unique values: 1
Sample: [Row(location=None)]Column "method":
Number of unique values: 2
Sample: [Row(method='PUT'), Row(method='GET')]Column "page":
Number of unique values: 7
Sample: [Row(page='Home'), Row(page='About'), Row(page='Submit Registration'), Row(page='Login'), Row(page='Register'), Row(page='Help'), Row(page='Error')]Column "song":
Number of unique values: 1
Sample: [Row(song=None)]Column "userAgent":
Number of unique values: 1
Sample: [Row(userAgent=None)]
So either guests or logged out individuals don’t have userId
defined, which makes sense. Given this, I think is safe to continue just with rows that have user id defined:
# Remove rows with an empty user_id
df = df.filter('userId != ""')
Defining Churn
We can define churn as the action of a user unsubscribing from the Sparkify service. During the initial exploration, the auth
field showed it can take the value Cancelled
and I anticipate those rows will allow us to identify users that churned. Let's look at a row in such state:
df.where('auth == "Cancelled"').head(1)[Row(artist=None, auth='Cancelled', firstName='Adriel', gender='M', itemInSession=104, lastName='Mendoza', length=None, level='paid', location='Kansas City, MO-KS', method='GET', page='Cancellation Confirmation', registration=1535623466000, sessionId=514, song=None, status=200, ts=1538943990000, userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.77.4 (KHTML, like Gecko) Version/7.0.5 Safari/537.77.4"', userId='18')]
The user visited the Cancellation Confirmation
page, so it sounds like it really did churn at that point. Let's explore the timeline of that user in that particular session to help us understand better what happened:
def user_timeline(df, user_id, session_id, cols, n=5):
"""Print rows the last n actions of an user in a session"""
user_df = df.where(f'userId={user_id} AND sessionId={session_id}')
print(f'Number of rows for user with id {user_id} and session id {session_id}: {user_df.count()}')
user_df.select(cols).sort(sqlF.desc('ts')).show(n)# Timeline for the user with id 18 and session with id 514
user_timeline(df, 18, 514, ['ts', 'sessionId', 'auth', 'page'])Number of rows for user with id 18 and session id 514: 102
+-------------+---------+---------+--------------------+
| ts|sessionId| auth| page|
+-------------+---------+---------+--------------------+
|1538943990000| 514|Cancelled|Cancellation Conf...|
|1538943740000| 514|Logged In| Cancel|
|1538943739000| 514|Logged In| Downgrade|
|1538943726000| 514|Logged In| NextSong|
|1538943440000| 514|Logged In| NextSong|
+-------------+---------+---------+--------------------+
only showing top 5 rows
The picture starts to make sense now: when the user visits the Cancellation Confirmation
page at some point, then it follows that the user is no longer Logged In
. We can validate that:
# Validate that when an user visits the `Cancellation Confirmation` page, then is no longer `Logged In`
cancel_subset_df = df.where('page="Cancellation Confirmation"')
show_unique_stats(cancel_subset_df, ['auth'])Column "auth":
Number of unique values: 1
Sample: [Row(auth='Cancelled')]# Does a user with a `Cancelled` auth means it can only have visited the `Cancellation Confirmation` page?
auth_subset_df = df.where('auth="Cancelled"')
show_unique_stats(auth_subset_df, ['page'])Column "page":
Number of unique values: 1
Sample: [Row(page='Cancellation Confirmation')]
So given all the above, I think is safe to say that any user that has an auth value of Cancelled
can be considered churned at that point.
Let's work on adding a churned
column to the dataframe which is marked with 1 if that user churned from the platform at some point, otherwise is marked as 0:
def add_label_churned(df):
"""Add the `churned` to indicate if the user churned"""
# Identify the rows with a cancelled auth state and mark those
# with 1, then use a window function that groups
# by users and puts the cancel event at the top (if any) so
# every row gets a one after that when we sum
cancelled_udf = sqlF.udf(
lambda x: 1 if x == 'Cancelled' else 0, IntegerType())
current_window = Window.partitionBy('userId') \
.orderBy(sqlF.desc('cancelled')) \
.rangeBetween(Window.unboundedPreceding, 0)
churned_df = df.withColumn('cancelled', cancelled_udf('auth')) \
.withColumn("churned",
sqlF.sum('cancelled').over(current_window))
return churned_df.drop('cancelled')# Add the `churned` label
df = add_label_churned(df)# Show once again the timeline of actions for the user with id 18
# and session id 514
user_timeline(add_label_churned(df), 18, 514,
['ts', 'sessionId', 'page', 'churned'])Number of rows for user with id 18 and session id 514: 102
+-------------+---------+--------------------+-------+
| ts|sessionId| page|churned|
+-------------+---------+--------------------+-------+
|1538943990000| 514|Cancellation Conf...| 1|
|1538943740000| 514| Cancel| 1|
|1538943739000| 514| Downgrade| 1|
|1538943726000| 514| NextSong| 1|
|1538943440000| 514| NextSong| 1|
+-------------+---------+--------------------+-------+
only showing top 5 rows
We now have all rows correctly labeled!
To be continued…
In the next article of the series I share the process I followed for crafting some predicting features, shaping the data into a form in which each user is represented by a single row, and then how I fed the data into a supervised machine learning model.
As mentioned in the beginning of the article, you can also go visit my github repo if you’re interested in the actual code that can be used to reproduce the results of all the work.