Decision Tree Regression: With Python Code

Sreejith P
3 min readFeb 24, 2023

--

Decision tree regression is a non-parametric machine learning algorithm that is used for both regression and classification tasks. In this blog, we will focus on decision tree regression, which involves building a decision tree to predict a continuous target variable. We will use Python and scikit-learn library to implement decision tree regression on a sample dataset.

  1. Importing the Required Libraries

Let’s begin by importing the required libraries. We will use pandas to load and manipulate the dataset, and scikit-learn to build and evaluate the decision tree regression model.

import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

2. Loading the Dataset

We will use the Boston Housing dataset, which is available in scikit-learn. The dataset contains 13 input features and a continuous target variable, which represents the median value of owner-occupied homes in thousands of dollars. We will load the dataset using the load_boston function from scikit-learn and convert it into a pandas DataFrame for easy manipulation.

from sklearn.datasets import load_boston
# load the dataset
boston = load_boston()
# create a DataFrame from the dataset
df = pd.DataFrame(boston.data, columns=boston.feature_names)
df['target'] = boston.target

3. Preparing the Data

Next, we will prepare the data for training and testing the decision tree regression model. We will split the dataset into training and testing sets, and separate the input features from the target variable.

# separate the input features from the target variable
X = df.drop('target', axis=1)
y = df['target']
# split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

(“random state” refers to a parameter that is used to initialize the random number generator used by the model. The random number generator is used for several purposes, such as shuffling the data, splitting the data into training and testing sets, and initializing the model parameters.

The random state parameter is used to ensure that the results of the model are reproducible. By setting the random state to a fixed value, we ensure that the model produces the same results every time we run it on the same data. This is important for comparing different models and for debugging)

4. Building the Model

We can now build the decision tree regression model using the DecisionTreeRegressor class from scikit-learn. We will set the maximum depth of the tree to 3, which means that the tree can have at most 3 levels of decision nodes. This helps to prevent overfitting and improves the generalization performance of the model.

# build the decision tree regression model
tree_reg = DecisionTreeRegressor(max_depth=3)
# fit the model to the training data
tree_reg.fit(X_train, y_train)

5. Evaluating the Model

We can evaluate the performance of the decision tree regression model using the mean squared error (MSE) metric, which measures the average squared difference between the predicted and actual target values. We will use the mean_squared_error function from scikit-learn to compute the MSE on the testing data.

# make predictions on the testing data
y_pred = tree_reg.predict(X_test)
# calculate the mean squared error on the testing data
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

6. Visualizing the Tree

We can visualize the decision tree regression model using the plot_tree function from scikit-learn’s tree module. This generates a graphical representation of the decision tree, which can help us understand how the model makes predictions based on the input features.

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# plot the decision tree
plt.figure(figsize=(10,6))
plot_tree(tree_reg, filled=True, feature_names=X.columns)
plt.show()

In this blog, we have implemented decision tree regression in Python using scikit-learn.

--

--