Step I: A Beginner’s step by step guide to build Machine Learning Models of Customer Churn for a Music Company in PySpark.

Kaustubh Ursekar
Jan 9 · 11 min read

Preface:

All I want is my customer back” is increasingly becoming a dominant attitude in business meetings these days. From CEOs of companies to Project managers, everyone wishes to know parameters like customer satisfaction, customer retention, etc. These metrics sustains the business in modern world and data science is a bloating technology that can be leveraged for this purpose. Thus, nobody wishes to stay behind or miss this opportunity to increase their market footprint.

Introduction:

I shall discuss an example of music company data to predict customer churn. This post is intended for people who are brand new to coding in data science field or the ones, who are unaware of technical coding aspects of customer churning.

Below are the mentioned features of this project:

  1. ) Python 3.x will be used along with Jupyter Notebook.
  2. ) A sample of 128mb from BigData of 12gb is used to work in Pyspark
  3. ) One can setup this spark cluster on AWS or IBM CLOUD with mediocre modification of code.
  4. )Step by Step guide from importing libraries to hyper parameter tuning.
  5. ) Two models : Gradient Boosted Tree and Logistic Regression shall be used to predict the churning customers
  6. ) Link to GitHub repository for entire notebook: https://github.com/ursekar/ML-Model-of-Customer-Churn/blob/master/Sparkify.ipynb

The Project:

I have divided my project in 4 parts as it would be otherwise too long for one to grasp. I have went in detail throughout my voyage and steered it through each sets of this oceanic topic. The 4 humongous waves are:

Part I: Understanding Dataset & Data Wrangling; “You are here !”.

Part II: EDA or Exploratory Data Analysis;

https://medium.com/@kaustubh.ursekar/step-ii-a-beginners-step-by-step-guide-to-build-machine-learning-models-of-customer-churn-for-a-8b6fefc07bba

Part III: Feature Engineering;

https://medium.com/@kaustubh.ursekar/part-iii-a-beginners-step-by-step-guide-to-build-machine-learning-models-of-customer-churn-for-a-88302bc468c8

Part IV: Modelling & Hyperparameter Tuning;

https://medium.com/@kaustubh.ursekar/part-iv-a-beginners-step-by-step-guide-to-build-machine-learning-models-of-customer-churn-for-a-85dc64257e0d

Data Sources:

I have used data provided by Udacity. There are 3 versions of the dataset as the data provided is 12 GB, which is very large. Therefore, a small cluster of 128 MB, a medium size of data of about 6 GB and complete data are the 3 types of the dataset that is available to us.

The deployment can be performed using IBM Cloud or AWS as one desires. For the simplicity of this blog. I shall go with the small cluster of 128 MB provided to us. Do not worry, this code can re-used with mediocre to no change on both cloud types.

The 3 types of same dataset available are:

  1. ) 12 GB Dataset: You can use the below code line for importing data in you spark session on AWS.

Full Sparkify Dataset: s3n://udacity-dsnd/sparkify/sparkify_event_data.json

2.) 6 GB Dataset: You can download the data from below link

https://video.udacity-data.com/topher/2018/December/5c1d6681_medium-sparkify-event-data/medium-sparkify-event-data.json

3.) 128 MB Small Cluster Dataset: You can download the data from my Repository below. I have uploaded zipped version of the data. Hence, do not forget to unzip and extract the 128 MB data.

https://github.com/ursekar/BigData-PySpark

Importing Libraries and Data Introduction:

The Project is done in PySpark along with Pandas and MatplotLib which is used for visualizations. To perform count, sum, groupby, sorting, and other mathematical and logical tweaks, as well as for defining our own functions we shall need functions library of SQL from pyspark.sql.functions. You will soon see this library to prove useful in data exploration, data munging and wrangling. Secondly, we shall need library of ml from pyspark for machine learning and mending of variables. We would need to import feature, evaluation, classification, tuning, etc. Here is all you need to import.

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import isnan, count, when, col, desc, col, sort_array, asc, avg
from pyspark.sql.functions import sum as Fsum, col, split, mean, first, lit
from pyspark.sql.functions import UserDefinedFunction as udf
from pyspark.sql.functions import broadcast
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import GBTClassifier,LinearSVC, LogisticRegression
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import pandas as pd
import numpy as np
import pyspark.sql.functions as f
import datetime
import matplotlib as plt

Now let us import our data. We need to create a spark session for this purpose. I have given app name as: “Spark SQL Quiz”. One can give anything else if he she desires to. Once done, now I import my json file named mini_sparkify_event_data.json.

# create a Spark session
spark = SparkSession \
.builder \
.appName("Spark SQL Quiz") \
.getOrCreate()
data = spark.read.json("mini_sparkify_event_data.json")

Viewing our schema which simply means that we trying to see columns in our data using printSchema() attribute on our dataset as shown.

data.printSchema()

root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

Let us see records in our data-frame too. Observe here that Spark data-frames are generally ones with many columns and they also give us data very much unaligned when just thrown attribute show() after them. Hence, we shall use some parameters with show() attribute and produce the first 5 records of dataframe without truncating them.

data.show(vertical = True, n=5, truncate=False)-RECORD 0-----------------------------------------------------------
artist | Martha Tilston
auth | Logged In
firstName | Colin
gender | M
itemInSession | 50
lastName | Freeman
length | 277.89016
level | paid
location | Bakersfield, CA
method | PUT
page | NextSong
registration | 1538173362000
sessionId | 29
song | Rockpools
status | 200
ts | 1538352117000
userAgent | Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0
userId | 30
-RECORD 1-----------------------------------------------------------
artist | Five Iron Frenzy
auth | Logged In
firstName | Micah
gender | M
itemInSession | 79
lastName | Long
length | 236.09424
level | free
location | Boston-Cambridge-Newton, MA-NH
method | PUT
page | NextSong
registration | 1538331630000
sessionId | 8
song | Canada
status | 200
ts | 1538352180000
userAgent | "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.103 Safari/537.36"
userId | 9
-RECORD 2-----------------------------------------------------------
artist | Adam Lambert
auth | Logged In
firstName | Colin
gender | M
itemInSession | 51
lastName | Freeman
length | 282.8273
level | paid
location | Bakersfield, CA
method | PUT
page | NextSong
registration | 1538173362000
sessionId | 29
song | Time For Miracles
status | 200
ts | 1538352394000
userAgent | Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0
userId | 30
-RECORD 3-----------------------------------------------------------
artist | Enigma
auth | Logged In
firstName | Micah
gender | M
itemInSession | 80
lastName | Long
length | 262.71302
level | free
location | Boston-Cambridge-Newton, MA-NH
method | PUT
page | NextSong
registration | 1538331630000
sessionId | 8
song | Knocking On Forbidden Doors
status | 200
ts | 1538352416000
userAgent | "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.103 Safari/537.36"
userId | 9
-RECORD 4-----------------------------------------------------------
artist | Daft Punk
auth | Logged In
firstName | Colin
gender | M
itemInSession | 52
lastName | Freeman
length | 223.60771
level | paid
location | Bakersfield, CA
method | PUT
page | NextSong
registration | 1538173362000
sessionId | 29
song | Harder Better Faster Stronger
status | 200
ts | 1538352676000
userAgent | Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0
userId | 30
only showing top 5 rows

I shall discuss the types of columns and how we would be using or dropping them in coming parts.

Data Cleaning and Wrangling:

Now comes our time to get hands dirty in this project. The most important yet tedious part of any data science project after machine learning could be data exploration. So I will jump right into it without any fancy quotes or inspirational idioms.

I am complete noob (amateur or new-be). I hope some of you guys reading this article are like me and the ones who are expert, please feel free to criticize me in your responses below and how could I have went better from here in this project.

Thus, I start with viewing the null or nan entries in each column of my data by using is.nan(). But this gives only boolean output and hence I need a math operator like count() to know exactly how many are null entries in each column of data as count(when(isnan(c), c))bla bla .

Also, I would need a for loop as I will be passing each column to my is.nan() one at a time like bla for c in data.columns. All these would come together to make my entire function in one line of code as:

data.select([count(when(isnan(c), c)).alias(c) for c in data.columns]).show(vertical = True, Truncate = False)-RECORD 0------------
artist | 0
auth | 0
firstName | 0
gender | 0
itemInSession | 0
lastName | 0
length | 0
level | 0
location | 0
method | 0
page | 0
registration | 0
sessionId | 0
song | 0
status | 0
ts | 0
userAgent | 0
userId | 0

Thus we do not have any null or nan entries in our data. But is it really true or are we deceived? Well we would see that later. For now, this much would do to proceed ahead for deciding which column will act as “churn” or simply put “dependent variable” in our ML models. We see that the user goes out completely once they confirm their cancellation of subscription of a product or service. Hence, we need to look for such a page under the page column to tell us whether user has cancelled out or not.

data.select(“page”).distinct().show(truncate = False)+-------------------------+
|page |
+-------------------------+
|Cancel |
|Submit Downgrade |
|Thumbs Down |
|Home |
|Downgrade |
|Roll Advert |
|Logout |
|Save Settings |
|Cancellation Confirmation|
|About |
|Submit Registration |
|Settings |
|Login |
|Register |
|Add to Playlist |
|Add Friend |
|NextSong |
|Thumbs Up |
|Help |
|Upgrade |
+-------------------------+

From all these, we see that we have “Cancellation Confirmation” page that guarantees or assures us the cancellation of the user for this particular music service. Hence we would take this particular page as our label of interest. But we cannot be using this label as it is. We would need binary approach of 1 and 0 to get us it in the language ML models would understand. Therefore we can simply write a flag function with an if..else.. condition to indicate match to “Cancellation Confirmation” as 1 and others as 0.

# Define a flag function
flag_cancelation_event = udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())

But, here is one problem. The userId column will have multiple entries for an user. Because, the user would have liked, disliked and changed songs or done other activity causing him/her to visit more than one pages, such as “Thumbs up”, “Thumbs down”, “NextSong”, etc. Therefore, there would be more than one userId records or rows for a particular user who has churned and One would want to be sure that all of them give us 1 value in the page column. Thus, we would need here something like SQL, that partition’s records of our users by their unique ids and yields us the dataset of our interest. This can be achieved by importing Window from sql.window. We would save these 1s and 0s in seperate column called churn, by adding it to our dataset.

# Define a flag function
flag_cancelation_event = udf(lambda x: 1 if x == “Cancellation Confirmation” else 0, IntegerType())
# apply to the dataframe
data = data.withColumn(“churn”, flag_cancelation_event(“page”))
#Define window bounds
windowval = Window.partitionBy(“userId”).rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
# Applying the window
data = data.withColumn(“churn”, Fsum(“churn”).over(windowval))
print(‘The total number of churners are:’, data.agg(Fsum(“churn”)).collect()[0][0])The total number of churners are: 44864

But the above are total number of records of these churners. To find the number of unique churners, we would need distinct(). We have used Fsum() which is nothing but as I call it “Fancy sum()” of PySpark working similarly as our sum() in python. Hence he would be our friend here in PySpark environment.

Therefore, we select the userId and churn column and aggregate with Fsum() and apply distinct() to get our number of unique churned users.

data.select(‘userId’, ‘churn’).distinct().agg(Fsum(“churn”)).collect()[0][0]52

Now, recollect that before we had only went through Null and Nan values for each column in our dataset, but we did not find any. Hence, to make sure that there really are not any, we should go on looking for even blank “” entries in each column of our dataset. Therefore, we shall define a function here than can do this job for us no matter how many future columns get created in this dataset. The another reason for defining a function is, it would be easy for us to re-use this code while creating the .py files for producing an application of this code. So here we go.

Logic is to filter out records that are Nan, Null or Blank “” and count them. This can be easily done with filter attribute of the dataframe in pyspark. But we also wish to go over each column and we know that each column of dataframe comes out as either series or list. hence we would just introduce a for loop here to do our job of going through every column of our dataframe.

def missing_recs(df, col):return df.filter((isnan(df[col])) | (df[col].isNull()) | (df[col] == “”)).count()print(“\nColumns in data are:\n”,data.columns)print(“\n[missing values]\n”)
for col in data.columns:
missing_count = missing_recs(data, col)
if missing_count > 0:
print(“{}: {}”.format(col, missing_count))
Columns in data are:
['artist', 'auth', 'firstName', 'gender', 'itemInSession', 'lastName', 'length', 'level', 'location', 'method', 'page', 'registration', 'sessionId', 'song', 'status', 'ts', 'userAgent', 'userId', 'churn']

[missing values]

artist: 58392
firstName: 8346
gender: 8346
lastName: 8346
length: 58392
location: 8346
registration: 8346
song: 58392
userAgent: 8346
userId: 8346

Well !! this is a surprise now. Here we have caught lot of missing entries in our data. So are we suppose to bite our nails for them and scratch are heads for it? Not at all !! simply think by taking a step back and seeing overall. This is a music company and every user must be registered to enjoy either free or paid service.

Here the userId is acting something like primary key or valuable for column for our data. Something we can say as a base point to keep our analysis on. Therefore, let us drop blank records from dataframe where the entry for userId is missing and get only the genuine or authentic part of our data to establish our analysis upon.

For now, do not worry about blank entries of other columns as you are going to see how the industrial requirements differ from academic teaching in later modelling parts and know approach of “Feature Engineering” that is used most of times for high;y sophisticated models. Thereby,we are touching the base of it here too, so stay on !

print(‘Total number of rows in data are:’, data.count())
print(‘Length of data after dropping blank userId:’, data[data[‘userId’] != ‘’].count())
data = data[data[‘userId’] != ‘’]
Total number of rows in data are: 286500
Length of data after dropping blank userId: 278154

Now its time to dive in Exploratory Data Analysis or simply acronym as EDA.

Thank you to my readers:

If you have made it till here, then I appreciate your patience of reading my article, Please leave your responses below. I would surely appreciate any improvements or criticisms suggested.

The link for next part is given below:

https://medium.com/@kaustubh.ursekar/part-iii-a-beginners-step-by-step-guide-to-build-machine-learning-models-of-customer-churn-for-a-88302bc468c8

Kaustubh Ursekar

Written by

Data Scientist @ National Physician Services - Hartford, CT, USA

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade