3 Simple Tricks That Will Change the Way You Debug PyTorch
Every Deep Learning project is different. No matter how much experience you bring with you, there will always be new challenges and unexpected behavior you will struggle with. The skill- and mindset that you bring to the project will determine how quickly you discover and adapt to the obstacles that stand in the way of success.
From a practical point of view, a Deep Learning project starts with the code. Organizing it is easy in the beginning, but as the project grows in complexity, more and more time is spent in debugging and sanity checking. Surprisingly, much of this can be automated. In this post I will show you how you can
- find out why your training loss does not decrease,
- implement automatic model verification and anomaly detection,
- save valuable debugging time with PyTorch Lightning.
For demonstration, we will use a simple MNIST classifier example that has a couple of bugs:
If you run this code, you will find that the loss does not decrease and after the first epoch, the test loop crashes. What’s wrong?
Trick 0: Organize Your PyTorch
Before we debug this code, we will organize it into the Lightning format. PyTorch Lightning automates all boilerplate/engineering code in a Trainer object and neatly organizes all the actual research code in the LightningModule so we can focus on what’s important:
Lightning takes care of many engineering patterns that are often a source for errors: training-, validation- and test loop logic, switching the model from train to eval mode and vice versa, moving the data to the right device, checkpointing, logging, and much more.
Trick 1: Sanity Checking the Validation Loop
If we run the above, we immediately get an error message complaining that the sizes don’t match in line 65 in the validation step.
---> 65 loss = F.nll_loss(x, y)
66 acc = accuracy(torch.max(output, dim=1), y)
67 self.log('val_loss', loss, on_epoch=True,
...RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : 
If you noticed, Lightning runs two validation steps before the training begins. This is not a bug, it’s a feature! It actually saves us a lot of time that would otherwise be wasted if the error happened after a long training epoch. The fact that Lightning sanity checks our validation loop at the beginning lets us fix the error quickly, since it’s obvious now that line 65 should read
loss = F.nll_loss(output, y)
as it does in the training step.
This was an easy fix because the stack trace told us what was wrong, and it was an obvious mistake. The fixed code now runs without errors, but if we look at the loss value in the progress bar (or the plots in TensorBoard) we find that it is stuck at a value 2.3. There could be many reasons for this: wrong optimizer, poorly chosen learning rate or learning rate schedule, bug in the loss function, problem with the data etc.
Trick 2: Logging the Histogram of Training Data
It is important that you always check the range of the input data. If model weights and data are of very different magnitude it can cause no or very low learning progression, and in the extreme case lead to numerical instability. It happens for instance when data augmentations are applied in the wrong order or when a normalization step is forgotten. Is this the case in our example? We should be able to find out by printing the min- and max values. But wait! This is not a good solution, because it pollutes the code unnecessarily, fills the terminal and overall takes too much time to repeat it later on should we need to. Better: Write a Callback class that does it for us!
A callback in PyTorch Lightning can hold arbitrary code that can be injected into the Trainer. This one here computes the histogram of the input data before it goes into the training step. Wrapping this functionality into a callback class has the following advantages:
- It is separate from your research code; there is no need to modify your LightningModule!
- It is portable, so it can be reused for future projects and it requires only changing two lines of code: import the callback, then pass it to Trainer.
- Can be extended by subclassing or be combined with other callbacks.
Now with the new callback in action, we can open TensorBoard and switch to the “Histograms” tab to inspect the distribution of the training data:
The targets are in the range [0, 9] which is correct because MNIST has 10 digit classes, but the images have values between -130 and -127, that’s wrong! We quickly find that there is a problem with normalization in line 41:
transforms.Normalize(128, 1) # wrong normalization
These two numbers are supposed to be the mean and standard deviation of the input data (in our case, the pixels in the images). To fix this, we add the true mean and standard deviation and also name the arguments to make it clear:
We can look these numbers up because for MNIST they are already known. For your own datasets you would have to compute it yourself.
After the normalization is applied, the pixels will have mean 0 and standard deviation 1, just like the weights of the classifier. And we can confirm this by looking at the histogram in TensorBoard.
Trick 3: Detecting Anomalies in the Forward Pass
After fixing the normalization issue, we now also get the expected histogram logged in TensorBoard. But unfortunately the loss is still not decreasing. Something is still wrong. Knowing that the data is correct, a good place to start looking for mistakes is the forward path of the network. A common source of error are operations that manipulate the shape of tensors, e.g., permute, reshape, view, flatten, etc., or operations that are applied to a single dimension, e.g., softmax. When these functions are applied on the wrong dimensions or in the wrong order, we usually get a shape mismatch error, but this is not always the case! These nasty bugs are hard to track down.
Let’s have a look at a technique that lets us detect such errors very quickly.
The idea is simple: If we change the n-th input sample, it should only have an effect on the n-th output. If other outputs i ≠ n also change, the model mixes data and that’s not good! A reliable way to implement this test is to compute the gradient on the n-th output with respect to all inputs. The gradient must be zero for all i ≠ n (red in the animation above) and nonzero for i = n (green in the animation above). If these conditions are met, the model passes the test. Below is the implementation for n = 3:
And here is the same in a Lightning Callback:
Applying this test to the LitClassifer immediately reveals that it is mixing data. Now knowing what we are looking for, we quickly find a mistake in the forward method. The softmax in line 35 is applied to the wrong dimension:
output = F.log_softmax(x, dim=0)
It should instead be:
output = F.log_softmax(x, dim=1)
And there you go, the classifier works now! The training and validation losses quickly decrease.
Writing good code starts with organization. PyTorch Lightning takes care of that part by removing the boilerplate code surrounding training loop engineering, checkpoint saving, logging etc. What is left is the actual research code: the model, the optimization and the data loading. If something is not working the way we expect it to work, it is likely a bug in one of these three parts of the code. In this blog post, we implemented two callbacks that help us 1) monitor the data that goes into the model; and 2) verify that the layers in our network do not mix data across the batch dimension. The concept of a callback is a very elegant way of adding arbitrary logic to an existing algorithm. Once implemented, it can be easily integrated into new projects by changing two lines of code.
For the benefit of clarity, the code for the callbacks shown here is very simple and may not work right away with your models. However, it is not much effort to generalize it. In fact, I have already done it for you in this repository. The TrainingDataMonitor is a bit nicer because it works with multiple input formats (tuple, dict, list etc.) and also creates a meaningful label for each histogram. In addition, there is a ModuleDataMonitor which can even log the inputs and outputs of each layer in the network. The model verification is a bit more sophisticated and also works with multiple in- and outputs. Finally, there is the official PyTorch Lightning Bolts collection of well-tested callbacks, losses, model components and more to enrich your Lightning experience.