A comprehensive guide to memory usage in PyTorch

Jacob Stern
Deep Learning for Protein Design
6 min readDec 13, 2021

Out-of-memory (OOM) errors are some of the most common errors in PyTorch. But there aren’t many resources out there that explain everything that affects memory usage at various stages of training/inference. This guide should help you figure out what is using up all of your memory in Pytorch, and help you avoid common pitfalls. If you use these tricks to cut down your memory consumption, you can train with bigger models and larger batches!

After reading this post, you should know exactly what is consuming your memory at each step of training and inference. If you want the shortcut formula for memory usage, scroll to the bottom of this post. There is also a Colab notebook that contains all of the code for this post.

Example

Here is a demo script:

Here is the output:

So what is happening at each step?

Step 1 — model loading: Move the model parameters to the GPU. Current memory: model.

Step 2 — forward pass: Pass the input through the model and store the intermediate outputs (activations). Storing these activations is what takes up memory at this step. Though storing all activations is not strictly necessary (see “Gradient Checkpointing”), storing them is practically necessary for the efficiency of the backpropagation algorithm. Current memory: model + activations.

Step 3 — backward pass: Compute the gradients from the end of the network to the beginning, and discard the forward activations as you go. Because we have discarded the forward activations, the memory use after the backward pass is double the model size — one copy of the weights, and one copy of the gradients. Current memory: model + gradients.

Step 4 — optimizer step: Update the parameters, and keep track of running optimizer parameters. Many optimizers keep track of parameters such as an estimate of the first and second moments of the gradient, for each model weight. This takes up twice the model size for Adam (which uses two moments), one times the model size for RMSprop (which uses one moment), and zero times the model size for SGD (which doesn’t use moments). Current memory: model + gradients + gradient moments.

Step 5 — run the next iterations: After the gradients have been computed once and the optimizer has taken a step, the gradients and the gradient moments stick around. So your total maximum memory usage in future iterations will be: model + activations + gradients + gradient moments, which means the memory usage will increase on the second pass but then remain constant.

Now, with a basic understanding of what normally consumes memory, let’s look at some special cases and optimizations to save memory.

Mixed Precision Training

Mixed precision training is a technique that stores model weights and gradients in full 32-bit precision, but uses half precision for the forward pass. This halves the memory used in the forward pass. These memory savings are not reflected in the current PyTorch implementation of mixed precision (torch.cuda.amp), but are available in Nvidia’s Apex library with `opt_level=02` and are on the roadmap for the main PyTorch code base. As an added bonus, mixed precision training also speeds up calculations!

Gradient Checkpointing

Gradient checkpointing is another way to save memory on the forward pass, but is a bit more involved than some of these other tricks. Check out this blog if you want to try it out!

Distributed Data Parallel (DDP) and memory usage.

When using Distributed Data Parallel, you may see that your model takes up twice the amount of memory when you load it to the GPUs. This is because DDP creates “buckets” for each GPU, where it gathers gradients communicated from all other GPUs. So for distributed training, account for one more copy of the gradients.

Loading a saved model and memory usage

There are several potential pitfalls for memory usage when loading a saved model.

Pitfall #1: Loading to a different device than the model was saved on. By default, PyTorch loads a saved model to the device that it was saved on. If that device happens to be occupied, you may get an out-of-memory error. To resolve this, make sure to specify the current device to load your model to.

Don’t do this
Do this instead

Pitfall #2: Not deleting a loaded checkpoint. If you load a checkpoint in a different line from loading the checkpoint into the model, the checkpoint variable might hang around, consuming precious GPU memory. To avoid this, free the checkpoint variable.

Don’t do this
Do this instead

Saving memory at inference time

All suggestions up to now have referred to model training. But when using a trained model (“inference”), we only need the model weights, so we don’t need to store forward activations, gradients, or gradient moments! Our memory usage is simply the model size (plus a small amount of memory for the current activation being computed). To do this, simply use the with torch.no_grad(): context manager.

Conclusion: the formula

In summary, here are some of the biggest factors affecting your GPU usage.

  • Batch size: forward pass memory usage scales linearly with batch size.
  • Model size: model weights, gradients, and stored gradient momentum terms scale linearly with model size.
  • Optimizer choice: if you use a momentum-based optimizer, it can double or triple the amount of memory stored for your gradients.

Let m = model memory

Let f = the amount of memory consumed by the forward pass for a batch_size of 1.

Let g = m be the amount of memory for the gradients.

Let d = 1 if training on one GPU and 2 if training on >1 GPU.

Let o = the number of moments stored by the optimizer (probably 0, 1, or 2)

Let b = 0.5 if using mixed precision training, and 1 if using full precision training.

Then for training,

Max memory consumption = m + f*batch_size*b + d*g + o*m

For inference,

Max memory consumption = m

With these formulas in hand, you can make an educated choice on memory trade-offs, enabling you to train larger models faster. Here is a function that computes this formula for you. If there are any other PyTorch memory pitfalls that you have run into, let me know in the comments and I’ll add them to the post.

Follow me on Twitter: @jacobastern

--

--

Jacob Stern
Deep Learning for Protein Design

PhD student at Brigham Young University. Researching protein design with deep learning.