Building reliable machine learning models with cross-validation

Gideon Mendels
Aug 6, 2018 · 4 min read

Cross-validation is a technique used to measure and evaluate machine learning models performance. During training we create a number of partitions of the training set and train/test on different subsets of those partitions.

Cross-validation is frequently used to train, measure and finally select a machine learning model for a given dataset because it helps assess how the results of a model will generalize to an independent data set in practice. Most importantly, cross-validation has been shown to produce models with lower bias than other methods.

This tutorial will focus on one variant of cross-validation named k-fold cross-validation.

In this tutorial we’ll cover the following:

  1. Overview of K-Fold Cross-Validation
  2. Example using Scikit-Learn and Comet.ml

K-Fold Cross-Validation

Cross-validation is a resampling technique used to evaluate machine learning models on a limited data set.

The most common use of cross-validation is the k-fold cross-validation method. Our training set is split into K partitions, the model is trained on K-1 partitions and the test error is predicted and computed on the Kth partition. This is repeated for each unique group and the test errors are averaged across.

The same procedure is described by the following steps:

  1. Split the training set into K (K=10 is a common choice) partitions

For each partition:

2. Set the partition is the test set

3. Train a model on the rest of the partitions

4. Measure performance on the test set

5. Retain the performance metric

6. Explore model performance over different folds

Cross-validation is commonly used since it’s easy to interpret and since it generally results in a less biased or less optimistic estimates of the model performance than other methods, such as a simple train/test split. One of the biggest downsides in using cross-validation is the increased training time as we are essentially training K times instead of 1.

Cross-validation example using scikit-learn

Scikit-learn is a popular machine learning library that also provides many tools for data sampling, model evaluation and training. We’ll use the Kfold class to generate our folds. Here’s a basic overview:

Now let’s train an end-to-end example using scikit-learn and Comet.ml.

This example trains a text classifier on the news groups dataset (you can find it here). Given a piece of text (string), the model classifies it to one of the following classes: “atheism”,”christian”,”computer graphics”, “medicine”.

On every fold we report the accuracy to Comet.ml and finally we report the average accuracy of all folds. After the experiment finishes, we can visit Comet.ml and examine our model:

Image for post
Image for post

The following chart was automatically generated by Comet.ml. The right most bar (in purple) represents the average accuracy across folds. As you can see some folds preform significantly better than the average and shows how important k-fold cross validation is.

You might have noticed that we didn’t compute the test accuracy. The test set should not be used in any way until you’re completely finished with all experimentation. If we change hyperparameters or model types based on the test accuracy we’re essentially over-fitting our hyperparameters to the test distribution.

Still curious about cross-validation? Here are some other great resources:

Found this article useful? Follow us (Comet.ml) on Medium and check out some other relevant articles below! Please 👏 this article to share it!

Gideon Mendels is the CEO and co-founder of Comet.ml.

About Comet.ml — Comet.ml is doing for ML what Github did for code. Our lightweight SDK enables data science teams to automatically track their datasets, code changes, experimentation history. This way, data scientists can easily reproduce their models and collaborate on model iteration amongst their team!

Comet.ml

Build better models faster

Thanks to Cecelia Shao

Gideon Mendels

Written by

Co-founder/CEO of Comet.ml — a machine learning experimentation platform helping data scientists track, compare, explain, reproduce ML experiments.

Comet.ml

Comet.ml

Comet provides a self-hosted and cloud-based meta machine learning platform allowing data scientists and teams to track, compare, explain and optimize experiments and models.

Gideon Mendels

Written by

Co-founder/CEO of Comet.ml — a machine learning experimentation platform helping data scientists track, compare, explain, reproduce ML experiments.

Comet.ml

Comet.ml

Comet provides a self-hosted and cloud-based meta machine learning platform allowing data scientists and teams to track, compare, explain and optimize experiments and models.

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store