Deep Learning Memory Usage and Pytorch Optimization Tricks

Mixed precision training and gradient checkpointing on a ResNet

Quentin Febvre
Sicara's blog
3 min readOct 29, 2019

--

Read the original article on Sicara’s blog here.

Shedding some light on the causes behind CUDA out of memory ERROR, and an example on how to reduce by 80% your memory footprint with a few lines of code in Pytorch

Understanding memory usage in deep learning models training

In this first part, I will explain how a deep learning models that use a few hundred MB for its parameters can crash a GPU with more than 10GB of memory during their training !

So where does this need for memory comes from? Below I present the two main high-level reasons why a deep learning training need to store information:

  • information necessary to backpropagate the error (gradients of the activation w.r.t. the loss)
  • information necessary to compute the gradient of the model parameters

Gradient descent

If there is one thing you should take out from this article, it is this:

As a rule of thumb, each layer with learnable parameters will need to store its input until the backward pass.

This means that every batchnorm, convolution, dense layer will store its input until it was able to compute the gradient of its parameters.

Backpropagation of the gradients and the chain rule

Now even some layer without any learnable parameters need to store some data! This is because we need to backpropagate the error back to the input and we do this thanks to the chain rule:

Chain rule:(a_i being the activations of the layer i)

The culprit in this equation is the derivative of the input w.r.t the output. Depending on the layer, it will

  • be dependent on the parameters of the layer (dense, convolution…)
  • be dependent on nothing (sigmoid activation)
  • be dependent on the values of the inputs: eg MaxPool, ReLU …

For example, if we take a ReLU activation layer, the minimum information we need is the sign of the input.

Different implementations can look like:

  • We store the whole input layer
  • We store a binary mask of the signs (that takes less memory)
  • We check if the output is stored by the next layer. If so, we get the sign info from there and we don’t need to store additional data
  • Maybe some other smart optimization I haven’t thought of…

Example with ResNet18

Now let’s take a closer look at a concrete example: The ResNet18!

Continue reading “Deep Learning Memory Usage and Pytorch Optimization Tricks

--

--