Deep Learning Memory Usage and Pytorch Optimization Tricks

Mixed precision training and gradient checkpointing on a ResNet

Quentin Febvre
Oct 29 · 3 min read

Shedding some light on the causes behind CUDA out of memory ERROR, and an example on how to reduce by 80% your memory footprint with a few lines of code in Pytorch

Understanding memory usage in deep learning models training

In this first part, I will explain how a deep learning models that use a few hundred MB for its parameters can crash a GPU with more than 10GB of memory during their training !

So where does this need for memory comes from? Below I present the two main high-level reasons why a deep learning training need to store information:

  • information necessary to backpropagate the error (gradients of the activation w.r.t. the loss)
  • information necessary to compute the gradient of the model parameters

If there is one thing you should take out from this article, it is this:

As a rule of thumb, each layer with learnable parameters will need to store its input until the backward pass.

This means that every batchnorm, convolution, dense layer will store its input until it was able to compute the gradient of its parameters.

Now even some layer without any learnable parameters need to store some data! This is because we need to backpropagate the error back to the input and we do this thanks to the chain rule:

Chain rule:(a_i being the activations of the layer i)

The culprit in this equation is the derivative of the input w.r.t the output. Depending on the layer, it will

  • be dependent on the parameters of the layer (dense, convolution…)
  • be dependent on nothing (sigmoid activation)
  • be dependent on the values of the inputs: eg MaxPool, ReLU …

For example, if we take a ReLU activation layer, the minimum information we need is the sign of the input.

Different implementations can look like:

  • We store the whole input layer
  • We store a binary mask of the signs (that takes less memory)
  • We check if the output is stored by the next layer. If so, we get the sign info from there and we don’t need to store additional data
  • Maybe some other smart optimization I haven’t thought of…

Now let’s take a closer look at a concrete example: The ResNet18!

Continue reading “Deep Learning Memory Usage and Pytorch Optimization Tricks

Sicara's blog

We build tailor-made AI and Big Data solutions for amazing clients

Thanks to Juliep and Arnault Chazareix

Quentin Febvre

Written by

Agile Data Scientist, Working @Sicara_fr

Sicara's blog

We build tailor-made AI and Big Data solutions for amazing clients

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade