Hyperparameter Tuning and Pruning: More about Decision Trees in R with rpart

Dima Diachkov
Data And Beyond
Published in
7 min readJul 19, 2023

In a previous article about decision trees (this one), we explored how to apply Decision Tree Classification in R using the Iris dataset. But… There is still so much more to unearth in the world of machine learning. Today, we are going to dig deeper and investigate some A LITTLE BIT more advanced techniques and best practices to refine your knowledge in the data science realm.

This is part #31 of the “R for Applied Economics” guide, where we collectively explore various depths of R, data science, and financial/economic analysis. Today we are going to cut some branches of our overfitting trees.

Credits: Unsplash | Zoltan Tasi

Advancing your Decision Tree Technique in R

Last time we achieved almost 98% of accuracy with our decision tree from rpart package. Let’s establish our starting point based on the previous article with a 98%-accuracy tree.

# data preparation
# Load the iris dataset
data(iris)
# Split the data into training and test sets
set.seed(20)
train_index <- sample(1:nrow(iris), nrow(iris)*0.7)
# train dataset formation
train_set <- iris[train_index, ]
str(train_set)
# test dataset formation
test_set <- iris[-train_index, ]
str(test_set)
# Growing trees
# Build the decision tree model
library(rpart)
iris_tree <- rpart(Species ~ ., data = train_set, method = "class")
iris_tree
# Predict on the test set
predictions <- predict(iris_tree, test_set, type = "class")
predictions

Now we will try to improve it.

The versatility of decision trees doesn’t stop at simple classification. They can also be optimized and pruned for better performance. Let’s dive into these aspects:

Hyperparameter Tuning

Machine learning models often have several parameters that can be adjusted to improve model performance. Decision Trees, for example, have parameters like the maximum depth of the tree, the minimum samples split, and the minimum samples leaf.

In R, we can use the rpart.control function to tune these parameters when we are building our decision tree. Let's see an example with default parameters:

library(rpart)
control <- rpart.control(minsplit = 20, minbucket = 7, maxdepth=30)
fit <- rpart(Species ~ ., data=iris, method="class", control=control)

In this example, the minsplit the parameter determines the minimum number of observations that must exist in a node for a split to be attempted, minbucket is the minimum number of observations in any terminal node, and maxdepth is the maximum depth of any node of the final tree. By tuning these parameters, we can ensure our decision tree is not overfitting or underfitting.

So you can iteratively play with all these parameters… but I will jump right to the improvement part straight away, I don’t want to waste your time.

My logic is the following:

  • Setting minbucket to 1 doesn't bring any added value, because by default, each terminal node (or leaf) will contain at least one observation. If you adjust it to a higher number, like 3 or 7, it implies that each terminal node will consist of a minimum of three or seven observations respectively.
  • The lower you set the minbucket value, the higher the specificity of your model. However, setting minbucket to an excessively low value, such as 1, could potentially lead to the problem of overfitting your model.

I will try to fictionally maximize the fitting. I will specify minbucket at 5.


# NEW ADVANCED TREE with control component
control <- rpart.control(minsplit=20, minbucket=5, maxdepth=20)
iris_tree_advanced <- rpart(Species ~ ., data = train_set, method="class", control=control)

# Predict on the test set
predictions_advanced <- predict(iris_tree_advanced, test_set, type = "class")
predictions_advanced

# Evaluate the model with the same Confusion Matrix as we used before
library(caret)
cm <- confusionMatrix(predictions_advanced, test_set$Species)
print("Tuned model")
cm
Output for the model above

We have achieved 100% accuracy on TEST data, which was not given to the model earlier. There are some caveats about overfitting, but as we are testing on the unseen earlier dataset — we are not doing anything wrong, we just improved the accuracy of our model from 98% to almost 100% (basically, the difference between this and previous outputs were in one object, that was earlier classified as versicolor while in fact being virginica; this model “mistake” is now fixed).

Please note, that we tried to reach 100% only to show that you can control the quality of the generated tree with control components, be cautious with 100% accuracy. That is achievable mostly on toy datasets.

Remember, if your accuracy is already high (e.g., >95%), attempting to reach 100% accuracy may lead you down the path of overfitting. Always validate your model with a test set or via cross-validation to ensure it generalizes well to unseen data. In this case we have 100% based on test data, which means that we are good to go.

Pruning Decision Trees

In order to demonstrate pruning, we will use earlier established tree with 100% fitting based on test data. Let’s pretend that we think that there is an overfitting or we would like to simplify the complexity of the tree. Pruning is what we need…

Pruning is another technique used to improve the performance of Decision Trees by removing the branches that have weak predictive power. This is accomplished by using a complexity parameter (CP), which is used to control the size of the decision tree and ultimately avoid overfitting.

The printcp function in R provides the CP table for our fitted rpart object. This table can help us identify the optimal CP that results in the most accurate model.

# Quality of our "overfitted" model
printcp(iris_tree_advanced)
Output for the code above

What do we see here?

  • CP (Complexity Parameter): This parameter determines a threshold under which the split of a node is not worth the complexity. In this case, we have values 0.548387, 0.370968, 0.016129, and 0.01.
  • nsplit: This is the number of splits made up to that point in the tree. For example, the row with nsplit = 1 represents the tree after the first split.
  • rel error (Relative Error): This is the error relative to the root node, it decreases with each split.
  • xerror (Cross-Validation Error): This is the estimate of prediction error for each split (and each value of CP) from cross-validation. The idea is to choose the smallest tree (i.e., fewest splits) such that the cross-validation error is within 1 standard error of the minimum.
  • xstd: This is the standard error of the cross-validated error.

From the table, we can see that the xerror doesn’t reduce significantly beyond the third split (xerror is increasing after the third split), which suggests that pruning might be beneficial at the third split to prevent overfitting.

I want to stress, that these are all valuable insights that can be used to help improve your decision tree model’s accuracy, primarily by focusing on the right amount of pruning to avoid overfitting and reduce complexity.

The printcp output displays a table with CP values for each split (along with the number of splits, relative error, cross-validation error, and standard error). The CP values indicate how much the overall error rate decreases with each split. A large CP indicates that a split resulted in a significant decrease in error, while a smaller CP suggests a less impactful split.

Looking at the table, the cross-validation error (xerror) increases after the third split (from 0.14516 to 0.16129), suggesting that the model may be starting to overfit at that point (AND WE KNOW THAT AS WE MADE IT SO). Hence, pruning may indeed be beneficial here to prevent overfitting and simplify the model.

To perform pruning, you we choose a CP value that balances model complexity (number of splits) with predictive accuracy (xerror). The optimal CP value is typically chosen as the one associated with the smallest tree within one standard error of the minimum xerror.

In this case, it might make sense to prune the tree at the third split, corresponding to a CP of 0.02, because the xerror doesn't decrease significantly after 3rd split.

Let’s prune our tree with the prune function:

pruned_tree <- prune(iris_tree_advanced, cp = 0.02)
printcp(pruned_tree)

predictions_pruned <- predict(pruned_tree, test_set, type = "class")
predictions_pruned

# Evaluate the model
cm_pruned <- confusionMatrix(predictions_pruned, test_set$Species)
print("Tuned model")
cm_pruned
Output for the code above

We have restored the initial performance of the tree of 98% and avoided overfitting. Good job!👏

Wrap-up

Today we’ve delved deeper into decision tree classification in R, focusing on advanced techniques of hyperparameter tuning and tree pruning. Starting with a model built using the rpart package that achieved nearly 98% accuracy on the Iris dataset, we looked for ways to boost performance.

By utilizing the rpart.control function to fine-tune the parameters such as minsplit, minbucket, and maxdepth, we managed to achieve 100% accuracy on the test dataset. Yet, it's important to remember the risk of overfitting, especially when a model exhibits such a high accuracy rate.

To counter potential overfitting and simplify our model, we explored tree pruning. Through the complexity parameter (CP), we controlled the size of the decision tree and enhanced its performance. Utilizing the printcp function in R, we identified the optimal CP and then pruned the decision tree with the prune function. The pruned tree successfully avoided overfitting and maintained the initial model's performance, demonstrating the effectiveness of these advanced techniques in refining decision tree models. These methods are applicable across various datasets and problem scenarios, making them invaluable tools in the vast realm of machine learning.

Please clap 👏 and subscribe if you want to support me. Thanks! ❤️‍🔥

--

--

Dima Diachkov
Data And Beyond

Balancing passion with reason. In pursuit of better decision making in economic analysis and finance with data science via R+Python