Tree-Based Machine Learning

Decision Tree — Part 1

Rina Buoy
The Startup
6 min readJan 12, 2020

--

Trees & Temple Ruins (Cambodia)

I had always underestimated the power of the decision tree algorithm until I started to write this post. Although the decision tree is algorithmically simple, compared with SVM, Neural Network etc., the decision tree is undoubtedly intuitive and easy to implement. However, simplicity comes with a cost of overfitting (high variance). To overcome overfitting, random forest and boosted trees algorithms come to our rescue.

This is the first part of the decision tree series which covers the following:

Decision Tree — Part 1 (this post)

Random Forest — Part 2 (next post)

Boosted Trees — Part 3 (next post)

Decision Tree Algorithm

Binary Tree Structure

For the sake of convenience, we will explain the decision tree from a classification perspective and we will show how to use the decision tree regression afterwards.

Here are some characteristics of a classification tree (source):

1-sequence of if-else questions about individual features

2-objective: infer class labels

3-able to capture non-linear relationships between features and labels

4-does not require feature scaling (ex: Standardization, ..)

An example of a binary decision tree graph for breast cancer dataset is given below.

source

A decision tree consists of nodes and a node can be either a question or prediction, depending on its position. There are three types of node:

1-Root: the node at the top of the tree structure. It is a question which gives rise to two children nodes

2-Internal node: the parent node which gives rise to two children nodes.

3-Leaf: the node at the lowest structure of the tree structure. There is no further children node. It gives a prediction.

source

Attribute Splitting

The root node or internal nodes represent the splitting attribute (feature). Intuitively, the best splitting attribute is the one which makes data samples of the children nodes as pure as possible. Pure means that the data samples belong to the same class/label.

If the attribute is binary (True/False), splitting is straightforward. If the attribute is continuous, the best splitting point must be determined. An example of a continuous attribute splitting is given below.

A decision tree is built by recursively and greedily branching out at the best attribute until the following criteria are met.

1- the data samples under the lowest node are pure (from the same class/label)

2-the lowest node cannot be further split (not enough data samples)

3-the tree depth is beyond the maximum depth (deeper tree -> overfitting)

source

Information Gain & Purity Measure

At each branching step, the best splitting attribute is the one with the highest Information Gain (GI). GI of an attribute (f)at an arbitrary split-point (sp), is given by :

source

I is the impurity index. N is the total number of data samples under the parent node. The common impurity indices are Gini index and entropy, the formulas of which are given below:

source

Sample calculations of Gini index and entropy for different cases of varying impurity level are shown below. Gini and entropy are at a minimum value (zero) when a node contains data samples from the class, and at a maximum value when a node contains even class distribution.

source

Regression Tree

Gini index or entropy applies only to classification problems. For regression problems, the impurity index at a given node is given by mean-squared error (MSE).

source

where,

For prediction (at a leaf node),

Simple Python Implementation

The below Python codes implement a simple version of the decision tree classifier. Gini index is used to measure node impurity. Scikit-learn offers more optimised and sophisticated decision tree classifier and regressor. However, for a pedagogical purpose, this implementation would help us understand how the decision tree works under the hood.

Decision Tree in Scikit-Learn

Classification Tree

To demonstrate the usage of decision tree classifier in sk-learn, we use Breast Cancer Diagnosis dataset. The dataset has 30 numerical attributes plus a label column which takes two possible classes — benign and malignant. There are 569 rows. Here are the Python codes.

# Import DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier
# Import train_test_split
from sklearn.model_selection import train_test_split
# Import accuracy_score
from sklearn.metrics import accuracy_score
# Import dataset
from sklearn.datasets import load_breast_cancer
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus
#Get feature, label from breast cancer dataset
data = load_breast_cancer()
X= data.data
y=data.target
# Split dataset into 80% train, 20% test
X_train, X_test, y_train, y_test= train_test_split(X, y,test_size=0.2,stratify=y,random_state=1)
# Instantiate dt, set 'criterion' to 'gini'
dt = DecisionTreeClassifier(criterion='gini',max_depth=5, random_state=1)
# Fit dt to the training set
dt.fit(X_train,y_train)
# Predict test-set labels
y_pred= dt.predict(X_test)
# Evaluate test-set accuracy
print(accuracy_score(y_test, y_pred))
dot_data = StringIO()
export_graphviz(dt, out_file=dot_data,
filled=True, rounded=True,
special_characters=True,feature_names = data.feature_names,class_names=data.target_names)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('dt_classification_graph.png')
Image(graph.create_png())

By the running the above codes, the resulting accuracy is ~96%. The codes also export the tree diagram as shown below.

Regression Tree

We use Boston house-prices dataset to demonstrate regression tree. Python codes are given below.

# Import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
# Import train_test_split
from sklearn.model_selection import train_test_split
# Import accuracy_score
from sklearn.metrics import mean_squared_error
# Import dataset
from sklearn.datasets import load_boston
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus
import matplotlib.pyplot as plt
#Get feature, label from breast cancer dataset
data = load_boston()
X= data.data
y=data.target
# Split dataset into 80% train, 20% test
X_train, X_test, y_train, y_test= train_test_split(X, y,test_size=0.2,random_state=1)
# Instantiate dt, set 'criterion' to 'gini'
dt = DecisionTreeRegressor(min_samples_leaf=0.1, max_depth=5, random_state=1)
# Fit dt to the training set
dt.fit(X_train,y_train)
# Predict test-set labels
y_pred= dt.predict(X_test)
# Evaluate test-set accuracy
print(mean_squared_error(y_test, y_pred))
plt.scatter(y_pred,y_test,)
plt.ylabel('actual')
plt.xlabel('predicted')
dot_data = StringIO()
export_graphviz(dt, out_file=dot_data,
filled=True, rounded=True,
special_characters=True,feature_names = data.feature_names)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('dt_regression_graph.png')
Image(graph.create_png())
plt.show()

Here is the resulting cross-plot of predicted house price and actual house price:

Here is the resulting regression tree diagram.

Wrap-up

A decision tree is intuitive, interpretable yet powerful machine learning model. However, a decision tree is very sensitive to hyper-parameter values. Poor hyper-parameter values lead to either under-fitting (high bias) or over-fitting(high variance). The process of choosing the right set of hyper-parameters is known as ‘hyper-parameters tuning’.

--

--

Rina Buoy
The Startup

An applied NLP researcher at Techo Startup Center (TSC)