Overfitting and Pruning in Decision Trees — Improving Model’s Accuracy

Rishika Ravindran
Nerd For Tech
Published in
6 min readJan 18, 2023

Decision Trees are a non-parametric supervised machine-learning model which uses labeled input and target data to train models. They can be used for both classification and regression tasks. Recapitulating from my previous article on Decision Trees, Decision trees represent the decision-making process through branching, tree-like structures. They make decisions based on how previous sets of questions (labels/nodes) were answered. Decision tree models can make predictions on the value of a target variable by learning the parameters from the input features.

The techniques discussed in this article are explained with predicting house prices examples later in the article for which the implemented notebook can be found here.

What is Overfitting?

Overfitting is a common problem that needs to be handled while training a decision tree model. Overfitting occurs when a model fits too closely to the training data and may become less accurate when encountering new data or predicting future outcomes. In an overfit condition, a model memorizes the noise of the training data and fails to capture essential patterns [1].

Image Source

Overfitting in Decision Trees

In decision trees, In order to fit the data (even noisy data), the model keeps generating new nodes and ultimately the tree becomes too complex to interpret. The decision tree predicts well for the training data but can be inaccurate for new data. If a decision tree model is allowed to train to its full potential, it can overfit the training data.

There are techniques to prevent the overfitting of decision trees which will be discussed in this article.

What is Pruning?

Pruning is a technique that removes parts of the decision tree and prevents it from growing to its full depth. Pruning removes those parts of the decision tree that do not have the power to classify instances. Pruning can be of two types — Pre-Pruning and Post-Pruning.

The above example highlights the differences between a pruned and an unpruned decision tree. The unpruned tree is denser, more complex, and has a higher variance — resulting in overfitting.

Pre-Pruning

Pre-Pruning, also known as ‘Early Stopping’ or ‘Forward Pruning’, stops the growth of the decision tree — preventing it from reaching its full depth. It stops the non-significant branches from generating in a decision tree. Pre-Pruning involves the tuning of the hyperparameters prior to training the model. Hyperparameters are the parameters whose values control the learning process and determine the value of the model parameters. In simpler terms, they are any parameters whose values are decided prior to the model’s training and remain the same when it ends [5].

Pre-Pruning stops the tree-building process for leaves with small samples. During each stage of the splitting of the tree, the cross-validation error will be monitored [2]. If the value of the error does not continue to decrease, the tree’s growth is stopped.

Hyperparameter Tuning for Pre-Pruning Decision Trees

The hyperparameters that can be tuned for pre-pruning or early stopping are max_depth, min_samples_leaf, and min_samples_split.

  • max_depth: Specifies the maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.

The more the value of max_depth, the more complex the tree will be. On increasing the max_depth value, the training error will decrease but it can result in inaccurate predictions on the test data (overfitting). Hence, the correct max_depth value is the one that results in the best-fit decision tree — neither underfits nor overfits the data.

  • min_samples_leaf: Specifies the minimum number of samples required at a leaf node.

Let’s consider a scenario where the min_samples_leaf of a model is set to 5, and our current node has 20 samples. The model has to decide where to continue splitting on this node or stop here. To consider this parameter, the model splits at this node and checks for the number of samples in the left and right branches. If both of them are larger than 5, then the model will split at this node. If even one of the branches has less than 5 samples, this split will not be done and this unsplit node will be a terminal node (leaf).

  • min_samples_split: Specifies the minimum number of samples required to split an internal node.

Let’s consider a scenario where the min_samples_split is set to 50. Then, any node with less than 50 samples will not be split further. Therefore, all terminal nodes will have less than 50 samples, and all internal nodes will have samples of more than 50 (or equal) [3].

Post-Pruning

Post-Pruning or ‘backward pruning’ is a technique that eliminates branches from a “completely grown” decision tree model to reduce its complexity and variance. This technique allows the decision tree to grow to its full depth, then removes branches to prevent the model from overfitting. By doing so, the model might slightly increase the training error but drastically decrease the testing error [4].

In Post-Pruning, non-significant branches of the model are removed using the Cost Complexity Pruning (CCP) technique. This algorithm is parameterized by α(≥0) or alpha known as the complexity parameter. The cost_complexity_pruning_path function of the sklearn package in Python calculates the effective alphas and their corresponding impurities at each step of the pruning process. In cost complexity pruning, the ccp_alpha or the alpha value can be tuned to get the best-fit decision tree model.

How does Cost Complexity Pruning or ‘Weakest Link Pruning’ Work?

Cost Complexity Pruning or ‘Weakest Link Pruning’ works by calculating a tree score that is based on the Sum of Squared Residuals (SSR) of the tree or subtrees, and a Tree Complexity penalty (T), which is a function of the number of leaves or terminal nodes in the tree or subtree. The SSR increases as the trees get shorter, and the Tree Complexity Penalty compensates for the difference in the number of leaves.

  • Tree Score = SSR + alpha*T, where alpha is a tuning parameter we find using Cross Validation.

This pruning technique then calculates different values for alpha, giving us a sequence of trees from a full-sized tree to just a leaf. This is repeated until a 10-Fold Cross-Validation is done. The final value for alpha is the one that on average gave us the lowest Sum of Squared Residuals with the testing data [6].

Implementing Post-Pruning with Python’s sci-kit learn Package

Let’s try to understand the implementation of Post-Pruning in decision trees through a predicting house prices example. The dataset used is publicly available on Kaggle.

This notebook can be accessed in my Github repository. The following images are snippets from this notebook.

Python sklearn package provides functionality to implement cost complexity pruning with the cost_complexity_pruning_path function. It calculates the effective alphas and their corresponding impurities at each step of the pruning process. The ccp_alpha can be tuned to get the best-fit model. With an increase in ccp_apha values, more nodes of the tree are pruned.

After training a decision tree to its full length, the cost_complexity_pruning_path function can be implemented to get an array of the ccp_alphas and impurities values. The decision tree model can then be trained for different values of ccp_alphas, and the train and test performance scores can be computed for each alpha value using performance or accuracy metrics. The alpha value with the highest performance score of the testing data is chosen as the final ccp_alpha value for the model [1].

Through this example, we can see how the accuracy of a decision tree model can be improved through the implementation of Cost Complexity Pruning, as compared to an unpruned tree. The model accuracy can further be improved by tuning hyperparameters like the max_depth, min_samples_leaf, and min_samples_split along with post-pruning.

--

--

Rishika Ravindran
Nerd For Tech

Data Science and Machine Learning enthusiast || Passionate about building my knowledge in the technical and analytical domain || I enjoy writing with a purpose