10 Must-Know Models for ML Beginners: Decision Trees

Dagang Wei
6 min readFeb 17, 2024

--

Image generated by the author

This article is part of the series 10 Must-Know Models for ML Beginners.

Introduction

In the realm of machine learning (ML), decision trees reign as one of the most intuitive and interpretable algorithms. Their ability to mimic human decision-making processes makes them a popular choice for solving a wide array of problems. Let’s dive in and demystify decision trees.

What are Decision Trees?

Imagine a decision tree as a flowchart-like structure. It starts with a root node representing a question or a feature in your data. The branches extending from the root node depict possible values or ranges of that feature. Subsequent nodes in the tree pose further questions about other features, and the process continues. The very end of each branch, called a leaf node, signifies a class label (in classification problems) or a predicted numerical value (in regression problems).

How Decision Trees Work

Constructing a decision tree in machine learning involves a series of steps and principles that guide the decision-making process at each node of the tree until a prediction outcome is reached. A decision tree is a flowchart-like structure in which each internal node represents a “test” on an attribute (e.g., whether a coin flip comes up heads or tails), each branch represents the outcome of the test, and each leaf node represents a class label (decision taken after computing all attributes). The paths from root to leaf represent classification rules.

1. Select the Best Attribute

Use statistical measures to select the attribute that best splits the data into subsets. The most common measures are Gini Impurity, Entropy (used in the Information Gain metric), and the Gain Ratio. Entropy measures the disorder or uncertainty in the dataset, and Information Gain is the reduction in entropy or impurity after the dataset is split on an attribute. Gini Impurity measures the frequency at which any element of the dataset will be mislabeled when it is randomly labeled. The attribute with the highest Information Gain or the lowest Gini Impurity is typically chosen to make the decision.

2. Split the Data

Based on the best attribute selected, split the data into subsets that contain possible values for the attribute. Create a decision node identifying the best attribute at each step.

3. Recursive Splitting

Repeat the process for each child subset of the split data. Evaluate all the remaining attributes and choose the best attribute to split the data on. Continue this process recursively for each branch, using only the data points that reach the branch. The recursion is completed when one of the conditions is met: Every element in the subset belongs to the same class (pure). There are no more attributes to be selected, but the samples still do not belong to the same class (handled by majority voting). There are no samples left.

4. Pruning the Tree

Once the tree is built, it might be overly complex and overfit the training data, leading to poor generalization on unseen data. Pruning is the process of removing parts of the tree that do not provide additional power to classify instances. This can be done by removing branches that have little importance and replacing them with leaf nodes. Two common pruning techniques are pre-pruning (stop growing the tree earlier, before it perfectly classifies the training data) and post-pruning (allow the tree to classify the training set perfectly, then remove nodes or branches that do not contribute to the tree’s accuracy on a validation set).

5. Stopping Criteria

Define conditions under which the tree will stop growing. These conditions can include a minimum number of samples required to split a node, a maximum depth of the tree, or a minimum gain in impurity reduction needed to justify further splits.

When to Use It

Decision Trees are particularly useful when:

  • Interpretability is Key: In domains such as finance, healthcare, and policy-making, where understanding the decision-making process is crucial.
  • Handling Categorical Data: They naturally handle categorical variables without the need for dummy variables.
  • Non-Linear Data: They can capture non-linear relationships between features and the target.
  • Quick Prototyping: Their simplicity makes them an excellent choice for baseline models and quick prototyping.

When to Consider Alternatives

Despite their advantages, Decision Trees have limitations that might necessitate alternative models:

  • Overfitting: They are prone to overfitting, especially with complex datasets, leading to poor generalization on unseen data.
  • Instability: Small changes in the data can lead to drastically different trees being generated.
  • Performance: For some complex, high-dimensional datasets, other models like Random Forests or Gradient Boosting Machines (which are ensembles of Decision Trees) or deep learning models may yield better performance.

Example: Multi-Classification with Decision Tree

The code is available in this colab notebook.

import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Function to calculate Gini impurity
def gini_impurity(groups, classes):
n_instances = float(sum([len(group) for group in groups]))
gini = 0.0
for group in groups:
size = float(len(group))
if size == 0:
continue
score = 0.0
for class_val in classes:
p = [row[-1] for row in group].count(class_val) / size
score += p * p
gini += (1.0 - score) * (size / n_instances)
return gini

# Function to find the best split
def best_split(dataset):
unique_classes = list(set(row[-1] for row in dataset)) # Get unique class labels

# Initialize variables for tracking the best split
b_index, b_value, b_score, b_groups = 999, 999, 999, None

for feature_index in range(len(dataset[0]) - 1): # Check each feature (excluding the target)
for row in dataset:
feature_value = row[feature_index]
# Try splitting using the current feature value
groups = split_group(feature_index, feature_value, dataset)

# Calculate Gini impurity to evaluate the potential split
gini = gini_impurity(groups, unique_classes)

# Update if a better split is found
if gini < b_score:
b_index, b_value, b_score, b_groups = feature_index, feature_value, gini, groups

# Return the details of the best split found
return {'index': b_index, 'value': b_value, 'groups': b_groups}


# Function to split dataset based on an attribute and its value
def split_group(index, value, dataset):
left, right = list(), list()
for row in dataset:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right

# Function to create a terminal leaf node (majority class)
def terminate_node(group):
outcomes = [row[-1] for row in group]
return max(set(outcomes), key=outcomes.count)

# Recursive function to build the decision tree
def build_tree(node, max_depth, min_size, depth):
left, right = node['groups']
del (node['groups'])
# Check for no split
if not left or not right:
node['left'] = node['right'] = terminate_node(left + right)
return
# Check for max depth
if depth >= max_depth:
node['left'], node['right'] = terminate_node(left), terminate_node(right)
return
# Process left child
if len(left) <= min_size:
node['left'] = terminate_node(left)
else:
node['left'] = best_split(left)
build_tree(node['left'], max_depth, min_size, depth + 1)
# Process right child
if len(right) <= min_size:
node['right'] = terminate_node(right)
else:
node['right'] = best_split(right)
build_tree(node['right'], max_depth, min_size, depth + 1)

def build_decision_tree(X, y, max_depth=4, min_size=1):
"""Build a decision tree.

Args:
X (numpy.ndarray): Input features.
y (numpy.ndarray): Target labels.
max_depth (int): Maximum depth of the decision tree.
min_size (int): Minimum number of samples required to split a node.

Returns:
The root node of the built decision tree.
"""
data = np.column_stack((X, y))
root = best_split(data)
build_tree(root, max_depth, min_size, depth=1)
return root

# Make a prediction with a decision tree
def predict(node, row):
if row[node['index']] < node['value']:
if isinstance(node['left'], dict):
return predict(node['left'], row)
else:
return node['left']
else:
if isinstance(node['right'], dict):
return predict(node['right'], row)
else:
return node['right']

# Visualize the decision tree boundaries
def visualize_tree(X, y, tree):
# Plotting setup
cmap = ListedColormap(['tab:orange', 'tab:blue', 'tab:red', 'tab:green', 'tab:purple'])
h = .02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))

# Classify mesh points using the decision tree, creating the background plot
Z = np.array([predict(tree, x) for x in np.c_[xx.ravel(), yy.ravel()]])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=cmap, alpha=0.2)

# Plot samples
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors='k')
plt.title("Decision Tree Visualization")
plt.show()


# Generate sample data with 5 centers, y is in [0, 4]
X, y = make_blobs(n_samples=300, centers=5, n_features=2, random_state=6)

# Build the decision tree (adjust parameters if needed)
tree = build_decision_tree(X, y, max_depth=4, min_size=1)

# Evaluate the model
accuracy = sum(predict(tree, row) == row[-1] for row in np.column_stack((X, y))) / float(len(X))
print('Accuracy: %.3f' % accuracy)

# Visualize results
visualize_tree(X, y, tree)

Output:

Accuracy: 0.993

Conclusion

Decision trees offer a remarkably understandable way to learn patterns from data. Their strengths lie in interpretability, dealing with mixed data types, and capturing complex relationships. However, watch out for overfitting and instability. If you need to truly understand why your model is making predictions, or have a variety of feature types in your data, decision trees are a superb starting point in your machine learning journey!

--

--