Decision Tree Algorithm from 0

UmarSunny
5 min readJul 20, 2024

--

Decision Tree is a Supervised learning Algorithm which is mostly used for classification problems. It works for both continuous and categorical input variables. We’ll see how its implemented from scratch without using prebuild libraries like scikit-learn. Here, we’ll see a multiclass classification use case for the same.

Important terminologies to keep in mind while dealing with decision trees:

Root Node: The topmost node representing the entire population or sample in a decision tree, which subsequently divides into two or more subsets.

Decision Node: A node that splits into further sub-nodes, indicating a decision point in the tree.

Leaf/Terminal Node: A node that does not split into any further sub-nodes, representing a final decision or outcome in the tree.

Pruning: The process of removing sub-nodes from a decision node to simplify the model, effectively the opposite of splitting.

Branch/Sub-Tree: A section of the decision tree, representing a subset of the entire structure, starting from a node and including all its descendants.

Entropy: A measure of the randomness or impurity in a dataset, used to determine the best way to split the data in decision trees.

Gini Index: A metric that measures the impurity of a node in a decision tree, calculated as the probability of a randomly chosen element being incorrectly classified.

Information Gain: The reduction in entropy or impurity from a split, used to determine the most informative feature to split on at each step in the decision tree.

Maximum Depth: The maximum number of levels a decision tree can have, which can be controlled to prevent overfitting.

Minimum Samples Split: The minimum number of samples required to split an internal node, used to control the growth of the decision tree and avoid overfitting.

Minimum Samples Leaf: The minimum number of samples required to be at a leaf node, which helps prevent the model from creating leaf nodes with too few samples.

Feature Importance: A score that indicates how important each feature is for the decision tree in predicting the target variable, often used for feature selection.

Cross-Validation: A technique for assessing how a decision tree model will generalize to an independent dataset, typically by dividing the data into multiple training and test sets.

Bagging (Bootstrap Aggregating): An ensemble technique that involves training multiple decision trees on different subsets of the data and combining their predictions to improve overall performance.

Random Forest: An ensemble method that combines multiple decision trees, each trained on random subsets of the data and features, to enhance prediction accuracy and control overfitting.

Hyperparameters: Parameters that are set before the learning process begins, such as maximum depth, minimum samples split, and minimum samples leaf, which control the structure and complexity of the decision tree.

Overfitting: A modeling error that occurs when a decision tree model captures noise or random fluctuations in the training data instead of the underlying pattern, leading to poor generalization to new data.

Underfitting: A scenario where a decision tree model is too simple to capture the underlying patterns in the data, leading to poor performance on both training and test data.

Gini Impurity:

It measures the frequency at which any element of the dataset will be misclassified when it is randomly labeled according to the distribution of labels in the subset. Calculated using the following formula:

Gini impurity
  • C is the total number of classes.
  • pi​ is the probability of picking an element of class i (i.e., the proportion of items of class i in the set).

Steps to our Decision Tree Algorithm:

  1. Create a class (Node) that we will use to build our tree (node by node). Basically it is for the initializations of parameters like left node, right node, value etc. for any tree.
  2. Create another class that will contain all the Decision Tree functionalities including: building it by using the above Node class, gini impurity function, fitting the model and prediction function etc.
  3. Create a function for Gini impurity by using the formula as it is discussed above, that will help us decide the best split.
  4. Create a function (best_split) that will find the best features and threshold to split the data (X and y values) by using the Gini impurity function. Everytime it returns updated best_feature and best_threshold value to the build_tree function.
  5. Create a recursive function build_tree to build the complete tree node by node using the Node class and based on best_feature and best_threshold values returned by best_split function.
  6. Create fit, predict_sample and predict functions to fit and make predictions using the built tree.
import numpy as np
import pandas as pd

class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.value = value
class DecisionTree:
def __init__(self, max_depth=5):
self.max_depth = max_depth
self.root = None

@staticmethod
def gini_impurity(y):
classes = np.unique(y)
impurity = 1.0
for c in classes:
prob_c = np.sum(y == c) / len(y)
impurity = impurity - prob_c ** 2
return impurity

def best_split(self, X, y):
best_gini = float('inf')
best_feature = None
best_threshold = None
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_mask = X[:, feature] < threshold
right_mask = X[:, feature] >= threshold
if np.any(left_mask) and np.any(right_mask):
gini = (len(y[left_mask]) * self.gini_impurity(y[left_mask]) +
len(y[right_mask]) * self.gini_impurity(y[right_mask])) / len(y)
if gini < best_gini:
best_gini = gini
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold

def build_tree(self, X, y, depth=0):
if len(np.unique(y)) == 1 or depth == self.max_depth:
return Node(value=np.unique(y)[0])

feature, threshold = self.best_split(X, y)
if feature is None:
return Node(value=np.unique(y)[0])

left_mask = X[:, feature] < threshold
right_mask = X[:, feature] >= threshold

left_node = self.build_tree(X[left_mask], y[left_mask], depth + 1)
right_node = self.build_tree(X[right_mask], y[right_mask], depth + 1)

return Node(feature, threshold, left_node, right_node)

def fit(self, X, y):
self.root = self.build_tree(X, y)

def predict_sample(self, sample, node):
if node.value is not None:
return node.value
if sample[node.feature] < node.threshold:
return self.predict_sample(sample, node.left)
else:
return self.predict_sample(sample, node.right)

def predict(self, X):
return np.array([self.predict_sample(sample, self.root) for sample in X])

Sample Dataset to test our multiclass classification ML model:

Take the dataset provided by XYZ Telecom Company for the purpose of analyzing customer churn as an example. It includes information on customer age, account balance, and tenure with the company, along with an indicator of whether the customer has churned (i.e., stopped using the company’s services)

Finally use the above class in your main function and make predictions:


if __name__ == "__main__":

data = pd.read_csv('churn_data.csv')
X = data.iloc[:, :-1].values # Features, i.e columns 0,1,2
y = data.iloc[:, -1].values # Labels, i.e column 3

# Initialize and fit the decision tree
tree = DecisionTree(max_depth=5)
tree.fit(X, y)

# Predict for a new sample of values
sample1 = np.array([25, 1200.50, 1]) # Gives False
sample2 = np.array([35, 3000.50, 4]) # Gives True

prediction = tree.predict(np.array([sample1]))
print(f'Predicted class: {prediction[0]}')

Conclusion

This code demonstrates a simple implementation of a decision tree classifier from scratch. You can enhance it by adding features such as pruning, directly handling categorical variables, and optimizing the splitting criteria.

There You Go! Congratulations!⛳🏆

If you found this helpful, a humble clap would mean the world.

Thank you for your time to read this blog. Enjoy Data Science!

--

--

UmarSunny

AI Machine Learning Deep Learning Data Science Lightning Speed