Mastering Decision Tree Models for Smarter Decisions

Sachinsoni
13 min readAug 23, 2023

--

Imagine you have to make decisions step by step. A decision tree is like a map that helps you choose the best path. It starts with a question, and depending on your answer, it leads to more questions or a final answer. Each question is like a branch on a tree, and the final answers are like leaves. It’s a helpful way to make choices based on the information you have. Let’s start this amazing blog!

What is Decision Tree ?

a. Non-Parametric model : A decision tree is considered a non-parametric model because it doesn’t assume a specific mathematical form for the relationship between inputs and outputs.

b. White Box model : A decision tree is often called a “white box” model because its decision-making process is easy to understand and interpret.

c. Ability to work with non-linear data and mother of all tree based algorithms.

d. It is a giant if-else based model .

Intuition of Decision Tree :

For given below table a decision tree is formed which is based on if-else statement.

What is the implementation approach for decision trees in scikit-learn?

The CART (Classification and Regression Trees) algorithm constructs decision trees by selecting the best features to split the data based on criteria like Gini impurity (for classification) or mean squared error (for regression). It recursively partitions the data into branches, creating decision paths that lead to final predictions or classifications at the leaf nodes. The process continues until a stopping condition is met, and the resulting tree can be pruned for simplicity and generalization.

Let us take an example to understand How CART algorithm works and Decision trees forms.

I have a below dataset in which Job_Outcome is our output column and I am drawing Decision tree using CART algorithm.

First I am taking Field = Science .

When the field is categorized as Science and marked as True, there are 4 individuals employed and 1 individual unemployed. Consequently, the probabilities of being employed and unemployed are 4/5 and 1/5, respectively. Conversely, when this condition is not met, there are 3 individuals employed and 2 individuals unemployed, yielding probabilities of 3/5 for employment and 2/5 for unemployment.

Now we calculate Gini impurity for ‘yes’ case as well as ‘no’ case.

gini impurity for ‘yes’ case = 1- (4/5)² — (1/5)² = 8/25

gini impurity for ‘no’ case = 1- (3/5)² — (2/5)² = 12/25

Now I am taking weighted sum of both ‘yes’ and ‘no’ gini impurity = (5/10)*(8/25) + (5/10)*(12/25) = 0.4

Now, we have calculate gini impurity for field = Science, similarly we will calculate gini impurity for Degree_Type and Average_Grade column and select those column as a root node which has a low value of gini impurity.

Now we move on Degree_Type column and here three possibilities occur, degree= UG, degree = PG and degree=PhD, so we take it one by one and calculating gini impurity for each possibilities and then we decide which degree is to be taken.

Similarly we have to calculate gini impurity for degree = PhD and suppose its gini impurity comes 0.4 then we ignore gini impurity for degree= UG and PG because in this case gini impurity for degree = PhD is low, so we select it for further process.

Now we move on Average_grade column and one thing is notice it is a numerical column, so method for calculating gini impurity is slightly different for in this case. See below diagram for calculating gini impurity in case of numerical column :-

To begin, we arrange the values within the numerical column in ascending order. Then, we compute the means between consecutive pairs of values. For each of these mean values, we calculate the corresponding Gini impurity using the approach depicted in the diagram provided above. After evaluating the Gini impurity for each mean value, we identify the means associated with the lowest Gini impurity. These selected mean values serve as pivotal points to guide our data divisions, ensuring the creation of branches with minimal impurity during the construction of the Decision Tree.

We proceed to calculate the Gini impurity for the columns representing degree, average, and field. By comparing these Gini impurity values, we determine which column exhibits the lower impurity. This column is then chosen as the starting point or root node for constructing our Decision Tree. Following this, the entire process is reiterated for each subsequent branch, where the column with the least Gini impurity becomes the basis for further division. This iterative approach continues until we reach leaf nodes, completing the construction of the Decision Tree.

So in this way CART algorithms works to create decision trees for provided dataset.

Visualizing Gini Impurity :

I have a bag in which a mixture of 3 oranges and 2 apples are kept. To help me understand, I created a grid where one axis represents my guesses, and the other axis represents the actual labels of the fruits. Each fruit has a probability of 1/5. This grid assists me in visualizing the concept of Gini impurity and its implications as I consider different possibilities of guessing the fruit types based on the provided probabilities. See below diagram:

Geometric Intuition of Decision Trees :

In a decision tree, the process of “marking the cut” involves selecting a specific threshold or value for a feature that optimally divides the data into distinct branches. This threshold is chosen in a way that minimizes the impurity within each resulting branch.
See the below diagram how decision tree marks the cut :

Implementation of Decision tree Code :
I am taking Iris dataset example for implementing this and I am taking only two input columns which are sepal length and sepal width.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import export_text

# Load iris dataset
iris = load_iris()
X = iris.data[:, :2] # we only take the first two features.
y = iris.target

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth=3,min_samples_split=40)
clf.fit(X_train, y_train)

# Check accuracy
y_pred = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")

# Plot the decision tree
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=iris.feature_names[:2], class_names=iris.target_names)
plt.show()

# Plot the decision boundary
plt.figure(figsize=(8, 6))
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
np.arange(y_min, y_max, 0.01))

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.contourf(xx, yy, Z, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.show()

and the output we got is :

Decision Tree Created by CART algorithm for above IRIS dataset
Decision boundary for IRIS dataset

How CART algorithm works in case of Regression Problem ?

In case of regression, instead of calculating weighted gini impurity, we calculate weighted mean squared error or variance. See the below diagram:

Overfitting in case of Decision Tree :

Overfitting in decision trees occurs when the tree becomes too complex, capturing noise in the training data and leading to poor predictions on new data. This is often indicated by deep trees with few samples per leaf, causing the model to fit training data precisely but generalize poorly.

Overfitting occurs in Decision Tree

To mitigate overfitting, techniques such as pruning, setting minimum samples per leaf, and limiting the tree’s depth are employed to strike a balance between complexity and generalization.

Decision Tree Trained for depth 3 to prevent overfitting

How Prediction is done in case of Decision Tree :

Prediction in a decision tree involves traversing the tree from the root node to a specific leaf node, guided by the conditions at each internal node. Starting at the root, the input data point’s features are compared to the conditions associated with the current node. Depending on whether the conditions are met, the traversal proceeds along the appropriate branch. This process continues recursively until a leaf node is reached, where the prediction is based on the majority class (in classification) or the average value (in regression) of the training samples that reached that leaf during training. The resulting prediction provides the outcome for the input data point based on the learned patterns in the decision tree.

Let us take an example to understand this concept :
Suppose you are working on iris dataset and trained decision tree as shown in the figure:

Consider a query point with a sepal length (sl) of 3.9 and a sepal width (sw) of 3. Using this data, we can predict whether it belongs to the setosa, versicolor, or virginica class. Starting at the root node, we check if sl ≤ 5.45, which is true in this case, leading us along the “yes” branch. Then we evaluate sw ≤ 2.8, which is false, directing us to the “no” branch. Given that the majority class in this branch is setosa, our prediction is setosa for this query point.

This prediction process aligns similarly with regression tasks, where the difference lies in calculating averages for each class instead of considering majority classes as in classification. The methodology remains consistent: traverse the decision tree, evaluate conditions, and make predictions based on the learned patterns during training.

Advantages of Decision Trees :

  • Simple to understand and to interpret. Trees can be visualized.
  • Requires little data preparation. Other techniques often require data normalization, dummy variables need to be created and blank values to be removed. Note however that this module does not support missing values.
  • The cost of using the tree (i.e., predicting data) is logarithmic in the number of data points used to train the tree.
  • Can work on non-linear datasets
  • Can give you feature importance.
# In above Decision tree we trained the model with name clf and to know feature
# importance we use following code.
clf.feature_importances_

Disadvantages of Decision Trees :

  • Decision-tree learners can create over-complex trees that do not generalize the data well. This is called overfitting. Mechanisms such as pruning, setting the minimum number of samples required at a leaf node or setting the maximum depth of the tree are necessary to avoid this problem.
  • Decision trees can be unstable because small variations in the data might result in a completely different tree being generated. This problem is mitigated by using decision trees within an ensemble.
  • Predictions of decision trees are neither smooth nor continuous, but piecewise constant approximations as seen in the figure. Therefore, they are not good at extrapolation.

This limitation is inherent to the structure of decision tree models. They are very useful for interpretability and for handling non-linear relationships within the range of the training data, but they aren’t designed for extrapolation. If extrapolation is important for your task, you might need to consider other types of models.

How does Decision Tree calculate feature importance ?

The formula for calculating feature importance is shown in below figure:

Concept of Pruning in Decision Trees

Pruning is a technique used in machine learning to reduce the size of decision trees and to avoid overfitting. Decision trees are susceptible to overfitting because they can potentially create very complex trees that perfectly classify the training data but fail to generalize to new data. Pruning helps to solve this issue by reducing the complexity of the decision tree, thereby improving its predictive power on unseen data.
There are two main types of pruning: pre-pruning and post-pruning.

1. Pre-pruning (Early stopping): This method halts the tree construction early. It can be done in various ways: by setting a limit on the maximum depth of the tree, setting a limit on the minimum number of instances that must be in a node to allow a split, or stopping when a split results in the improvement of the model’s accuracy below a certain threshold.

2. Post-pruning (Cost Complexity Pruning): This method allows the tree to grow to its full size, then prunes it. Nodes are removed from the tree based on the error complexity trade-off. The basic idea is to replace a whole subtree by a leaf node, and assign the most common class in that subtree to the leaf node.

Pre-pruning Techniques :

  1. Maximum Depth: One of the simplest forms of pre-pruning is to set a limit on the maximum depth of the tree.

after applying max_depth = 2, decision tree looks like :

Once the tree reaches the specified depth during training, no new nodes are created. This strategy is simple to implement and can effectively prevent overfitting, but if the maximum depth is set too low, the tree might be overly simplified and underfit the data.

2. Minimum Samples Split: This is a condition where a node will only be split if the number of samples in that node is above a certain threshold. If the number of samples is too small, then the node is not split and becomes a leaf node instead. This can prevent overfitting by not allowing the model to learn noise in the data.
after applying min_samples_split = 10, the decision tree looks like :

3. Minimum Samples Leaf: This condition requires that a split at a node must leave at least a minimum number of training examples in each of the leaf nodes. Like the minimum samples split, this strategy can prevent overfitting by not allowing the model to learn from noise in the
data.
after applying min_samples_leaf = 50 , decision tree looks like :

4. Maximum Leaf Nodes: This strategy limits the total number of leaf nodes in the tree. The tree stops growing when the number of leaf nodes equals the maximum number.
after applying max_leaf_nodes = 4, the decision tree looks like :

5. Minimum Impurity Decrease: This strategy allows a node to be split if the impurity decrease of the split is above a certain threshold. Impurity measures how mixed the classes within a node are. If the decrease is too small, the node becomes a leaf node.
after applying min_impurity decrease = 0.1, the decision tree looks like :

6. Maximum Features: This strategy considers only a subset of features for deciding a split at each node. The number of features to consider can be defined and this helps in reducing overfitting.

Advantages of Pre-pruning :

  1. Computational Efficiency: By limiting the size of the tree, pre-pruning can substantially reduce the computational cost of training and prediction.
  2. Simplicity: Pre-pruning criteria such as maximum depth or minimum number of samples per leaf are easy to understand and implement.

Disadvantages of Pre-Pruning:

1. Risk of Underfitting: If the stopping criteria are too strict, pre-pruning can halt the growth of the tree too early, leading to underfitting. The model may become overly simplified and fail to capture important patterns in the data.

2. Requires Fine-Tuning: The pre-pruning parameters (like maximum depth or minimum samples per leaf) often require careful tuning to find the right balance between underfitting and overfitting.

3. Short Sightedness :Can prune good nodes if they come after a bad node.

In the realm of machine learning, Decision Trees stand as a foundational algorithm known for their interpretability and adaptability. This overview navigates their structure, starting from basic splits based on impurity metrics like Gini, to finding optimal thresholds for feature divisions. Understand their interpretability as white-box models, while also addressing the challenge of overfitting through pruning and limitations. Uncover feature importance through metrics like Gini importance, quantifying which features drive predictions. This concise exploration equips you to harness the power of Decision Trees for insightful, data-driven decision-making.

--

--