Memory optimization: Cure Out Of Memory errors like a doctor

Ambrose Ling
deMISTify
Published in
13 min readApr 28, 2024

Have you ever tried training your own LlaMA model, or fine-tuning your own Mistral 7B, or trying to fine tune your own version of Stable Diffusion or just ANY machine learning model but you keep getting CUDA Out Of Memory (OOM) errors mid way. Maybe its exactly because of these errors that are preventing you from turning these exciting ML ideas into a reality on your consumer-level hardware. So you may think “oh why not just buy a GPU with bigger memory?” Well, when you’re a broke uni student like me and for many others in the world that don’t have this capability, we need to find a better solution, so how can we reduce the significant training cost associated with these large models in regards to memory?

In this article I wish to break down how memory is consumed by your model during training and help you understand why memory is allocated the way it is and how can some of the latest research advances in memory optimization be applied to reduce different aspects of memory used by your model during training.

What I realized is that the process of debugging OOM errors share a lot in common with how a doctor treats their patients. The patient first comes to the doctor with certain symptoms and the doctor thoroughly analyzes where these symptoms are coming from or which parts of their body are showing these symptoms. Based on the doctor’s professional knowledge and experience, they make a diagnosis of what could be the cause of the disease. They then are able to provide the appropriate treatment to the patient. Let’s see how we can apply this analogy to memory optimization!

Consultation & Diagnosis: Breaking down memory allocations

To start a consultation session with a patient, doctors usually ask them questions, use tools like stethoscopes, blood pressure monitors, thermometers to get a better sense of the patient’s health situation. In our case, the patient we have is a memory-faulty model. In order to obtain a better understanding of how memory is used during model training, we can use tools to help us make a diagnosis. One of the tools that I found very useful is using the torch memory profiler. The torch memory profiler allows you to visualize memory allocations of your training run. With their comprehensive API, it allows you to see the execution times as well as the amount of memory allocated for certain operations in the model. One very useful feature is viewing the categorized memory usage over time by your model during training. (Here is a detailed blog post on how to use the torch profiler)[1].

Before we dive into all the specifics on the profiler and all that jazz, let’s try to understand what happens under the hood in your typical training loop. In my example, I’m simply training a model to perform image classification.

for epoch in range(epochs):
for batch in dataloader:
y = model(x) #FORWARD PASS
loss = loss_fn(y,batch.y) #LOSS COMPUTATION
loss.backward() #BACKWARD PASS
optimizer.step() #GRADIENT DESCENT

As you iterate through each epoch, you extract a batch of data from your dataloader. In forward pass, your model takes an input batch of data and passes it through the hidden layers of the model. During that process, PyTorch dynamically constructs a computation graph (See Figure 2 below). In this computation graph, a node represents a tensor operation and an edge represents a parent-child relationship between operations (the output of operation A is the input to operation B). During forward pass, torch computes the forward pass operations and each node stores intermediate outputs of the operations which we call activations. Simultaneously when an operation is executed, torch obtains the derivative of that operation and constructs a graph for our backward pass using automatic differentiation. Only after we compute the loss and run loss.backwards() aka backward pass, the nodes of the backward pass graph get evaluated and thus gradients would be computed and stored in memory[2].

With the gradients ready, we can apply a gradient descent step with optimizer.step() which nudges the weights in a certain direction in order to minimize the loss.

If we use the torch profiler, we can obtain a memory consumption graph corresponding to our training loop. And the profiler has helped us categorize the memory usage over time.

Figure 1: Memory graph over time for training a ViT for image classification

Lets break down the categories that consume the most memory in more detail and understand the behaviors observed:

  1. Parameters (green)

These are the weights of your model that store the strength of connections we have between artificial neurons. These weights get multiplied with the input in order to propagate to the output. When we train our neural network, we increase or decrease these weights depending on the data it sees in order to optimize a loss function. The parameters occupy constant memory as there is only 1 copy of the weights and biases for the model. We do not need to allocate additional memory as training the model only requires us to update the values of these weights and biases.

2. Gradients (dark blue)

Gradients refer to the change in the weights with respect to change in loss. Each layer of your deep learning model is essentially a series of functions chained together that are differentiable. In order to perform gradient descent, we must compute the gradients of the loss with respect to the weights in order to perform gradient descent.

3. Activations (red)

Activations are intermediate results of your model that are stored in memory when we construct the computation graph. It is sometimes hard to predict how much memory activations would consume as it depends on batch size and the model architecture.

NOTE on both activations and gradients!!!

From Figure 1, an interesting observation is that we can observe cyclic peaks in activation memory. Another observation is that as activation memory decreases, we allocate more and more memory for gradients. To explain this, we may use this animation created by the creators of one of the memory optimization techniques we will cover later. This animation represents the vanilla back-propagation algorithm.

Each node on the top row represents activations of different layers in your model while nodes on the bottom represent the gradients of the loss with respect to activations and parameters of these layers. During the forward pass we evaluate the nodes in order and in reverse order during backward pass. Purple nodes represent nodes that are taking up memory at a given time. An arrow represents a data dependency for computation.

Figure 2: Diagram of a full computation graph with vanilla back-propagation from https://github.com/cybertronai/gradient-checkpointing

This computation graph helps us explain the memory behavior we see from Figure 1. The activation memory (red) increases as more and more layers are evaluated and we store more activation tensors, keeping in mind that we have not evaluated the backward graph. When that is complete, we perform backward pass and gradually release activation nodes from memory and we only keep activations that are necessary for gradient computation (blue) until we complete it for all layers.

4. Optimizer States (yellow)

Optimizer states store the attributes of the optimizer for parameter groups. The role of the optimizer is to update the parameters of a network by using the gradient of the loss with respect to the weight ∂L/∂w at a certain iteration t.

For this article, we will focus on one of the most popular optimizers used in modern deep learning models which is the Adam optimizer. It is well known for its memory efficiency and training stability compared to other optimizers such as AdaGrad and RMSProp.

Figure 3:Adam Optimization algorithm

In the Adam optimization algorithm, we update the parameters based on a moving average of the first moment vector and a second moment vector[3]. What makes Adam unique is that we can apply individual learning rates to each parameter influenced by the first moment (momentum of the paramters) and the second moment (variability of the parameters), which allows for more efficient training with larger models. In terms of memory consumption, this algorithm requires us to store a first moment vector and a second moment vector, each has the same size as the number of parameters you have in your model, which is why you see that the optimizer states occupy approximately 2x the amount of memory than the parameters. This is logical because under the Adam optimizer, we keep track of 2 quantities for each parameter in our model (momentum and variance). And you may notice that memory is not allocated until the first training iteration. This is because optimizer states are not computed and stored until the completion of the first forward and backward pass.

We now observe these memory behaviors and provided some preliminary diagnosis. What comes after that? Treatment! Once a doctor understands what the underlying illness is and makes a medically reasoned diagnosis, they will decide on a treatment plan for their patient in order for them to combat their illness and recover. Picking the right memory optimization technique is the exact same idea. Different methods tackle different memory bottlenecks and may be effective towards your use case depending on where your bottleneck is. Here I propose 5 research advances in optimization that you may find helpful.

Treatment: Memory optimization techniques

  1. Gradient check-pointing (Paper)

Gradient check-pointing finds a way to balance between recomputing activation nodes and memory cost. Developed by Tim Salimans and Yaroslav Bulatov, they proposed a way of checkpointing or preserve some nodes in memory and we utilize these checkpointing to compute other remaining nodes at most once. One thing to note is that the checkpoint node is kept in memory after the forward pass.

Figure 4: adding a checkpoint node to computation graph

Let’s assume that we mark the circled node as the checkpoint.

Figure 5: gradient checkpointed computation graph

If we pay attention to the computation graph with a checkpoint node, we release the stored activation nodes as soon as they are no longer required.

Authors of this work found that dividing the entire network into K chunks and marking the node at every K = sqrt(n) chunk as a checkpoint is the most optimal, where n is the number of layers in your network[4]. To calculte the activations of all the other nodes is equivalent to at most running forward pass once. Doing so allows us to reduces the amount of memory sub-linearly proportional to the number of layers in our network rather than linearly. Intuitively, you can compare the vanilla backprop with gradient checkpointed backprop and see that less nodes are actively consuming memory at a given time (less purple nodes at once).

2. Flash Attention (Paper)

In a lot of models with transformer-based architectures, the attention mechanism is the main backbone for their ability to establish meaningful connections between tokens in sequential data, yet this operation can impose significant overhead as well as memory consumption. Researchers of this method take advantage of the hierarchical structure of GPU memory and running attention computations on GPU SRAM which has greater memory speed, rather than HBM (high bandwidth memory)[5].

Figure 6: Flash attention diagram with memory hierarchy
Figure 7: Standard attention algorithm

In original attention implementation the operation Q•K^T involves computing the attention weights between all tokens. This procedure along with the softmax operation requires repeatedly reading and writing within the HBM, which leads to a large amount of memory accesses for every block. Furthermore, previous implementations requires storing intermediate tensors S and P in order to compute gradients with respect to Q,K,V for backpropagation, which can increase memory usage.

Figure 8: Flash Attention algorithm. Red box indicates section of faster attention computation moved to SRAM

With flash attention, the authors propose to use tiling and recomputation to reduce memory accesses as well as to lower memory usage on storing intermediate activations. Tiling basically divides the attention computation into blocks, each block is moved to SRAM for fast computation and we combine the results and move it back to HBM. Recomputation allows us to omit saving intermediate tensors S and P (selective form of gradient checkpointing). With the output O and softmax normalization statistics (m and l), we can recompute S and P easily with block of Q,K,V already in SRAM.

3. 8-bit Adam optimizers (Paper)

As discussed optimizer states keep track of gradient statistics over the course of training. However, these statistics from the Adam optimizer for instance consume a considerable amount of memory. This work leverages quantization to maintain 8-bit statistics of the optimizer states rather than 32 bits, while preserving performance through block-wise dynamic quantization[6].

In brief, quantization is the act of mapping k-bit integers to a real element, usually represented by 32bits. We do so by mapping the input tensor to a domain in the desired quantization data type, then find the closet value in domain D and finally store the quantized value.

Figure 9: equation for quantization

This equation establishes the relationship between quantized (T^Q)and de-quantized tensors (T^D). N represents the normalization constant for transforming to the quantized range (usually N = max(|T|)), Q^map represents the mapping from [0,(2^k)-1]->D. If you wish to learn more about the details of quantization, you can check out this awesome article by Jeremy Qu!

Building on top of his explanations, block wise quantization is a method of reducing quantization error and improves training stability by chunking the optimizer state tensors into different blocks and computing a normalization constant for each block.

Figure 10: block wise dynamic quantization

This blocking mechanism allows for more robust outliers in the input tensor by minimizing the unused range of the quantization bucket and restricting the effect of outliers to a single block. (Learn more about it here!)

4. LoRA fine-tuning (Paper)

Low Rank Adaptation is a fine tuning technique that significantly reduces memory consumption by making the weight updates much more parameter efficient.

Figure 11: LoRA parametrization, x is the input to both W and A•B

When we talk about low rank matrices, we are essentially saying that the matrix has columns or rows that can be expressed as a linear combination of each other (not linearly independent). Hence we can find a representation that is lower dimensional but still captures the essential information of this matrix[7].

This idea of using low rank matrices for weight updates makes it easier to fine tune large models to different downstream tasks. When we perform fine-tuning on for example a weight matrix W of a linear layer, instead of performing W+ΔW, we use low rank decomposition to represent ΔW (d x r) with 2 matrices A and B where B is d x r and A is r x k. Thus during fine-tuning the pre-trained weights W are frozen while AB are trainable. Authors of LoRA realises that there exists some low “intrinsic dimension” to the update matrices and can achieve similar performance as full parameter training[8].

With this new method, our forward pass would look something like this:

Figure 11: LoRA modified forward pass equation (W_o is frozen, BA are low rank adapted matrices)

LoRA has allowed many researchers to fine tune large language models, diffusion models etc to their own applications and use cases. In terms of memory reduction, LoRA allows us to reduce the number of training parameters by a significant portion (on GPT3 with 175B parameters, GPU memory was reduced by 3x).

5. GaLoRe Gradient Low Rank Projection (Paper)

GaLore is a much more recent work (literally published 2 weeks ago) that made significant breakthroughs in memory efficiency and builds upon the idea of low rank projections from LoRA. Instead of limiting the parameters to low-rank, this work performs a low rank update on the gradients instead of the parameters[9].

Figure 12: GaLore pseudocode algorithm

Their formulation retains all the original parameters of the pretrained model. However, GaLore proposes to compute 2 projection matrices P and Q that projects the gradient matrix to and from a low dimensional sub-space when we perform a gradient update optimizer.step(). GaLoRE is much more memory efficient AND than LoRA due to several properties in their methodology:

a) LoRA requires storing the activations of the original parameters as well as the non-LoRA parameters while GaLore only requires the activations of the full parameters set.

b) GaLore does not require storing seperate low rank matrices A•B as the ΔW is directly applied to the pretrained weights during a weight update (i.e. the last step in Figure ).

c) GaLore is adaptable to Adam and further saves memory on optimizer states as theses statistics would have the same dimensionality as the low rank gradients. They can also be adapted to 8-bit Adam, which further reduces memory requirements for storing the statistics.

Figure 13: GaLore VS LoRA comparison

GaLore made it possible for a consumer level GPU RTX4090 24GB to perform pretraining on a LLaMA model with 7B parameters without applying any other memory optimization tricks. This marks a significant milestone of making billion parameter model training, fine-tuning much more accessible to the ML/AI community.

Final Remarks

To summarize all the methods described above, this table is an overview of the memory bottlenecks these methods may help with.

Figure 14: Overall comparison of all the methods

With the right memory optimization methods, you can reduce training costs on your models significantly and bring those ML project ideas of yours into a reality on your hardware!

References

  1. https://pytorch.org/blog/understanding-gpu-memory-1/
  2. https://pytorch.org/blog/computational-graphs-constructed-in-pytorch/
  3. https://medium.com/r/?url=https%3A%2F%2Ftowardsdatascience.com%2Fthe-math-behind-adam-optimizer-c41407efe59b%23e382
  4. https://github.com/cybertronai/gradient-checkpointing
  5. https://arxiv.org/pdf/2205.14135.pdf
  6. https://arxiv.org/pdf/2110.02861.pdf
  7. https://blog.ml6.eu/low-rank-adaptation-a-technical-deep-dive-782dec995772#:~:text=Low%2DRank%20Matrix%3A%20A%20rank,min(m%2C%20n).
  8. https://arxiv.org/pdf/2106.09685.pdf
  9. https://arxiv.org/pdf/2403.03507.pdf

--

--

Ambrose Ling
deMISTify

I like neuroscience, machine learning, business, computational biology:)