3 Simple Tricks That Will Change the Way You Debug PyTorch

  • find out why your training loss does not decrease,
  • implement automatic model verification and anomaly detection,
  • save valuable debugging time with PyTorch Lightning.
PyTorch Lightning brings back the smile on your face. Photo by ETA+ on Unsplash
Vanilla MNIST PyTorch code, adapted from github.com/pytorch/examples

Trick 0: Organize Your PyTorch

Before we debug this code, we will organize it into the Lightning format. PyTorch Lightning automates all boilerplate/engineering code in a Trainer object and neatly organizes all the actual research code in the LightningModule so we can focus on what’s important:

Can you spot all the bugs in this code?

Trick 1: Sanity Checking the Validation Loop

If we run the above, we immediately get an error message complaining that the sizes don’t match in line 65 in the validation step.

...
---> 65 loss = F.nll_loss(x, y)
66 acc = accuracy(torch.max(output, dim=1)[1], y)
67 self.log('val_loss', loss, on_epoch=True,
reduce_fx=torch.mean)
...
RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [64]
loss = F.nll_loss(output, y)
Figures showing the training- and validation loss not decreasing
PyTorch Lightning has logging to TensorBoard built in. In this example, neither the training loss nor the validation loss decrease.

Trick 2: Logging the Histogram of Training Data

It is important that you always check the range of the input data. If model weights and data are of very different magnitude it can cause no or very low learning progression, and in the extreme case lead to numerical instability. It happens for instance when data augmentations are applied in the wrong order or when a normalization step is forgotten. Is this the case in our example? We should be able to find out by printing the min- and max values. But wait! This is not a good solution, because it pollutes the code unnecessarily, fills the terminal and overall takes too much time to repeat it later on should we need to. Better: Write a Callback class that does it for us!

A simple Callback that logs histograms of the training data to TensorBoard.
  1. It is separate from your research code; there is no need to modify your LightningModule!
  2. It is portable, so it can be reused for future projects and it requires only changing two lines of code: import the callback, then pass it to Trainer.
  3. Can be extended by subclassing or be combined with other callbacks.
transforms.Normalize(128, 1)  # wrong normalization
transforms.Normalize(mean=0.1307, std=0.3081)

Trick 3: Detecting Anomalies in the Forward Pass

After fixing the normalization issue, we now also get the expected histogram logged in TensorBoard. But unfortunately the loss is still not decreasing. Something is still wrong. Knowing that the data is correct, a good place to start looking for mistakes is the forward path of the network. A common source of error are operations that manipulate the shape of tensors, e.g., permute, reshape, view, flatten, etc., or operations that are applied to a single dimension, e.g., softmax. When these functions are applied on the wrong dimensions or in the wrong order, we usually get a shape mismatch error, but this is not always the case! These nasty bugs are hard to track down.

A quick sanity check that the model does not mix data across the batch.
output = F.log_softmax(x, dim=0)
output = F.log_softmax(x, dim=1)

Conclusion

Writing good code starts with organization. PyTorch Lightning takes care of that part by removing the boilerplate code surrounding training loop engineering, checkpoint saving, logging etc. What is left is the actual research code: the model, the optimization and the data loading. If something is not working the way we expect it to work, it is likely a bug in one of these three parts of the code. In this blog post, we implemented two callbacks that help us 1) monitor the data that goes into the model; and 2) verify that the layers in our network do not mix data across the batch dimension. The concept of a callback is a very elegant way of adding arbitrary logic to an existing algorithm. Once implemented, it can be easily integrated into new projects by changing two lines of code.

Advanced Callbacks

For the benefit of clarity, the code for the callbacks shown here is very simple and may not work right away with your models. However, it is not much effort to generalize it. In fact, I have already done it for you in this repository. The TrainingDataMonitor is a bit nicer because it works with multiple input formats (tuple, dict, list etc.) and also creates a meaningful label for each histogram. In addition, there is a ModuleDataMonitor which can even log the inputs and outputs of each layer in the network. The model verification is a bit more sophisticated and also works with multiple in- and outputs. Finally, there is the official PyTorch Lightning Bolts collection of well-tested callbacks, losses, model components and more to enrich your Lightning experience.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store