Churn Prediction Using Machine Learning
Analyze all relevant customer data and develop a robust and accurate Churn Prediction model to retain customers and to form strategies for reducing customer attrition rates.
Churn means customers or users who left the services or migrates to the competitor in the industry. It is very important for any organization to keep its existing customer and attract new ones if one of them fails it is bad for business. The goal is to explore the possibility of machine learning for churn prediction to retain a competitive edge in the industry.
One of the most famous and useful case studies of churn prediction is in the telecom industry. It is important for telecom companies to analyze all relevant customer data and develop a robust and accurate Churn Prediction model to retain customers and to form strategies for reducing customer attrition rates.
In this project, Telco Customer Churn Dataset which is available at Kaggle is used.
Attributes Information
Prediction column:
Churn: Whether the customer churned or not (Yes or No)
Two numerical columns:
1. MonthlyCharges: The amount charged to the customer monthly
2. TotalCharges: The total amount charged to the customer
Eighteen categorical columns:
1. CustomerID: Customer ID unique for each customer
2. gender: Whether the customer is a male or a female
3. SeniorCitizen: Whether the customer is a senior citizen or not (1, 0)
4. Partner: Whether the customer has a partner or not (Yes, No)
5. Dependents: Whether the customer has dependents or not (Yes, No)
6. Tenure: Number of months the customer has stayed with the company
7. PhoneService: Whether the customer has a phone service or not (Yes, No)
8. MultipleLines: Whether the customer has multiple lines or not (Yes, No, No phone service)
9. InternetService: Customer’s internet service provider (DSL, Fiber optic, No)
10. OnlineSecurity: Whether the customer has online security or not (Yes, No, No internet service)
11. OnlineBackup: Whether the customer has an online backup or not (Yes, No, No internet service)
12. DeviceProtection: Whether the customer has device protection or not (Yes, No, No internet service)
13. TechSupport: Whether the customer has tech support or not (Yes, No, No internet service)
14. StreamingTV: Whether the customer has streaming TV or not (Yes, No, No internet service)
15. StreamingMovies: Whether the customer has streaming movies or not (Yes, No, No internet service)
16. Contract: The contract term of the customer (Month-to-month, One year, Two years)
17. PaperlessBilling: Whether the customer has paperless billing or not (Yes, No)
18. PaymentMethod: The customer’s payment method (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))
The project is structured as follows:
- Data cleaning
- Exploratory Data Analysis
- Data Preprocessing
- Encoding
- Feature Selection
- Oversampling Technique
- Model Creation and Evaluation
- Improving the Model
1. Data Cleaning
Start with Importing important libraries:
import numpy as np # linear algebra
import pandas as pd # data processing
import seaborn as sns # For creating plots
import matplotlib.ticker as mtick # For specifying axes tick format
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
sns.set(style = 'white')# Input data files are available in the "../churn_prediction" directory.import os
print(os.listdir("../churn_prediction"))df.shape
(7043, 21)
Converting columns in the required datatype format before moving forward. As “TotalCharges” column is defined as object which is originally a numerical column.
# Converting Total Charges to a numerical data type.
df.TotalCharges = pd.to_numeric(df.TotalCharges, errors=’coerce’)# Passed a dictionary to astype() function
df = df.astype({“customerID”:’category’,
“gender”:’category’,
“SeniorCitizen”:’category’,
“Partner”:’category’,
“Dependents”:’category’,
“tenure”:’category’,
“PhoneService”:’category’,
“MultipleLines”:’category’,
“InternetService”:’category’,
“OnlineSecurity”:’category’,
“OnlineBackup”:’category’,
“DeviceProtection”:’category’,
“TechSupport”:’category’,
“StreamingTV”:’category’,
“StreamingMovies”:’category’,
“Contract”:’category’,
“PaperlessBilling”:’category’,
“PaymentMethod”:’category’,
“MonthlyCharges”: ‘float64’,
})
Now First, check for any missing values available or not, and if available then by how many percentages so decide the imputation method accordingly.
# Percentage of null values
df.isnull().sum() * 100 / len(df)customerID 0.000000
gender 0.000000
SeniorCitizen 0.000000
Partner 0.000000
Dependents 0.000000
tenure 0.000000
PhoneService 0.000000
MultipleLines 0.000000
InternetService 0.000000
OnlineSecurity 0.000000
OnlineBackup 0.000000
DeviceProtection 0.000000
TechSupport 0.000000
StreamingTV 0.000000
StreamingMovies 0.000000
Contract 0.000000
PaperlessBilling 0.000000
PaymentMethod 0.000000
MonthlyCharges 0.000000
TotalCharges 0.156183
Churn 0.000000
dtype: float64
Now missing is present in the dataset it is in very small percentages so either missing value can be removed from dataset or impute using simple mean imputation. There are 11 missing values which are only 0.15% of total values for Total Charges. So we can fill it with simple mean imputation our data set.
# fill missing values with mean column values
df.TotalCharges.fillna(df.TotalCharges.mean(), inplace=True)
2. Exploratory Data Analysis
Check for imbalance class distribution
# Class Distribution
df.Churn.value_counts()
No 5174
Yes 1869
Name: Churn, dtype: int64
Plot of Churn Class Distribution
def bar_plot(df,column):
ax = sns.countplot(y=column, data=df)
plt.title('Distribution of Configurations')
plt.xlabel('Number of Axles')total = len(df[column])
for p in ax.patches:
percentage = '{:.1f}%'.format(100 * p.get_width()/total)
x = p.get_x() + p.get_width() + 0.02
y = p.get_y() + p.get_height()/2
ax.annotate(percentage, (x, y))
plt.show()bar_plot(df, "Churn")
Target variable
We are trying to predict if the user left the company in the previous month. Therefore we have a binary classification problem with a slightly unbalanced target:
- Churn: No — 72.4%
- Churn: Yes — 27.6%
Numerical features
There are only three numerical columns: tenure, monthly charges, and total charges.
def kdeplot(feature, hist, kde):
plt.figure(figsize=(9, 4))
plt.title("Plot for {}".format(feature))
ax0 = sns.distplot(df[df['Churn'] == 'No'][feature].dropna(), hist=hist, kde=kde,
color = 'darkblue', label= 'Churn: No',
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 4})
ax1 = sns.distplot(df[df['Churn'] == 'Yes'][feature].dropna(), hist=hist, kde=kde,
color = 'orange', label= 'Churn: Yes',
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 4})
plt.savefig('kde.png')
kdeplot('tenure', hist = False, kde = True)
kdeplot('MonthlyCharges', hist = False, kde = True)
kdeplot('TotalCharges', hist = False, kde = True)
From the plots above we can conclude that:
- Recent Users are more likely to churn
- Users with higher MonthlyCharges are also more likely to churn
- TotalCharges have a similar property for both
Feature Generation that can b done by the difference between the MonthlyCharges and the TotalCharges divided by the tenure:
# Calculate features
df['total_charges_to_tenure_ratio'] = df['TotalCharges'] / df['tenure']
df['monthly_charges_diff'] = df['MonthlyCharges'] - df['total_charges_to_tenure_ratio']
kdeplot('monthly_charges_diff')
Categorical features
This dataset has 16 categorical features:
- Six binary features (Yes/No)
- Nine features with three unique values each (categories)
- One feature with four unique values
Binary Features (Yes/No)
fig, axes = plt.subplots(2, 3, figsize=(12, 7), sharey=True)
sns.countplot("gender", data=df, ax=axes[0,0])
sns.countplot("SeniorCitizen", data=df, ax=axes[0,1])
sns.countplot("Partner", data=df, ax=axes[0,2])
sns.countplot("Dependents", data=df, ax=axes[1,0])
sns.countplot("PhoneService", data=df, ax=axes[1,1])
sns.countplot("PaperlessBilling", data=df, ax=axes[1,2])
fig.savefig("inp.png")
- Gender Distribution — About half of the customers in our data set are male while the other half are female.
- % Senior Citizens — There are only 16% of the customers who are senior citizens. Thus most of our customers in the data are younger people.
- Partner — About 50% of the customers have a partner.
- Dependent status — Only 30% of the total customers have dependents.
- Phone Service — About 90.3% of the customers have phone services.
- Paperless Billing— About 59.2% of the customers make paperless billing
Partner and Dependent:
sns.countplot("Partner", data=df, hue = 'Dependents')
fig, axis = plt.subplots(1, 2, figsize=(12,4))
axis[0].set_title("Has partner")
axis[1].set_title("Has dependents")
axis_y = "percentage of customers"
# Plot Partner column
gp_partner = df.groupby('Partner')["Churn"].value_counts()/len(df)
gp_partner = gp_partner.to_frame().rename({"Churn": axis_y}, axis=1).reset_index()
ax = sns.barplot(x='Partner', y= axis_y, hue='Churn', data=gp_partner, ax=axis[0])
# Plot Dependents column
gp_dep = df.groupby('Dependents')["Churn"].value_counts()/len(df)
gp_dep = gp_dep.to_frame().rename({"Churn": axis_y}, axis=1).reset_index()
ax = sns.barplot(x='Dependents', y= axis_y, hue='Churn', data=gp_dep, ax=axis[1])
- Customer who has Partner is more likely to have Dependent
- Customers that don’t have Partners are more likely to churn
- Customers without Dependents are also more likely to churn
Senior Citizens and Dependent:
sns.countplot("SeniorCitizen", data=df, hue = 'Dependents')
- Senior Citizen is less likely to have Dependent
Phone and Internet services
bar_plot(df, "MultipleLines")
sns.countplot("MultipleLines", data=df, hue = 'Churn')
- Few customers don’t have phone service
- Customers with multiple lines have a slightly higher churn rate
bar_plot(df, "InternetService")
sns.countplot("InternetService", data=df, hue = 'Churn')
- Customers without internet have a very low churn rate
- Customers with fiber are more probable to churn than those with a DSL connection
Internet Services
There are six additional services for customers with the internet:
OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport, StreamingTV, StreamingMovies
cols = ["OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies"]
df1 = pd.melt(df[df["InternetService"] != "No"][cols]).rename({'value': 'Has service'}, axis=1)
plt.figure(figsize=(10, 4.5))
ax = sns.countplot(data=df1, x='variable', hue='Has service')
ax.set(xlabel='Additional service', ylabel='Num of customers')
plt.show()
- Customers with the first 4 additionals (security to tech support) are more unlikely to churn
- Streaming service is not predictive for churn
Payment Method
bar_plot(df, "PaymentMethod")
sns.countplot("PaymentMethod", data=df, hue = 'Churn')
- Electronic Check is the Largest Payment method
- Electronic Check has most churn in Payment Method
Correlation Between Features
plt.figure(figsize=(12, 6))
df.drop(['customerID', 'total_charges_to_tenure_ratio', 'monthly_charges_diff'],
axis=1, inplace=True)
df_corr = df.apply(lambda x: pd.factorize(x)[0])
ax = sns.heatmap(df_corr.corr(), xticklabels=corr.columns, yticklabels=corr.columns,
linewidths=.2, cmap="YlGnBu")
plt.figure(figsize=(15, 10))
sns.heatmap(df_corr.corr(), annot=True)
Feature Importance
params = {'random_state': 0, 'n_jobs': 4, 'n_estimators': 5000, 'max_depth': 8}
# One-hot encode
df = pd.get_dummies(df)
# Drop redundant columns (for features with two unique values)
drop = ['Churn_Yes', 'Churn_No', 'gender_Female', 'Partner_No',
'Dependents_No', 'PhoneService_No', 'PaperlessBilling_No']
x, y = df.drop(drop,axis=1), df['Churn_Yes']
# Fit RandomForest Classifier
clf = RandomForestClassifier(**params)
clf = clf.fit(x, y)
# Plot features importances
imp = pd.Series(data=clf.feature_importances_, index=x.columns).sort_values(ascending=False)
plt.figure(figsize=(10,12))
plt.title("Feature importance")
ax = sns.barplot(y=imp.index, x=imp.values, palette="Blues_d", orient='h')
Oversampling Technique
Synthetic Minority Oversampling Technique(SMOTE) is an oversampling technique and widely used to handle the imbalanced dataset. This technique synthesizes new data points for minority class and oversample that class.
from imblearn.over_sampling import SMOTE
sm = SMOTE(random_state=0)
X_resampled, y_resampled = sm.fit_resample(x, y)y_resampled.value_counts()Out[162]:
1 5174
0 5174
Name: Churn_Yes, dtype: int64
Train Test Split
Divides data into Train and Test Subset
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size = 0.2, random_state=42)
Model
For Starter, the GradientBoostingClassifier model is implemented to show to results of the basic model and its predictions.
clf_forest = GradientBoostingClassifier()
clf_forest.fit(X_train, y_train)
Train Predict
Model prediction on the training dataset
pred = clf_forest.predict(X_train)
accuracy_score(y_train, pred)0.8659096400096642
Test Predict
Model prediction in testing dataset
pred_test = clf_forest.predict(X_test)
accuracy_score(y_test, pred_test)0.855072463768116
Evaluation
# Confusion Matrix
from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred) # Confusion Matrixfrom sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred) # Accuracyfrom sklearn.metrics import recall_score
recall_score(y_true, y_pred, average=None) # Recallfrom sklearn.metrics import precision_score
precision_score(y_true, y_pred, average=None) # PrecisionClassification report :
precision recall f1-score support
1 0.85 0.86 0.86 1049
0 0.86 0.85 0.85 1021
accuracy 0.86 2070
macro avg 0.86 0.85 0.86 2070
weighted avg 0.86 0.86 0.86 2070
We have achieved an overall accuracy of almost 85% with just direct implementation of the model without performing extensive feature engineering, feature selection, and hyperparameter tuning. If we apply all these techniques we can easily get accuracy above 90% and improve the model. Different model implementation and comparison can also yield an improvement in results.
This tutorial is more focused on Exploratory Data Analysis because it one of the important parts of a machine learning project cycle model building and improvement can be done easily but understanding and have an intuition about data is very important to solve a machine learning problem.
I hope this tutorial helps you to gain intuition and understanding of Churn prediction and its applications. I left the feature creation and selection part to you to experiment with and implement your understanding of the problem.
Complete Notebook of the project can be downloaded from the Repository.
Thank you for reading. Please let me know if you have any feedback.
I welcome feedback and constructive criticism and can be reached on Linkedin.