A Super Simple Explanation to Regression Trees and Random Forest Regressors

Sreevidya Raman
Analytics Vidhya
Published in
5 min readApr 18, 2021

--

Objective

What I try to achieve through this article (and hopefully what will evolve into an entire series) is a comprehensive walk-through of basic ML concepts. To the extent possible, these concepts will be illustrated through toy examples conducted on Python and/or worked out by hand. My main source of information (and inspiration) is the very funny, very knowledgeable Josh Starmer.

Audience

Anyone new to Data Science/feeling rusty on their ML basics.

What is a Regression Tree (aka. Decision Tree Regressor)?

Regression Trees are a very intuitive and simplistic algorithm used to deal with problems that have a continuous Y variable.

What kinds of data can they use as features?

  • Numeric
  • Categorical
    - Binary
    - Multiclass

Toy example

In this toy example, we’re trying to predict life expectancy based on annual income. To keep things simple, we’re going to have just this one X variable that happens to be continuous. If you want to know how to work with categorical variables, you can check out A Super Simple Explanation to Decision Tree Classifier which has detailed, hand-worked calculations.

Clearly, a linear regression will not work here¹. It looks like an increase in annual income leads to an increase in life expectancy only up to a point, after which it starts to fall. An inverse U-shape would fit this data better than a line.

So how does a Regression Tree figure out where to make the splits?

Step 1. Sort the X variable (income) in ascending order.
Already done.

Step 2. Find the means between subsequent Xs.
These means are the potential thresholds on which to split. So in this example, we will not need to make more than three splits to make a final prediction.

If you’ve read A Super Simple Explanation to Decision Tree Classifier, then Steps 1 & 2 will be familiar to you.

Step 3. Calculate the mean Ys (life expectancy) and Sum of Squared Errors (SSEs) corresponding to every split in X.

Split on X=6.75:

Split on X=7.5:

Split on X=8.25:

Clearly, the split at X = 8.25 results in the lowest SSE (=50), so that will be the first threshold to split on.

We see that the node and the leaves in the diagram also mention the MSE, so let’s just quickly see what that’s all about.

When we’re at the node, we haven’t made any splits yet. At this point, the naïve prediction is the average of Life Expectancies = (75 + 85 + 80 + 67)/4 = 76.75.

Corresponding to this naïve prediction, the MSE for all observations is:

After splitting, the MSE in the leaves are:

Step 3. Repeat Step 2 for each of the leaves separately until we have built a full-grown Decision Tree.

Because we haven’t restricted the length of our tree, it makes perfect predictions and so the MSE at every leaf is 0.

Random Forest Regressors

Now, here’s the thing. At first glance, it looks like this is a brilliant algorithm to fit to any data with a continuous dependent variable, but as it turns out, Decision Trees are very prone to overfitting (they fit well on the train data, but not so much on the test data). To mitigate this problem, we can build Random Forests. For a detailed discussion on how Random Forest works, head over to A Super Simple Explanation to Random Forest Classifier. The are exactly two differences between Random Forest Classifiers and Random Forest Regressors:

  1. The way that the splits are decided. While in Random Forest Classifiers, splits are based on entropy, in Random Forest Regressors, they’re based on MSE.
  2. The aggregation methodology. In classification problems, Random Forests employ what is called a ‘majority vote’ whereby the prediction that is most common for an observation across trees is the observation’s final prediction. In regression problems on the other hand, the mean of predictions across trees will be the final prediction for the observation.

The tediousness of these calculations really make one appreciate the speed with which these algorithms work on Python, don’t they? You can find code for building and visualizing Regression Trees at: https://github.com/sreevidyaraman/Regression-Tree.

[1] For those of you uncomfortable with linear regression, you’re not alone. I have a master’s degree in Economics and linear regression still trips me up. But I’ll write an article on it soon and it should clear up a lot of the confusion around the topic. But for the time being, it’s enough for you to know that a linear regression would not work here because the relationship between income and life expectancy is not linear.

References

[1] Starmer, StatQuest with Josh. Regression Trees, Clearly Explained!!! YouTube, 20 Aug. 2019, https://www.youtube.com/watch?v=g9c66TUylZ4.

--

--

Sreevidya Raman
Analytics Vidhya

Senior Analytics Consultant at BRIDGEi2i Analytics Solutions | I write about Data Science, Data Science and everything in between