Tutorial: Training on larger batches with less memory in AllenNLP

Evan Pete Walsh
Sep 8, 2020 · 7 min read

This is part of a series of mini-tutorials to help you with various aspects of the AllenNLP library.

👉 If you’re new to AllenNLP, consider first going through the official guide as these tutorials will be focused on more advanced use cases.

⚠️ Please keep in mind this was written for version 1.1 and greater of AllenNLP and may not be relevant for older versions.

Batch size is an important hyperparameter to tune when training deep neural networks. Using the largest batch size that fits in memory on your GPU is often a good starting point, but as state-of-the-art models get bigger and bigger, the hardware many of us have available can be limited, forcing us to use smaller batches.

Now, small batches aren’t necessarily bad. But the smaller the batch, the less information it contains, and so the gradient estimates coming from each batch will be less accurate. This can sometimes make it difficult for your model to converge.

Fortunately, there are several options available in AllenNLP that will allow you to train with larger batches by saving memory.

In this post, we will go through 3 of these options and show you how you can utilize them in your experiments.

Gradient accumulation (GA) is a simple technique that makes use of the fact that the loss for each batch is usually calculated by averaging the loss from each individual instance in the batch. As a result, the losses from each instance don’t necessarily need to be calculated in parallel; they could be calculated sequentially and then averaged together afterward.

And that’s the idea behind GA: we break up a batch into smaller partial batches, run each of the partial batches through the forward pass of our model, and accumulate the total loss from each partial batch as we go. Then we average the loss over all of the instances from each of the partial batches to get the loss for the full batch, at which point we can call the backwards() method on the final average loss to calculate the gradients for the full batch.

In AllenNLP, you can utilize GA by just setting the num_gradient_accumulation_steps parameter of the trainer to an integer greater than 1. This gives you an effective batch size of num_gradient_accumulation_steps * batch_size.

For example, let’s say we want to use a batch size of 64, but we can only fit a batch size of 16 in memory. By setting the batch_size parameter of the data_loader to 16 and num_gradient_accumulation_steps to 4 in the trainer, we can achieve an effective batch size of 64.

Gradient checkpointing (GC) is a technique that came out in 2016 that allows you to use only O(sqrt(n)) memory to train an n layer model, with the cost of one additional forward pass for each batch [1].

In order to understand how GC works, it’s important to understand how backpropagation works.

Backpropagation is really just an application of the chain rule to calculate the derivative/gradient of the loss with respect to the trainable parameters of the model. As such, the gradient of each layer of the model is calculated as a function of the activations (outputs) of that layer and the gradient of the layer after it. Therefore, during the backward pass through the model, we still need to keep in memory all of the activations that were calculated on the forward pass through each layer, at least until we’ve finished calculating the gradient for the corresponding layer.

👉 For information about the math behind backpropagation, check out this post by Terrence Parr and Jeremy Howard on explained.ai [2].

The image below provides a visual of this process. The nodes on the top row represent the activations for each layer that are calculated on the forward pass of the model. The nodes on the bottom row represent the gradients corresponding to the same layers. The nodes that are lit up in purple at any given time are the nodes that need to be held in memory.


So the maximum number of nodes that are lit up at the same time represents the peak memory usage. In this case 7.

But how can we decrease that number? A naive approach would be to drop any activations that we don’t immediately need, and then just recalculate them when they’re needed:


As you can see, the maximum number of nodes we need to hold in memory with this approach is only 4! However, we’ve had to pay a significant computational cost since many of the forward pass activations needed to be calculated again during the backward pass.

There is, though, a clever way we could balance these trade-offs. Namely, we could keep certain intermediary activations in memory, called checkpoints:


Then only a fraction of the activations needs to be recomputed.


This results in a maximum of 5 nodes needing to be stored in memory at any given time, which is the trick behind GC.

In AllenNLP, GC is available in the PretrainedTransformerEmbedder and PretrainedTransformerMismatchedEmbedder classes. All you need to do is set the gradient_checkpointing parameter to True.

Last but not least is automatic mixed precision (AMP) operations.

With PyTorch and most other deep learning frameworks, models are trained with single-precision (32 bit) floating point numbers by default. However, a lot of the time it’s possible to achieve comparable optimization performance using half-precision (16 bit) floating point numbers for certain operations.

That’s where AMP comes in. Deep learning frameworks that support AMP have the ability to automatically decide which operations can be done with 16 bit floats when AMP is activated. And this can save a ton of memory since 16 bit floats take up half the space as 32-bit floats, allowing you to train on larger batches.

Thanks to the PyTorch team, AMP was integrated into version 1.6 of PyTorch as the torch.amp module, and this made it easier than ever to integrate AMP into AllenNLP as well.

As of AllenNLP 1.0, you can now activate AMP simply by setting the use_amp parameter to True in the trainer part of your config.

There you have it! Three techniques for increasing batch size when memory is a limiting factor:

  1. Gradient accumulation (GA).
    ➡️ Just set num_gradient_accumulation_steps in the trainer part of your config.
  2. Gradient checkpointing (GC).
    ➡️ Just set gradient_checkpointing to True in your pretrained transformer embedder.
  3. Automatic mixed precision (AMP).
    ➡️ Just set use_amp to True in the trainer part of your config.

GA is the simplest method, both conceptually and in implementation, and as is guaranteed to work out-of-the-box with any model.

If you’re using a transformer-based model, both GC and AMP are good options to try as well.

AMP is especially useful with transformers because most of the internal operations within the transformer layers can be done on half precision floats. Not only will this save a substantial amount of memory, but it can also lead to a big speed up since operations done with half precision floats are less complex than those done with full precision.

Of course, you could mix and match any of these together as well to run even bigger batches.

Keep in mind, a bigger batch size is not always better.

While larger batches will give you a better estimate of the gradient, the reduction in the amount of uncertainty is less than linear as a function of batch size. In other words, you get diminishing marginal returns by increasing batch size. For more information, see Chapter 8 of the Deep Learning book [3].

Further, if you continue to increase batch size, at some you may start to see worse generalization performance. One potential explanation for this is that training on larger batches increases the probability of converging to sharp local minima [4, 5], which don’t generalize as well to new data.

So use these techniques with caution. Remember, batch size is ultimately just another hyperparameter that you should tune to your specific use case. It’s not “one size fits all”.

That’s it! Happy NLP-ing!

I hope this tutorial was helpful. If you find any issues please leave a comment or open a new issue in the AllenNLP repo and give it the “Tutorials” tag:

AI2 Blog

AI for the Common Good.