Why you should be plotting learning curves in your next machine learning project
by Adrià Luz
The bias-variance dilemma is a widely known problem in the field of machine learning. Its importance is such, that if you don’t get the trade-off right, it won’t matter how many hours or how much money you throw at your model.
In the illustration above, you can get a feel for what bias and variance are as well as how they can affect your model performance. The first chart shows a model (blue line) that is underfitting the training data (red crosses). This model is biased, because it “assumes” the relationship between the size of a house and its market price is linear when it is not. Plotting a scatter plot of the data is always helpful as it will reveal the true relationship between the variables — a quadratic function would fit the data “just right” (second chart). The third chart is a clear example of overfitting. The high complexity of the model allows it to fit the data very closely — too closely. Although this model might perform really well on the training data, its performance on the test data (i.e. data it has never seen before) will be much worse. In other words, this model suffers from high variance, which means that it won’t be good at making predictions on data it has never seen before. Because the main point of building a machine learning model is to be able to accurately make predictions on new data, you should be focused on making sure it will generalise well to unseen observations, rather than maximising its performance on your training set.
What can you do if your model performance is not so good?
There are several things you can do:
- Get more data
- Try a smaller set of features (reduce model complexity)
- Try adding/creating more features (increase model complexity)
- Try decreasing the regularisation parameter λ (increase model complexity)
- Try increasing the regularisation parameter λ (decrease model complexity)
The question now is: “how do I know which of those things to try first?”. The answer is: “well, it depends.”. And it basically depends on whether your model is suffering from high bias or from high variance.
The issue here, you might be wondering, is: “ok, so my model is not performing as expected… but how do I know if it has a bias problem or a variance problem?!”. Learning curves!
Learning curves show the relationship between training set size and your chosen evaluation metric (e.g. RMSE, accuracy, etc.) on your training and validation sets. They can be an extremely useful tool when diagnosing your model performance, as they can tell you whether your model is suffering from bias or variance.
If your learning curves look like this, it means your model is suffering from high bias. Both the training and validation (or cross-validation) error is high and it doesn’t seem to improve with more training examples. The fact that your model is performing similarly bad for both the training and validation sets suggests that the model is underfitting the data and therefore has high bias.
On the other hand, if your learning curves look like this, your model might have a high-variance problem. In this chart, the validation error is much higher than the training error, which suggests that you are overfitting the data.
What can you do if your model performance is not so good? (pt. II)
Cool, so you have now identified what’s going on with your model and are in a great position to decide what to do next.
If your model has high bias, you should:
- Try adding/creating more features
- Try decreasing the regularisation parameter λ
These two things will increase your model complexity and therefore will contribute to solve your underfitting problem.
If your model has high variance, you should:
- Get more data
- Try a smaller set of features
- Try increasing the regularisation parameter λ
When your model is overfitting the training data, you can either try reducing its complexity or getting more data. As you can see above, the learning curves chart of a high-variance model suggests that, with enough data, the validation and training error will end up closer to each other. An intuitive explanation for this is that if you give your model more data, the gap between your model’s complexity and the underlying complexity in your data will get smaller and smaller.
Python implementation and real life example
I wrote this function to plot the learning curves of a model. Feel free to use it in your own work!
I thought I would end this post by showing you a real life example of a learning curves plot, which was created with the above code.
From the plot, it is very clear that my Random Forest model is suffering from high bias, as the training and validation curves are very close together and the accuracy is not great at around the 70% mark. Knowing this helped me when it came to deciding what my next step was going to be in order to improve my model performance. Because I had a high-bias problem, I knew getting more training data wasn’t going to help by itself, and that increasing the complexity of my model by engineering new and more relevant features was probably going to deliver the greatest impact.
Next time you have a bad-performing model in front of you, remember to plot the learning curves, analyse them, and work out whether you have a bias or a variance problem. Knowing this will help you decide what your next steps should be and it could save you countless headaches and hours wasted on work that is not going to help your model.