Building a Decision Tree From Scratch with Python

Enozeren
8 min readOct 13, 2023

--

Decision Trees are machine learning algorithms used for classification and regression tasks with tabular data. Even though a basic decision tree is not widely used, there are various more sophisticated algorithms that are built on basic decision trees which are more common such as random forests and boosted trees.

Image 1 — Basic Decision Tree Structure — Image by Author — made with Canva

In this article I’m implementing a basic decision tree classifier in python and in the upcoming articles I will build Random Forest and AdaBoost on top of the basic tree that I have built here.

The Decision Tree Medium Article series:

  1. Building a Decision Tree From Scratch (This article)
  2. Building a Random Forest From Scratch (Click here)
  3. Building AdaBoost (Boosted Trees) From Scratch (Click here)

Brief History of Decision Tree

The history of a decision tree goes back to 1950s. William Welson has introduced some kind of a tree structure in 1959 in his “Matching and prediction on the principle of biological classification” paper. After than structured algorithms were proposed by Ross Quinlan (ID3, C4.5) and Leo Breiman (CART) in 1980s. After 90s, some scientist proposed more advanced algorithms which are built upon simple decision trees. In this article we’ll be using the CART algorithm.

Building a Decision Tree From Scratch

A tree is basicly structured as the Image 1. In this implementation we will build a decision tree classifier. Therefore, the output of the tree will be a categorical variable.

NOTE: To see the full code, visit the github code by clicking here. In this article we won’t go over all the code.

To create our tree from scratch first we create a class called DecisionTree in python. To train our tree we will develop a “train” function and after training to predict an output we will develop a “predict” function.

Image 2 — DecisionTree class

In the Image 2 above, we see the functions we need for our class. We will go through the important ones together.

Decide What is a Good Split

The main idea for training a decision tree is about how you will split your data. The best split is basically a split where after spliting each split has data points where every data point has the same class. We should define a numerical metric that gives this property. Even though the “Gini Impurity” is the default metric used for splitting in the CART algorithm, in this implementation we will use entropy metric. Entropy is calculated with the function below (Grus, 2019)[1]:

    def _entropy(self, class_probabilities: list) -> float:
return sum([-p * np.log2(p) for p in class_probabilities if p>0])

This metric will give us a low entropy if the splitted data group have a dominant class, otherwise it will give us a high entropy (see the Image 3 below).

Image 3 — Entropy Example — Image by Author — made with Canva

We will be searching for splits which have the lowest entropy. Basically “find_best_split” function is built upon the entropy metric.

Create Candidate Splits to Compare

Now when we see 2 split we can tell which one is better by looking at the entropy score (lower is the better). For example when you have 2 features with numerical values, and you can compare splitting with feature 1 median value and feature 2 median value and check the entropies of resulting groups. If splitting with feature 1 yields lower entropy then you go with feature 1. After this split in each level of the tree you use this algorithm to decide with which feature to split. This approach is a greedy approach since we choose the best split in each level of the tree. Greedy approach has some disadvantages and some solutions for those disadvantages are proposed but in this implementation we stick with the greedy approach for the sake of simplicity.

Here is our function for finding the best split.

      def find_best_split(self, data: np.array) -> tuple:
"""
Finds the best split (with the lowest entropy) given data
Returns 2 splitted groups
"""
min_part_entropy = 1e6
min_entropy_feature_idx = None
min_entropy_feature_val = None

for idx in range(data.shape[1]-1):
feature_val = np.median(data[:, idx])
g1, g2 = self.split(data, idx, feature_val)
part_entropy = self.partition_entropy([g1[:, -1], g2[:, -1]])
if part_entropy < min_part_entropy:
min_part_entropy = part_entropy
min_entropy_feature_idx = idx
min_entropy_feature_val = feature_val
g1_min, g2_min = g1, g2

return g1_min, g2_min, min_entropy_feature_idx, min_entropy_feature_val, min_part_entropy

Creating the Tree by Splitting Recursively

To create a tree we need a basic node structure where it has left and right nodes. Therefore, we create the “TreeNode” class below where we will store the split information and the relation to the left and right nodes.

class TreeNode():
def __init__(self, data, feature_idx, feature_val, prediction_probs, information_gain) -> None:
self.data = data
self.feature_idx = feature_idx
self.feature_val = feature_val
self.prediction_probs = prediction_probs
self.information_gain = information_gain
self.left = None
self.right = None

Now we have everything we need for creating a tree. We can write a recursive function which creates our tree and eventually stops if stopping criterions satisfied.

    def create_tree(self, data: np.array, current_depth: int) -> TreeNode:

# Check if the max depth has been reached (stopping criteria)
if current_depth >= self.max_depth:
return None

# Find best split
split_1_data, split_2_data, split_feature_idx, split_feature_val, split_entropy = self.find_best_split(data)

# Find label probs for the node
label_probabilities = self.find_label_probs(data)

# Calculate information gain
node_entropy = self.entropy(label_probabilities)
information_gain = node_entropy - split_entropy

# Create node
node = TreeNode(data, split_feature_idx, split_feature_val, label_probabilities, information_gain)

# Check if the min_samples_leaf has been satisfied (stopping criteria)
if self.min_samples_leaf > split_1_data.shape[0] or self.min_samples_leaf > split_2_data.shape[0]:
return node
# Check if the min_information_gain has been satisfied (stopping criteria)
elif information_gain < self.min_information_gain:
return node

current_depth += 1
node.left = self.create_tree(split_1_data, current_depth)
node.right = self.create_tree(split_2_data, current_depth)

return node

Stopping Criterions

As we have seen in the “create_tree” function, the splitting is performed recursively, therefore we need to tell the function when to stop splitting. The criterions for that is called stopping criterions. In this implementations we have used 3 stopping criteria.

  1. The depth of the tree (max_depth)
  2. The minimum acceptable number of samples in the leaf after split (min_samples_leaf)
  3. The minimum acceptable information gain (min_information_gain)

These stopping criterions are defined when we create the DecisionTree object. See the __init__ function for that.

    def __init__(self, max_depth=4, min_samples_leaf=1, min_information_gain=0.0) -> None:
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.min_information_gain = min_information_gain

Training

Since we have all the functions we need for training we can take a look at the train function. It only takes the training sets, starts the recursive creating tree process and stores the first node of the tree (a.k.a. root).

    def train(self, X_train: np.array, Y_train: np.array) -> None:

# Concat features and labels
self.labels_in_train = np.unique(Y_train)
train_data = np.concatenate((X_train, np.reshape(Y_train, (-1, 1))), axis=1)

# Start creating the tree
self.tree = self.create_tree(data=train_data, current_depth=0)

Predicting

When making the prediction, we will calculate the probability of the sample belonging to any class. Each leaf has constant probabilities for each label and those probabilities are learned in the training phase. To make a new prediction we will take the unlabelled data and start instructions from the first node (root) and follow the path which the unlabelled data satisfies. We can do it by using a while loop which stops when there is no other node to go (or in another words reaching the leaf node).

After predicting the probabilities, the “predict” functions returns the most probable class as the prediction.

    def predict_one_sample(self, X: np.array) -> np.array:
"""Returns prediction for 1 dim array"""
node = self.tree

# Finds the leaf which X belongs
while node:
pred_probs = node.prediction_probs
if X[node.feature_idx] < node.feature_val:
node = node.left
else:
node = node.right

return pred_probs

def predict_proba(self, X_set: np.array) -> np.array:
"""Returns the predicted probs for a given data set"""

pred_probs = np.apply_along_axis(self.predict_one_sample, 1, X_set)

return pred_probs

def predict(self, X_set: np.array) -> np.array:
"""Returns the predicted probs for a given data set"""

pred_probs = self.predict_proba(X_set)
preds = np.argmax(pred_probs, axis=1)

return preds

Performance of Our Tree

Now let’s test how is our tree performing on some easy benchmarks.

Iris Dataset

In this dataset we have 4 features and 1 label

Iris Dataset

Let’s train our tree on 75% of the data.

DecisionTree trained on Iris Dataset

After training we can print and see the Tree. When the condition satisfied at a node, you go left, when not satisfied you go right.

We can see that most of the leafs has only 1 class so our algorithm learned an okay tree. Let’s see the accuracy.

Performance on Iris Dataset

We have 94% accuracy for the train set and 89% accuracy for the test set. That is okay for such a naive from scratch implementation.

Breast Cancer Dataset

Let’s test the model with the breast cancer data which has a lot more features.

Breast Cancer Dataset
DecisionTree trained on Breast Cancer Dataset
Performance on Breast Cancer Dataset

We have around 94% accuracy for both train and test datasets.

Conclusion

We have seen how we can implement a basic classification tree in this medium article and it performs reasonably well on some datasets. In the upcoming articles we will implement more sophisticated decision tree ideas such as random forest and boosted trees. Check out the links below.

The Decision Tree Medium Article series:

  1. Building a Decision Tree From Scratch (This article)
  2. Building a Random Forest From Scratch (Click here)
  3. Building AdaBoost (Boosted Trees) From Scratch (Click here)

References

  1. Grus, Joel. DATA SCIENCE from SCRATCH : First Principles with Python. O’Reilly Media, 2019.
  2. Quinlan, J. Ross. “Induction of decision trees.” Machine learning 1 (1986): 81–106.
  3. Breiman, L. (1984). Classification and Regression Trees (1st ed.). Routledge. https://doi.org/10.1201/9781315139470

Thanks for Reading!

Let me know if you have any comments 🙂

--

--