Build and Visualize a simple Decision Tree using Sklearn and Graphviz
Classifier models solve the problems where we have to classify entities whether they belong to a particular class or another (one or more) based on one or more attributes. Thus Classifiers being Supervised machine learning algorithms require labeled data to learn on its basis and uncover hidden associations.
Introduction
Decision trees mimic the human decision-making process to distinguish between two classes of objects and are especially effective in dealing with categorical data. Unlike other algorithms, such as logistic regression and support vector machines (SVMs), decision trees do not help in figuring out a linear relationship between the independent variable and the target variable. However, they can be used to model highly non-linear data. Thus we can use decision trees to explain all the factors that lead to a particular decision or prediction.
A decision tree splits data into multiple subsets of data. Each of these subsets is then further split into more subsets to arrive at the desired decision.
The first and top node of a decision tree is called the root node. The arrows in a decision tree always point away from this node. The node that cannot be further classified or split is called the leaf node. The arrows in a decision tree always point towards this node. Any node that contains descendant nodes and is not a leaf node is called the internal node.
Let’s build a simple Decision Tree on Telecom Customer Data to find out whether a customer would likely to churn or not
The dataset we’re using is from Kaggle, containing information about 7000 customers along with their personal and details related to the services they have opted in with the telecom company. The dataset can be found here.
We will use python libraries NumPy
,Pandas
to perform basic data processing and pydotplus
, graphviz
for visualizing the built Decision Tree.
Data Preparation and Cleaning
Importing NumPy
and Pandas
to read the dataset.
Let’s have a quick look at the dataset through some records
We will quickly check the data types and count of records using df.info()
There are a lot of columns having categorical data, We can check what category values are present across each column as below
As we saw earlier, there are 3 columns with numeric data namely MonthlyCharges
, tenure
, and SeniorCitizen
. however, the SeniorCitizen
the column isn’t really a numeric, it’s categorical with numeric levels. We can process the first two columns by converting them into categorical features, This is achieved with binning or bucketing. Also, TotalCharges
is considered as an Object but has numeric data inside. Let’s fix these data issues.
df['SeniorCitizen'] = df['SeniorCitizen'].map({0:'No',1:'Yes'})#Binning the tenure column
cut_labels = ['0-12', '13-24', '25-36', '37-48','49-60','61-72']
cut_bins = [0, 12,24,36,48,60,72]
df['Tenure Period'] = pd.cut(df['tenure'], bins=cut_bins, labels=cut_labels)#Binning the MonthlyCharges column
cut_labels = ['0-20', '21-40', '41-60', '61-80','81-100','101-120']
cut_bins = [0, 20,40,60,80,100,120]
df['MonthlyCharges_Range'] = pd.cut(df['MonthlyCharges'], bins=cut_bins, labels=cut_labels)#Binning the Age column
cut_labels = ['0-1000', '1001-2000','2001-4000','4001-6000','6001-8000','8001-10000']
cut_bins = [0, 1000,2000,4000,6000,8000,10000]
df['TotalCharges_Range'] = pd.cut(df['TotalCharges'], bins=cut_bins, labels=cut_labels)
df['TotalCharges_Range'].value_counts()df['TotalCharges']=pd.to_numeric(df['TotalCharges'],errors='coerce')
Now, Let’s check whether our dataset has any missing values.
Since the data is categorical, the best strategy to impute them is by replacing void entries with the most frequent values within that feature.
Here, We can remove the columns that do not add any value to the dataset.
Let’s give a final sanity check to the dataset
Label Encoding
Machine learning algorithms can not work with strings or character data, We will need to change these values to numeric before we feed the data set to the algorithm. One way to handle this is via Label Encoding. It assigns a unique number starting from 0 to each class of data.
Train Test Split
Now, we will divide our dataset into two sets i.e one for training and one to evaluate the performance on the unseen records. First, we will separate the Target variable and features.
from sklearn.model_selection import train_test_split
Model Building
The data has been split into train and test, now we will proceed towards fitting a Decision Tree Classifier model from Sci-kit’s sklearn.tree
module.
from sklearn.tree import DecisionTreeClassifier
Visualization of Tree
Yay, we have fitted a Decision Tree onto the train set, the next step we would see how the tree has grown to predict the final outcomes. For this, we need to install and import pydotplus
and graphviz
python libraries.
We can even export the viz to a pdf by using the below line
graph.write_pdf("dt_telecom_churn.pdf")
From above we can clearly see that the Tree model has found only one set of rules to know that customers would not churn. Is this tree a very terrible one or is just how it is?
Let’s check the performance of this model on the test data and check how are the accuracy and other metrics.
Model Evaluation
Let’s derive the Specificity
and Sensitivity
from the confusion matrix
Let’s get the accuracy score and the confusion matrix
Now, let’s calculate the rest of the metrics
Conclusion
Decision Trees are simple and intuitive models, However, they are high variance models i.e the slight change in train data may result in poor performance on the test as they try to find complex relationships in the data, also called overfitting.
Although our model has stable results on Train
and Test data
, this model has low Sensitivity
and high Specificity
and it would likely perform poorly to accurately predict the true positive class i.e The customer who has churned.
This can be accomplished by performing the Hyperparameter tuning. We will see how to improve the performance of this tree by fine-tuning the hyperparameters and with some other techniques in the next one.
Thanks for reading, Please check out my work on my GitHub profile and do give it ★ if you find it useful!