Cross Validation in R

Fatih Emre Ozturk, MSc
5 min readAug 18, 2023

--

Considering all observations in the dataset, it is an important question which observations will be used to train the machine learning model. One option would be to split the dataset randomly, with a certain proportion. However, randomly splitting the dataset may not always give good results. For instance, random splitting may not fully reflect the characteristics and distribution of the dataset. In particular, if the dataset contains unbalanced or rare classes, test set may not adequately represent these rare classes. In this case, the model’s ability to accurately learn rare classes may be weakened. Also when performing random splitting, there is a risk that information that does not need to be used for future predictions leaks into test data. This can cause the model to underperform in real-world applications. Lastly, model may have overfitted the data due to random splitting, or it may have overlearned on the training data. In this case, although it performs well on the test set, its ability to generalize to new data may be low. So, given all this, how should we separate the data set?

In cases where we do not have any prior information about how the train and test separation is made, the safest method would be Cross Validation. Rather than worry too much about which observations are good for train or which observbations are good for test, Cross Validation uses all observations for both in an iterative way, meaning that we use them in steps.

Step-by-Step Cross Validation

Instead of splitting the dataset into two parts, cross-validation usually involves splitting the data into smaller subsets and using these subsets in different ways to train and evaluate the model. This gives a better understanding of how the model performs against different data samples.

For the sake of the example, assume that we have a data set containing two variables: spending and income. The dataset would look like the following:

As you can see, since there is a trend that people have more income tend to spend more money, we want to use income to predict spending. So we decide to fit a line to the data with Linear Regression. Yet, we do not know which observations to use for test and which to use for train. We will use Cross Validation for this.

The first step is to randomly assing the data to different groups. In this example, we will divide the data into 4 groups where each group consist of 2 points. The following illustration shows how this process is done:

Now, in the first iteration, we will use Groups 1,2, and 3 for training and group 4 for the test. And this will look like the following:

Then, we we can measure the evaluation metrics fr each observation in the test. However, this do not stop here. We keep iterating this process until Groups 1,2 and 3 can each get turn to be used for test.

Since we have 4 groups of observations, we will do 4 iterations. The number of iteration in here also called Folds. Thus, this is called 4-fold cross validation.

For instance, assume that we iterate 4 times this data set. And we use SSR for the evaluation. Because each iteration uses different combination of observations(different groups), each iteration results in different fitted line. A different fitted line combined with using different data for test results in each iteration giving us different prediction errors.

In the last step of the cross validation, we can average prediction errors to get a general sense of how well this model will perform with future data. The groups used in the iteration with the lowest errors are selected as test and train.

Leave One Out Cross Validation: It uses all but one observation for training and uses the one remaning observation for test.

How do we decide if we should use k-fold or Leave-One-Out Cross Validation: If dataset is large, it is best to use 10-fold cross validation, and when the data set is small it is best to use Leave-One-Out Cross Validation.

Cross Validation in R

It is possible to do cross validation with the caret package in R. In this post, we will use the caret package again. But for a clearer understanding of all the steps described above, we will do cross validation in a more manual way. At first, let us create our data set:

# Data creation
spending <- c(200, 150, 300, 250, 180, 350, 280, 220, 280, 320)
income <- c(4000, 3500, 4500, 4000, 3800, 5000, 4200, 3700, 4300, 4800)
data <- data.frame(spending, income)

Since we used 4-fold cross validation for the example above, let us create folds using createFolds function from caret package:

set.seed(123) # Set seed value for randomness
folds <- createFolds(data$spending, k = 4, list = TRUE, returnTrain = TRUE)

Now, let us use a for loop to iterate each fold:

for (i in 1:4) {
cat("Iteration", i, "\n")

# Create training and test sets
train_indices <- unlist(folds[-i])
test_indices <- folds[[i]]

train_data <- data[train_indices, ]
test_data <- data[test_indices, ]

cat("Number of training data:", nrow(train_data), "\n")
cat("Test data count:", nrow(test_data), "\n")

# Train linear regression model
model <- lm(income ~ spending, data = train_data)

# Predict using test data
predictions <- predict(model, newdata = test_data)

# Calculate the sum of squares of error (SSR)
residuals <- test_data$income - predictions
SSR <- sum(residuals^2)
cat("SSR:", SSR, "\n\n")
}
Iteration 1 
Number of training data: 22
Test data count: 8
SSR: 191913.1

Iteration 2
Number of training data: 22
Test data count: 8
SSR: 204990.6

Iteration 3
Number of training data: 23
Test data count: 7
SSR: 105164.3

Iteration 4
Number of training data: 23
Test data count: 7
SSR: 170357.6

As you can see, when the SSR values of all iterations are analyzed, the third iteration has the lowest value. For this reason, the separation used in the third iteration stands out as the best separation. Now let’s access the index numbers of train and test observations after this separation:

# indexes of test 
folds[3]
$Fold3
[1] 2 3 4 5 7 9 10
# indexes of train
folds[-3]
$Fold1
[1] 1 2 3 4 5 6 8 10

$Fold2
[1] 1 2 4 6 7 8 9 10

$Fold4
[1] 1 3 5 6 7 8 9

Just like always:

“In case I don’t see ya, good afternoon, good evening, and good night!”

Reference and Further Reading

James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. An introduction to statistical learning. Vol. 112. New York: springer, 2013.

Starmer, J. (2022). The Statquest illustrated guide to machine learning!!!: master the concepts, one full-color picture at a time, from the basics all the way to neural networks. BAM!. (No Title).

--

--