Supervised Learning in R: Regression Trees

Fatih Emre Ozturk, MSc
6 min readOct 27, 2023

--

Regression trees are a regression technique based on decision trees, which utilize a tree structure that uses data to predict a specific target variable. This tree structure divides the data in a series of decision nodes and eventually predicts the target variable.

In the previous posts of our Supervised Learning in R series, we talked about linear regression. If we compare linear regression with regression tree, it can be said that regression tree is an easier algorithm to understand. The fact that the tree structure can be represented and interpreted graphically also makes it easier to interpret the results and the model. The automatic determination of which variables are more important for predicting the target variable and fewer assumptions compared to linear regression are other reasons for choosing regression tree over linear regression.

In this post, we will examine in detail how the regression tree is calculated step by step and how it is implemented in R.

Regression Trees | Step by Step

At first, assume that we have a data set like the following:

All the yellow points you saw is called predictor and the graph you saw represents predictor spaces. In the first step of the regression tree calculation, we divide predictor space into k distinct and non-overlapping regions like the following(for the sake of the example assume that we divide predictor space into k=3 regions):

Theoretically, regions can have any shape. However, for simplicity and ease of interpretation of the results of the estimated model, it is preferable to divide them into high dimensional rectangles or boxes. When determining the regions, the aim is to find the number of boxes that minimizes the RSS (Residual Sum of Squares) value. This continues for each recursive binary splitting. To go into more detail;

  • For any j and s, we define the pair of half planes as:
  • and we seek for values of j and s that minimize the following value:

Now back to our example above… In step 1, we obtained three different regions. Assume that the first region has an average training response of 10, the second region 20 and the third region 30. For any given observation, a value of 10 will be predicted if it is in the first region, 20 if it is in the second, and 30 if it is in the third.

After all of these, we can choose featues and pruning points.

Tree Pruning

In our example above, the predictor space was divided into three different regions. But what if it was split into 10 or even 15 regions? This is precisely why, when a very large tree is obtained, we can prune the regression tree to get a sub-tree. The underlying logic is that a smaller tree with fewer splits can lead to lower variance and better interpretation at the cost of some bias.

When pruning the tree, it is important at this point which subtree to choose. The goal in this process is to select a subtree that leads to the lowest test failure rate. So how do we calculate the test error? One method for this is to use cross validation. However, estimating the error of each subtree using cross validation can be quite cumbersome. Cost complexity pruning, i.e. weakest link pruning, can be used alternatively. Instead of considering each subtree, a sequence of trees indexed by alpha, a non-negative tuning parameter that controls the trade-off between the complexity of the subtree and its fit to the training data. A subtree is obtained for each value of alpha in order to obtain the following function as small as possible. cross validation is used to select a value for alpha.

Regression Trees in R

In the application of regression trees in the R programming language, we will use Hitters dataset from ISLR package. Using this dataset, which contains various information about baseball players, we will try to predict players’ salaries. We will have two independent variables: year and hits. Year is the number of years a player has played in the major leagues, while hits is the number of hits in the last year. Since there are NA’s in the data set. We will just remove them. To make it more bell-shaped distribution, we will also log-transfrom salary.

library(ISLR) # library for data set
df <- Hitters
df <- na.omit(df) # removing NAs
df$Salary <- log(df$Salary) # log transformation

We can build regression trees by using tree function from tree package as:

library(tree)
tree_hitters <- tree(Salary~Years+Hits, data = df)

To visualize tree, we can use ordinary plot() function. However, to see the nodes, we have to use text() function too.

plot(tree_hitters)
text(tree_hitters)

With the summary() function, we can see output of the model as follows:

summary(tree_hitters)
Regression tree:
tree(formula = Salary ~ Years + Hits, data = df)
Number of terminal nodes: 8
Residual mean deviance: 0.2708 = 69.06 / 255
Distribution of residuals:
Min. 1st Qu. Median Mean 3rd Qu. Max.
-2.2400 -0.2980 -0.0365 0.0000 0.3233 2.1520

This output provides a summary of the reconstruction and performance of the regression tree. You can use this information to understand how well the model makes predictions and which variables were used in building the model. The residual deviation and residual distribution values are important for assessing the predictive ability of the model. Every component of the output can be interpreted as follows:

  • Regression tree section describes the model of the regression tree used. The model uses the variables “Years” and “Hits” to predict the variable “Salary”. This indicates which independent variables the tree uses to predict the target variable.
  • “Number of terminal nodes” indicates the number of terminal nodes (leaves) in the tree structure. In this model, there are a total of 8 terminal nodes. Terminal nodes are the last nodes where the predictions are made.
  • “Residual mean deviance” value is a statistical measure of how good the model’s predictions are. A lower value is a sign of a better model. In this model, the residual mean deviance was calculated as 0.2708.
  • “Distribution of residuals” section provides a summary of how much the model’s predictions deviate from the actual values. Statistically, it describes the distribution of residual values. Since the “Mean” value is 0.0000, it can be said that the model makes accurate predictions on average.

Tree Pruning in R, tree

As mentioned above, to decide alpha number for pruning, we can use cross validation as follows:

cv_hitters <- cv.tree(tree_hitters)
plot(cv_hitters$size,cv_hitters$dev,type='b')

As you can see from the plot, after 4 the deviation value stops decreasing. For this reason, we should prune the tree to 4 nodes. Now, lets prune the tree:

prune_hitters=prune.tree(tree_hitters,best=4)
plot(prune_hitters)
text(prune_hitters,pretty=0)

Now, it looks better.

Just like always:

“In case I don’t see ya, good afternoon, good evening, and good night!”

Reference and Further Reading:

James, G., Witten, D., Hastie, T., & Tibshirani, R. (2013). An introduction to statistical learning (Vol. 112, p. 18). New York: springer.

--

--