Cross Validation in Machine Learning

Today, we’re diving into the concept of cross-validation, a super handy technique in machine learning. Let’s make it fun and easy to understand.

Nidhi Gahlawat
ILLUMINATION
4 min readAug 15, 2024

--

Setting the Scene

Imagine we have this cool dataset with information about how well people sleep, how much water they drink, and their daily junk food habits.

Dataset to predict risk of disease. Has 20 rows.
Image by Author

This dataset helps us predict whether someone will end up with health conditions. Now, we have just 20 rows in this dataset. To find the best model among k-Nearest Neighbors (k-NN), Decision Trees, and Logistic Regression, we’ll use cross-validation.

Image by Author

What is Cross-Validation?

Cross-validation is like a reality check for our models. It helps us figure out how well our model will perform on new, unseen data.

Not in the mood to read? Watch it here:

The idea is to split the dataset into several parts, train the model on some parts, and test it on the remaining parts. It basically the question “which model should be used”

Image by Author

Types of Cross-Validation

1. K-Fold Cross-Validation

  • Picture this: we split our dataset into ‘k’ equally sized parts. Suppose ‘k’ is 4, so we divide the data into 4 folds. Each fold has 5 rows. Then, we train and test the model in such a way that each time we use one fold for testing and other folds for training. That means here, first the first three folds of dataset are used to train the KNN model and last set is used for testing. We store the performance of each iteration.
Image by Author
Image by Author
Image by Author
Image by Author
  • We repeat this until each part has had its turn as the testing set. Then we calculate the overall performance. This process is repeated for all the machine learning models we want to check.
  • This gives us a solid performance evaluation for all models we want to compare: k-NN, Decision Trees, and Logistic Regression. We check which model performs best on unseen data. Maybe Decision Trees rock this time!

2. Stratified K-Fold Cross-Validation

  • Similar to K-Fold, but it’s a bit smarter. It ensures each fold has a proportional representation of each class.
Disadvantage of K-fold, Image by Author
  • So If our dataset has this split where 40% are positive cases and 60% negative cases. Then in case of k-fold some sets would have all negative cases and some would have all positive which would hamper the working.
  • So here’s how training and testing data would look like for k=4 in stratified k-fold cross validation where same ratio of positive and negative is maintained in both training and testing data. It’s like making sure each team in a game has a fair mix of players.
Image by Author

3. Leave-One-Out Cross-Validation (LOOCV)

Here’s a meticulous approach: each data point gets a moment to shine as the testing set, while the rest train the model. Perfect for tiny datasets, but watch out — it can be quite a task since it involves many repetitions. For our 20-row dataset, we train 20 times, each time with 19 rows for training and 1 row for testing.

Image by Author

4. Repeated K-Fold Cross-Validation

This one’s all about consistency. We repeat the K-Fold process multiple times with different random splits. For example, if k=4 and we repeat this 5 times, we end up with 20 training and testing iterations. It’s like checking your work multiple times to ensure robustness.

Wrapping It Up

So, there you have it! Cross-validation is essential for evaluating and selecting models in machine learning. By applying it to our dataset, we can confidently choose the best model among k-NN, Decision Trees, and Logistic Regression. This ensures our chosen model will perform well on new, unseen data, giving us reliable and robust predictions.

--

--

Nidhi Gahlawat
ILLUMINATION

Software Engineer, I write about machine learning, AI, iOS dev, programming languages and everything in between | Coffee keeps me alive!