How to unit test machine learning code.
Over the past year, I’ve spent most of my working time doing deep learning research and internships. And a lot of that year was making very big mistakes that helped me learn not just about ML, but about how to engineer these systems correctly and soundly. One of the main principles I learned during my time at Google Brain was that unit tests can make or break your algorithm and can save you weeks of debugging and training time.
However, there doesn’t seem to be a solid tutorial online on how to actually write unit tests for neural network code. Even places like OpenAI only found bugs by staring at every line of their code and try to think why it would cause a bug. Clearly, most of us don’t have that kind of time or self hatred, so hopefully this tutorial can help you get started testing your systems sanely!
Let’s start off with a simple example. Try to find the bug in this code.
Do you see it? The network isn’t actually stacking. When I wrote this code, I copy and pasted the line slim.conv2d(…) and only modified the kernel sizes, and never the actual input.
I’m embarrassed to say that this actually happened to me about a week ago… But it’s an important lesson! These bugs are really hard to catch for a few reasons.
- This code never crashes, raises an error, or even slows down.
- This network still trains and the loss will still go down.
- The values converge after a few hours, but to really poor results, leaving you scratching your head as to what you need to fix.
When your only feedback is the final validation error, the only place you have to search is your entire network architecture. Needless to say, you’ll need a better system.
So how do we actually catch this before we do a full multi day training session? Well, the easiest thing to notice about this is that the values of the layers never actually reach any other tensors outside the function. So assuming we had some type of loss and an optimizer, these tensors never get optimized, so they will always have their default values.
We can detect it by simply taking a training step and comparing their before and after.
Boom. In less than 15 lines of code, we now verified that a least all of the variables that we created get trained.
This test is super simple and super useful. Let’s say that we fixed the previous issue and now we want to start adding some batch normalization. See if you can spot the bug.
Did you see it? This one is super subtle. You see, in tensorflow batch_norm actually has is_training defaulted to False, so adding this line of code won’t actually normalize your input during training! Thankfully, the last unit test we wrote will catch this issue immediately! (I know, because this happened to me 3 days ago.)
Let’s do another example. This actually comes from a reddit post I saw one day. I won’t get into too much detail, but basically the person wanted to create a classifier that gave an output in the range of (0, 1). See if you can find the bug.
Notice the bug? This one is really hard to spot before hand, and can lead to super confusing results. Basically what is happening here is that prediction only has a single output, which, when you apply softmax cross entropy onto it, causes the loss to be 0 always.
An easy way to test for this is to well… make sure the loss is never 0.
Another good test to do is similar to our first test, but backwards. You can make sure that only the variables you want to train actually get trained. Take for example a GAN. One of the common bugs to appear is accidentally forgetting to set which variables to train during which optimization. Code like this happens all the time.
The biggest issue here is that the optimizer has a default setting to optimize ALL of the variables. In advance architectures like GANs, this is a death sentence to all of your training time. However, you can easily detect these mistakes by writing a test like this:
A very similar test can be written for the discriminator. And this same test can be used for a lot of reinforcement learning algorithms as well. Many actor-critic models have separate networks that need to be optimized by different losses.
Here are some patterns I would recommend following for your tests.
- Keep them deterministic. It would really suck to have a test fail in a weird way, only to never be able to recreate it. If you really want randomized input, make sure to seed the random number so you can rerun the test easily.
- Keep the tests short. Don’t have a unit test that trains to convergence and checks against a validation set. You are wasting your own time if you do this.
- Make sure you reset the graph between each test.
In conclusion, these black box algorithms still have lots of ways to be tested! Spending an hour writing a test can save you days of rerunning training sessions, and can greatly improve your research. Wouldn’t suck to have to throw away perfectly good ideas because our implementations were buggy?
This list clearly isn’t comprehensive, but it’s a solid start! If you have extra advice or specific tests that you found to be helpful, please message me on twitter! I’d love to make a part 2 of this.
All opinions in this piece are a reflection of my experiences and are not sponsored or supported by Google.