Cost Complexity Pruning in Decision Trees

Understanding the problem of Overfitting in Decision Trees and solving it by Minimal Cost-Complexity Pruning using Scikit-Learn in Python

Sarthak Arora
Analytics Vidhya
5 min readSep 19, 2020

--

Decision Tree is one of the most intuitive and effective tools present in a Data Scientist’s toolkit. It has an inverted tree-like structure which was once used only in Decision Analysis but is now a brilliant Machine Learning Algorithm as well, especially when we have a Classification problem in our hand.

They are well-known for their capability to capture the patterns in the data. But, excess of everything is harmful, right? Decision Trees are infamous as they can cling too much to the data they’re trained on.
Hence, your tree gives poor results on deployment because it cannot deal with a new set of values.

But, you need not worry. Just like a skilled mechanic has wrenches of all sizes readily available in his toolbox, a skilled Data Scientist also has his set of techniques to deal with any kind of problem.

Pruning is one of the techniques that is used to overcome our problem of Overfitting. Pruning, in its literal sense, is a practice which involves the selective removal of certain parts of a tree(or plant), such as branches, buds, or roots, to improve the tree’s structure, and promote healthy growth. This is exactly what Pruning does to our Decision Trees as well. It makes it versatile so that it can adapt if we feed any new kind of data to it, thereby fixing the problem of overfitting.

It reduces the size of a Decision Tree which might slightly increase your training error but drastically decrease your testing error, hence making it more adaptable.

Minimal Cost-Complexity Pruning is one of the types of Pruning of Decision Trees.

This algorithm is parameterized by α(≥0) known as the complexity parameter.

The complexity parameter is used to define the cost-complexity measure, Rα(T) of a given tree T: Rα(T)=R(T)+α|T|
where |T| is the number of terminal nodes in T and R(T) is traditionally defined as the total misclassification rate of the terminal nodes.

In its 0.22 version, Scikit-learn introduced this parameter called ccp_alpha (Yes! It’s short for Cost Complexity Pruning- Alpha) to Decision Trees which can be used to perform the same.

We will use the Iris dataset to fit the Decision Tree on. You can download the dataset here.

First, let us import the basic libraries and the dataset-

Importing the libraries and dataset

The Dataset looks like this-

A Snippet of our dataset

Our aim is to predict the Species of a flower based on its Sepal Length and Width.

We will split the dataset into two parts- Train and Test so that we can see how our model performs on unseen data as well.
(We shall use the train_test_split function from sklearn. model_selection to split the dataset)

Spliiting the Dataset into Train and Test

Now, let’s fit a Decision Tree to the train part and predict on both test and train.
(We will use DecisionTreeClassifier from sklearn.tree for this purpose)

Implementation of the Decision Tree

By default, the Decision Tree function doesn’t perform any pruning and allows the tree to grow as much as it can. We get an accuracy score of 0.95 and 0.63 on train and test part respectively as shown below. We can say that our model is Overfitting i.e. memorizing the train part but is not able to perform equally well on the test part.

Accuracy on train and test is 0.95 and 0.63 respectively

DecisionTree in sklearn has a function called cost_complexity_pruning_path, which gives the effective alphas of subtrees during pruning and also the corresponding impurities. In other words, we can use these values of alpha to prune our decision tree-

Cost Complexity Pruning Path

We will use these set these values of alpha and pass it to the ccp_alpha parameter of our DecisionTreeClassifier. By looping over the alphas array, we will find the accuracy on both Train and Test part of our dataset.

Code to loop over the alphas and plot the line graph for corresponding Train and Test accuracies,
Accuracy v/s Alpha

From the above plot, we can see that between alpha=0.01 and 0.02, we get the maximum test accuracy. Although our train accuracy has decreased to 0.8, our model is now more generalized and it will perform better on unseen data.

Train and Test Accuracy at α=0.02

We can also use K-fold cross validation to test our model rather than using a train-test split. It will give us a much better overview of our model’s performance on unseen data.

If you want to understand the Math behind Cost-Complexity Pruning, click here. Check out the scikit-learn documentation of Decision trees by clicking here.

You can find the notebook on my GitHub and take a closer look at what I have done.
Also, connect with me on LinkedIn and let’s discuss about Data!

Thanks for staying till the end! ^_^

Also, please give me a follow here, on Medium. That will motivate me to create content frequently for you. Cheers!

--

--

Sarthak Arora
Analytics Vidhya

Data Scientist @ Jupiter.co | Ex - Assistant Manager in Analytics @ Paisabazaar | I write about Data Science and ML | https://www.linkedin.com/in/iasarthak/