You’re fit and you know it: overfitting and cross-validation

Andy Elmsley
The Sound of AI
Published in
4 min readMay 2, 2019

Welcome back machine-learners. Last time we finished off our brief tour of Artificial Neural Networks (ANNs) with a look into their hyperparameters. As part of that investigation, I advised you to always train multiple models before you stick with a hyperparameter choice. This week we’ll dig a bit further into this idea by exploring one of the most common pitfalls of machine learning — overfitting — and how to discover ways to be sure that you’re model is fit with cross-validation.

Clever Hans was able to perform basic maths, or was he?

Clever Hans — a machine learning fable

Before we dig in to some theory, here’s a true story about a horse named ‘Clever Hans’. As his name suggests, Hans wasn’t like other horses. His owner Wilhelm soon discovered that he could be trained to read basic equations on a chalkboard and tap the correct answer with his foot. This amazed Wilhelm, who promptly showed off Hans all across Germany. However, it was soon discovered that, unbeknown to Wilhelm, Hans only appeared clever, and couldn’t really do maths. Hans’ cleverness hinged on his ability to have an up-close, unobstructed view of the person who knew the correct answer. Take that away, and his accuracy plummeted to zero. Hans was just very good at reading people’s body language, which indicated to him to stop tapping his foot.

So, what does this tell us about machine learning? Just like Hans’ owner Wilhelm, as model trainers we can’t always be sure of exactly what our model has learned from the data. The model may appear to perform superbly in training, but then makes terrible predictions about unseen data in a production environment. This is currently a hot topic within the machine learning community (I even attended an academic conference on the subject). If you’d like to find out more you can check out this paper.

How to be ‘too’ fit

Let’s take the following example. Here we have a dataset containing two classes represented by red crosses and blue circles.

We’d like to train a model that can classify the data into the two classes. After training you might be chuffed to find out that the model scores 100% accuracy on this dataset. But when you feed it some new data, the accuracy might drop completely — a classic sign of overfitting.

The classification boundary our model is using might look something like this:

Here you might already be able to see the problem. The 100% training score is achieved because the model has tightly fit a boundary line, but does not appear to have discovered the general trend in the data. A better generalised model might draw the line more like this:

Sure, the training accuracy of this model will be lower (we have two mistakes in blue and four mistakes in red), but the model has discovered a general pattern that will ensure better accuracy on new data.

Cross-validation to the rescue

To summarise, we need to keep two things in mind when evaluating our trained model. The error on the training set will only give us part of the story. Only through evaluating on a different set of ‘unseen’ data are we able to say that the model is performing well. This is what cross-validation sets out to achieve.

In cross-validation, the dataset is split into chunks. A certain proportion — let’s say 80% — is used for training the model as usual. The remaining data (20%) is not used for backpropagation, but instead is used to keep track of a test error. When this error stops improving (or in most cases, gets worse), it’s a sure sign that we’re overfitting — so we can stop training at that point.

An example of cross-validation in action. The model is overfit after 6 epochs.

Let’s see what that looks like in code:

Fitter, happier, more productive

That’s it for this week. We’ve covered the concept of overfitting and how to avoid it with cross-validation.

The cross-validation implementation I’ve introduced in this post is a very basic one — it’s far more common to train multiple models (usually up to ten) with a different chunk of test data each time. This is known as k-fold cross-validation. You can try extending the above example into a k-fold cross validator if you’re up for it.

As always, you can find the source code for all the above examples on our GitHub.

To begin your AI-coding training at day one, go here. And give us a follow to receive updates on our latest posts.

--

--

Andy Elmsley
The Sound of AI

Founder & CTO @melodrivemusic. AI video game music platform. Tech leader, programmer, musician, generative artist and speaker.