Machine Learning Crash Course: Part 4 — The Bias-Variance Dilemma

13 Jul 2017 | Daniel Geng and Shannon Shih

Here’s a riddle:

So what does this have to do with machine learning? Well, it turns out that machine learning algorithms are not that much different from our friend Doge: they often run the risk of over-extrapolating or over-interpolating from the data that they are trained on.

There is a very delicate balancing act when machine learning algorithms try to predict things. On the one hand, we want our algorithm to model the training data very closely, otherwise we’ll miss relevant features and interesting trends. However, on the other hand we don’t want our model to fit too closely, and risk over-interpreting every outlier and irregularity.

Fukushima

The engineers of the nuclear power plant used earthquake data from the past 400 years to train a regression model. Their prediction looked something like this:

Source: Brian Stacey, Fukushima: The Failure of Predictive Models

The diamonds represent actual data while the thin line shows the engineers’ regression. Notice how their model hugs the data points very closely. In fact, their model makes a kink at around a magnitude of 7.3 — decidedly not linear.

In machine learning jargon, we call this overfitting. As the name implies, overfitting is when we train a predictive model that “hugs” the training data too closely. In this case, the engineers knew the relationship should have been a straight line but they used a more complex model than they needed to.

If the engineers had used the correct linear model, their results would have looked something like this:

Notice there’s no kink this time, so the line isn’t as steeply sloped on the right.

The difference between these two models? The overfitted model predicted one earthquake of at least magnitude 9 about every 13000 years while the correct model predicted one earthquake of at least magnitude 9 just about every 300 years. And because of this, the Fukushima Nuclear Power Plant was built only to withstand an earthquake of magnitude 8.6. The 2011 earthquake that devastated the plant was of magnitude 9 (about 2.5 times stronger than a magnitude 8.6 earthquake).

Underfitting

While the linear model (the blue line) follows the data (the purple X’s), it misses the underlying curving trend of the data. A quadratic model would have been better.

The Bias-Variance dilemma

In the field of machine learning this incredibly important problem is known as the bias-variance dilemma. It’s entirely possible to have state-of-the-art algorithms, the fastest computers, and the most recent GPUs, but if your model overfits or underfits to the training data, its predictive powers are going to be terrible no matter how much money or technology you throw at it.

The name bias-variance dilemma comes from two terms in statistics: bias, which corresponds to underfitting, and variance, which corresponds to overfitting.

Our example of underfitting from above. The blue line is our model, and the purple X’s are the data that we are trying to predict.

The drawing above depicts an example of high bias. In other words, the model is underfitting. The data points obviously follow some sort of curve, but our predictor isn’t complex enough to capture that information. Our model is biased in that it assumes that the data will behave in a certain fashion (linear, quadratic, etc.) even though that assumption may not be true. A key point is that there’s nothing wrong with our training — this is the best possible fit that a linear model can achieve. There is, however, something wrong with the model itself in that it’s not complex enough to model our data.

Approximating data using a very complex model. Notice how the line tends to over-interpolate between points. Even though there is a general upward trend in the data, the model predicts huge oscillations.

In this drawing, we see an example of a model with very high variance. In other words, a model that is overfitting. Again, the data points suggest a sort of graceful curve. However, our model uses a very complex curve to get as close to every data point as possible. Consequently, a model with high variance has very low bias because it makes little to no assumption about the data. In fact, it adapts too much to the data.

Again, there’s nothing wrong with our training. In fact, our predictor hits every single data point and is by most metrics perfect. It is actually our model itself that’s the problem. It wants to account for every single piece of data perfectly and thus over-generalizes. A model that over-generalizes has high variance because it varies too much based on insignificant details about the data.

An example of high “variance” in everyday life from xkcd. ‘Girls sucking at math’ in this case is an overgeneralization based on only a single data point.

Explaining the Dilemma

For the case of high bias, we have a very simple model. In our example above we used a linear model, possibly the most simple model there is. And for the case of high variance, the model we used was super complex (think squiggly).

The image above should make things a bit more clear as to why the bias-variance trade-off exists. Whenever we choose a model with low complexity (like a linear model) and thus low variance, we’re also choosing a model with high bias. If we try to increase the complexity of our model, such as a quadratic model, we sacrifice low variance in exchange for low bias at the cost of high variance. The best we can do is try to settle somewhere in the middle of the spectrum, where the purple pointer is.

Noise

In math terms, we say that y = f(x)+ ϵ. Where yy is the training data we end up measuring, f(x) is that perfect function, and ϵ is the random error that we can’t avoid.

Our “perfect” function added to noise is what we end up measuring in the real world.

Our goal in machine learning is to retrieve the perfect function (or something reasonably close to it) from the noisy data. That is, from yy we want to construct f^(x), an approximation of f(x).

The bias-variance tradeoff results from using our necessarily flawed data, y, to reconstruct this perfect function f(x). We want to ignore the noise term ϵ, but because y = f(x)+ ϵ, the data we receive inevitably has the noise mashed up with the real data, f(x). So if we ignore the noise too much, we end up ignoring the real data as well.

This noise manifests as disorganized data points. What we want to do is recover the perfect function (in green) from these data points.

Resolving the Dilemma

As we’ve seen above, overfitting and underfitting have very clear signatures in training and test data. Overfitting results in low training error and high test error, while underfitting results in high errors in both the training and test set.

However, measuring training and test errors is hard when we have relatively few data points and our algorithms require a fair amount of data (which unfortunately happens quite often). In this case we can use a technique called cross-validation.

This is where we take our entire dataset and split it into k groups. For each of the k groups, we train on the remaining k−1 groups and validate on the kth group. This way we can make the most use of our data, essentially taking a dataset and training k times on it.

Cross validation with k=3. We divide the entire training set into three parts, and then train our model three times using each of the three parts as the validation set with the remaining two parts making up our actual training set.

As for what to do after you detect a problem? Well, having high bias is symptomatic of a model that is not complex enough. In that case, your best bet would be to just pick a more complex model.

The problem of high variance is a bit more interesting. One naive approach to reduce high variance is to use more data. Theoretically, with a complex enough model, as the number of samples tends toward infinity the variance tends toward zero. However, this approach is naive because the rate at which the variance decreases is typically fairly slow, and (the larger problem) data is almost always very hard to come across. Unless you’re working at Google Brain (and even then), getting more data takes time, energy, and money.

A better approach to reducing variance is to use regularization. That is, in addition to rewarding your model as it models the training data well, penalize it for growing too complex. Essentially, regularization injects “bias” into the model by telling it not to become too complex. Common regularization techniques include lasso or ridge regression, dropout for neural networks, and soft margin SVMs.

Finally, ensemble learning (which we’ll be talking about next time!) offers a way to reduce the variance of a model without sacrificing bias. The idea is to have an ensemble of multiple classifiers (typically decision trees) trained on random subsets of the training data. To actually classify a data point, the ensemble of classifiers all “vote” on a classification.

Conclusion

A student-run organization at UC Berkeley working on ML applications in industry, academic research, and making ML education more accessible to all