Build and Visualize a simple Decision Tree using Sklearn and Graphviz

Chinmay Gaikwad
ChiGa
Published in
5 min readSep 21, 2021

Classifier models solve the problems where we have to classify entities whether they belong to a particular class or another (one or more) based on one or more attributes. Thus Classifiers being Supervised machine learning algorithms require labeled data to learn on its basis and uncover hidden associations.

Introduction

Decision trees mimic the human decision-making process to distinguish between two classes of objects and are especially effective in dealing with categorical data. Unlike other algorithms, such as logistic regression and support vector machines (SVMs), decision trees do not help in figuring out a linear relationship between the independent variable and the target variable. However, they can be used to model highly non-linear data. Thus we can use decision trees to explain all the factors that lead to a particular decision or prediction.

A decision tree splits data into multiple subsets of data. Each of these subsets is then further split into more subsets to arrive at the desired decision.

The first and top node of a decision tree is called the root node. The arrows in a decision tree always point away from this node. The node that cannot be further classified or split is called the leaf node. The arrows in a decision tree always point towards this node. Any node that contains descendant nodes and is not a leaf node is called the internal node.

Let’s build a simple Decision Tree on Telecom Customer Data to find out whether a customer would likely to churn or not

The dataset we’re using is from Kaggle, containing information about 7000 customers along with their personal and details related to the services they have opted in with the telecom company. The dataset can be found here.

We will use python libraries NumPy,Pandas to perform basic data processing and pydotplus , graphviz for visualizing the built Decision Tree.

Data Preparation and Cleaning

Importing NumPy and Pandas to read the dataset.

Let’s have a quick look at the dataset through some records

We will quickly check the data types and count of records using df.info()

There are a lot of columns having categorical data, We can check what category values are present across each column as below

As we saw earlier, there are 3 columns with numeric data namely MonthlyCharges , tenure , and SeniorCitizen . however, the SeniorCitizen the column isn’t really a numeric, it’s categorical with numeric levels. We can process the first two columns by converting them into categorical features, This is achieved with binning or bucketing. Also, TotalCharges is considered as an Object but has numeric data inside. Let’s fix these data issues.

df['SeniorCitizen'] = df['SeniorCitizen'].map({0:'No',1:'Yes'})#Binning the tenure column
cut_labels = ['0-12', '13-24', '25-36', '37-48','49-60','61-72']
cut_bins = [0, 12,24,36,48,60,72]
df['Tenure Period'] = pd.cut(df['tenure'], bins=cut_bins, labels=cut_labels)
#Binning the MonthlyCharges column
cut_labels = ['0-20', '21-40', '41-60', '61-80','81-100','101-120']
cut_bins = [0, 20,40,60,80,100,120]
df['MonthlyCharges_Range'] = pd.cut(df['MonthlyCharges'], bins=cut_bins, labels=cut_labels)
#Binning the Age column
cut_labels = ['0-1000', '1001-2000','2001-4000','4001-6000','6001-8000','8001-10000']
cut_bins = [0, 1000,2000,4000,6000,8000,10000]
df['TotalCharges_Range'] = pd.cut(df['TotalCharges'], bins=cut_bins, labels=cut_labels)
df['TotalCharges_Range'].value_counts()
df['TotalCharges']=pd.to_numeric(df['TotalCharges'],errors='coerce')

Now, Let’s check whether our dataset has any missing values.

Since the data is categorical, the best strategy to impute them is by replacing void entries with the most frequent values within that feature.

Here, We can remove the columns that do not add any value to the dataset.

Let’s give a final sanity check to the dataset

Label Encoding

Machine learning algorithms can not work with strings or character data, We will need to change these values to numeric before we feed the data set to the algorithm. One way to handle this is via Label Encoding. It assigns a unique number starting from 0 to each class of data.

Train Test Split

Now, we will divide our dataset into two sets i.e one for training and one to evaluate the performance on the unseen records. First, we will separate the Target variable and features.

from sklearn.model_selection import train_test_split

Model Building

The data has been split into train and test, now we will proceed towards fitting a Decision Tree Classifier model from Sci-kit’s sklearn.tree module.

from sklearn.tree import DecisionTreeClassifier

Visualization of Tree

Yay, we have fitted a Decision Tree onto the train set, the next step we would see how the tree has grown to predict the final outcomes. For this, we need to install and import pydotplus and graphviz python libraries.

We can even export the viz to a pdf by using the below line

graph.write_pdf("dt_telecom_churn.pdf")

From above we can clearly see that the Tree model has found only one set of rules to know that customers would not churn. Is this tree a very terrible one or is just how it is?

Let’s check the performance of this model on the test data and check how are the accuracy and other metrics.

Model Evaluation

Let’s derive the Specificity and Sensitivity from the confusion matrix

Let’s get the accuracy score and the confusion matrix

Now, let’s calculate the rest of the metrics

Conclusion

Decision Trees are simple and intuitive models, However, they are high variance models i.e the slight change in train data may result in poor performance on the test as they try to find complex relationships in the data, also called overfitting.

Although our model has stable results on Train and Test data, this model has low Sensitivity and high Specificity and it would likely perform poorly to accurately predict the true positive class i.e The customer who has churned.

This can be accomplished by performing the Hyperparameter tuning. We will see how to improve the performance of this tree by fine-tuning the hyperparameters and with some other techniques in the next one.

Thanks for reading, Please check out my work on my GitHub profile and do give it if you find it useful!

--

--