Decision Trees (Part 2)

Dr. Roi Yehoshua
AI Made Simple
Published in
10 min readMar 21, 2023

--

In the first part of this article we discussed what decision trees are and how to build them from a given data set. In this part we will demonstrate how to use and configure decision trees in Scikit-Learn, how to deal with overfitting in decision trees using tree pruning, and how to use decision trees for regression problems.

Decision Trees in Scikit-Learn

The DecisionTreeClassifier class in sklearn.tree provides a decision tree classifier. Its implementation is based on the CART (Classification and Regression Trees) algorithm [1]. However, it currently does not support categorical features, so you will need to convert your categorical variables into numerical ones before training the model (e.g., by using OneHotEncoder).

In the constructor of DecisionTreeClassifier you can specify the impurity measure that will be used for the node splits with a parameter named criterion. The available options are ‘gini’ for Gini index (the default), ‘entropy’ for information gain, and ‘log loss’. Other important parameters of the constructor will be discussed later.

For example, let’s train a decision tree classifier on the Iris data set, but using only the first two features of each flower (the sepal length and sepal width), which makes this classification problem more challenging.

We first load the data set:

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data[:, :2] # we only take the first two features
y =…

--

--

Dr. Roi Yehoshua
AI Made Simple

Teaching Professor for Data Science and ML at Northeastern University | Top Writer in AI | 200K+ Views on Medium | https://www.linkedin.com/in/roi-yehoshua/