Building a Decision Tree from Scratch in Python| Machine Learning from Scratch (Part III)
Build a better house price prediction model using a Decision Tree
TL;DR Build a Decision Tree regression model using Python from scratch. Compare the performance of your model with that of a Scikit-learn model. The Decision Tree is used to predict house sale prices and send the results to Kaggle.
Machine Learning from Scratch series:
- Smart Discounts with Logistic Regression
- Predicting House Prices with Linear Regression
- Building a Decision Tree from Scratch in Python
I am sorry, you might be losing sleep. Deep down you know your Linear Regression model ain’t gonna cut it. That housing market domination is still further down the road.
Can we improve that, can we have a model that makes better predictions?
Complete source code notebook for Patreons (Google Colaboratory):
Hello everybody, My name is Venelin and I am thrilled to invite you on a journey through the amazing world of Machine…www.patreon.com
Once again, we’re going to use the Kaggle data: “House Prices: Advanced Regression Techniques”. It contains 1460 training data points and 80 features that might help us predict the selling price of a house.
Decision tree models build structures like this:
The algorithms for building trees breaks down a data set into smaller and smaller subsets while an associated decision tree is incrementally developed. The final result is a tree with decision nodes and leaf nodes. A decision node has two or more branches. Leaf node represents a classification or decision (used for regression). The topmost decision node in a tree which corresponds to the best predictor (most important feature) is called a root node.
Decision trees can handle both categorical and numerical data. They are used for classification and regression problems. They can handle missing data pretty well, too!
We’re going to use the same data we used with the Linear Regression model. However, we’re not going to do any scaling, just because we’re lazy (or it is not needed):
We’re going to use a new cost function — Root Mean Square Error (RMSE). It is the standard deviation of how far from the regression line data points are. In other words, it tells you how concentrated the data is around the line of best fit.
RMSE is given by the formula:
We’ve already implemented MSE in previous parts, so we’re going to import an implementation here, in the name of readability (or holy laziness?):
Using a prebuild Decision Tree model
Let use Decision Tree regressor from the scikit-learn library to get a quick feel of the model:
We are using
RandomForestRegressor with 1 estimator, which basically means we’re using a Decision Tree model. Here is the tree structure of our model:
You should receive the exact same model (if you’re running the code) since we are setting the random state. How many features did the model use?
Now that this model is ready to be used let’s evaluate its R² score:
The R² statistic gives us information about the goodness of fit of the model. A score of
1 indicates perfect fit. Let’s have a look at the RMSE:
Building your own Decision Tree
Let’s start implementing our
Node helper class:
Trees are recursive data structures and we’re going to take full advantage of that. Our
Node class represents one decision point in our model. Each division within the model has 2 possible outcomes for finding a solution — go to the left or go to the right. That decision point also divides our data into two sets.
idxs stores indexes of the subset of the data that this Node is working with.
The decision (prediction) is based on the
Node holds. To make that prediction we’re just going to take the average of the data of the dependent variable for this
find_varsplit finds where should we split the data. Let’s have a look at it:
First, we try to find a better feature to split on. If no such feature is found (we’re at a leaf node) we do nothing. Then we use the split value found by
find_better_split, create the data for the left and right nodes and create each one using a subset of the data.
Here are the
It is time to implement the workhorse of our algorithm
We are trying to split on each data point and let the best split wins.
We’re going to create our split such that it has as low standard deviation as possible. We find the split that minimizes the weighted averages of the standard deviations which is equivalent to minimizing RMSE.
If we find a better split we store the following information: index of the variable, split score and value of the split.
The score is a metric that tells us how effective the split was (note that leaf nodes do not have scores, so it will be infinity). The method
find_score calculates a weighted average of the data. If the score is lower than the previous we have a better split. Note that the score is initially set to infinity -> only leaf nodes and really shallow trees (and Thanos) have a score of infinity.
Finally, let’s look at how we use all this to make predictions:
Once again, we’re exploiting the recursive nature of life. Starting at the tree root,
predict_row checks if we need to go left or right node based on the split value we found. The recursion ends once we hit a leaf node. At that point, the answer/prediction is stored in the
Here is the complete source of our
Let’s check how your Decision Tree regressor does on the training data:
Here is the R² score:
Our scikit-learn model gave us a score of
The RMSE score is:
The scikit-learn regressor gave us
Looks like your model is doing pretty good, eh? Let’s make a prediction on the test data and send it to Kaggle.
Sending your predictions to Kaggle
Let make our predictions and format the data as requested on Kaggle:
Feel free to submit your
csv file to Kaggle. Also, how can you improve?
Give yourself a treat, you just implemented a Decision Tree regressor!
Would it be possible to implement Random Forest regressor on top of your model? How that affects your Kaggle scores?
In the next part, you’re gonna do some unsupervised learning with k means!
Like what you read? Do you want to learn even more about Machine Learning?