Debugging Neural Networks: A Checklist
You’ve framed your problem, prepared your datasets, designed your models and revved up your GPUs. With bated breath, you start training your neural network, hoping to return in a few days to great results.
When you do return though, you find yourself faced with a very different picture. Your network seems to do no better than random selection. Or, if it is a classification model, has curiously learned to classify all entries to a single dominant category. You scratch your head wondering what went wrong, and hit a wall. What’s more, since you’re programming at a higher layer of abstraction, you have no intuitive sense for what’s going on with your matrices and activation functions.
This isn’t a problem faced only by beginners. Empirically, it happens to even the more experienced among us, especially as the complexity of models, the dataset and the core problem increases. So if you find yourself in this situation, don’t fret. To tackle this, we’ve put together a little checklist that might help you find a way out of this hole. This was written specifically in the context of image classification, but the advice is generic enough to apply to all types of networks.
The Checklist
Inputs
- Overfit on a small dataset: Prepare a tiny dataset of 50 or so records, turn off regularization (usually dropout layers + L1/L2 losses), and make sure loss converges to zero loss over multiple epochs. If it doesn’t then you know your problem lies deep.
- Train on standard corpora: To rule out whether your dataset itself is the source of the issue, switch to standard Imagenet or CIFAR-10 datasets.
- Mean centering: Make sure that your inputs have zero mean. For images, calculate the mean for each pixel across your entire training dataset and subtract the resulting mean image from each of your inputs. Turn off variance normalization and whitening if you’d like to keep it simple.
- Balanced dataset: If you’re doing fine on small datasets, but not on larger ones, check to see that your input contains sufficient entries from all of your classes. Highly imbalanced datasets could topple your network in favor of one class or another. The easiest approach out of the door is to undersample the larger classes and generate a training set that is equally balanced across all classes.
- Moving beyond undersampling: If you’re doing well with balanced datasets but are fretting over discarding most of your data, there are several techniques you can advance to. This article summarizes some of these approaches.
Model Architecture & Initialization
- Simplify your model: If you believe that the fault lies in your model itself, simplify it to one or two hidden layers. This should help you isolate whether the architecture of your model is where your problem lies.
- Weight initialization (shallow models): If you fail to initialize your weights at all, then back-propagation might either have no effect, or move all hidden nodes in the same direction. You have to introduce a small amount of perturbation to get your network learning. For small models with a handful of layers of depth, a gaussian distribution around 1e-2 should do the trick.
- Weight initialization (deeper models): With deeper models, initializing all layers with the same weight configurations might destroy weights once training begins. As your inputs propagate forward through the network, deeper layers will get multiplied with increasingly smaller numbers and converge to zero. As a result, your gradients during back-propagation will be very tiny numbers and your model will learn almost nothing. Use a batch normalization layer to combat this issue, and ideally train with a healthy batch size.
- Pre-trained models: If you’re using a standardized model based on Inception or ResNet, initialize your model from publicly available checkpoints. These pre-trained weights are usually generated after weeks of training, so if the fault lies in the sheer complexity of your problem and hence time taken to learn, then you would have given yourself a head start.
Loss
- Learning rate: Check to see if your loss drops with time. Depending on how steep or gentle your loss is, tweak your learning rate, ideally in logarithmic steps.
- Cross-entropy loss: If you’re using cross-entropy loss, check to see that your initial loss is approximately
-ln(1/NUM_CLASSES)
. If it's not, then you have a critical bug on hand. Note that this rule only applies when you compare your one-hot labels with softmax probabilities and not with logits as is often done. - Regularization loss: If your regularization loss dominates your cross-entropy loss by a magnitude of 10x or more, then reduce your regularization lambda or decrease the magnitude of your initial weights. Such skewed loss over prolonged periods may indicate that your network is more concerned with penalizing large weights than it is with converging to a solution.
- Steep loss reduction: At times, you might see that your loss drops steeply after a short period of training, before stabilizing. This is a strong indication that your initial weight allocation is inadequate.
Activation Functions
- Saturated Tanh and Sigmoid: Tanh and Sigmoid functions suffer from the saturation problem, i.e. at their extremes, derivatives are zero and your model stops learning. With bad weight initialization, it may not take long for your network to get to this point. I recommend trying other activation functions, especially ReLUs, to determine if activation is where the problem lies.
- Dead ReLUs: ReLUs aren’t a magic bullet, since they can “die” when fed with values less than zero. If most of your neurons die within a short period of training, then a large chunk of your network might stop learning very soon. If you find yourself in this situation, then take a closer look at your initial weights. If you really need a magic bullet, add a small initial bias (e.g. 0.01) to your weights. If that doesn’t work, you can try to experiment with Maxout and Leaky ReLUs.
Other Tips
- Gradient clipping: If you’re using CNNs, then exploding gradients are unlikely to be a problem. If you do face this problem though and can find no natural way around it, then try experimenting with gradient clipping.
- Plot your gradients: Visualize your gradients at each layer of your network. If you see that most of them are at zero, then you can tell that most of your neurons are dead, your networking has stopped learning, or that back-propagation is simply not happening.
- Ratio of weight-magnitudes to weight-updates: If this ratio is too low (typically less than 1e-4), then you can infer that your network is learning too slowly, or is learning very little.
- Plot variance of weights: Check to see that the variance of your weights remain consistent as you go deeper into your network. If you observe a massive skew here, then you know that your initialization/batch-norm is not kicking in right.
Despite all of these steps, if your network still doesn’t look like it’s headed in the right direction, then you ought to look for more fundamental errors in either your code or how your problem is framed. On the other hand, if you’ve now resolved your problems, congratulations! Start tuning your hyper-parameters and tweaking your model, and you’ll be on your way.