Where did all the memory go ?
LLM finetuning version.
The last time I thought about memory was when I had started using Chrome. Recently, when I was trying to fine-tune a small LLM (:P) with model weights of approx ~13 GB on a A-100 40 GB card, it failed with the dreaded CUDA OOM error. Decreasing batch sizes to single digits and sequence length significantly didn’t help which led me to thinking : Where did all my memory go ?
Context:
Models have grown 1000x in number of parameters in the last 5 years (BERT -> GPT4), but GPU memory has just 5x (V-100 16 GB -> A-100 80 GB). So, how do I even fine-tune large models on common cloud GPUs ?
Existing solutions such as Data parallelism won’t help in fitting models in fitting in device memory and model parallelism is tricky to implement and is not very efficient throughput wise. Before diving into solutions, let’s figure out what’s happening with the memory.
What’s taking up memory?
- During training, the three main memory consumption sources are optimizer states, gradients, and parameters. Besides these activations and temporary buffers consume the rest of the memory. Fragmented memory adds to the woes.
- Let’s consider we are training using AdamW optimizer.
Model weights (parameters) = 4 bytes * num_params = N
Optimizer tensors = 8 bytes * num_params = 2N
Gradients = 4 bytes * num_params = N
The total memory for training is 4N !
Let’s take a concrete example:
- Llama-7B model weight is around 12 GB. That means we require ~48 GB+ GPU memory per card to finetune Llama-7B. The typical A-100 GPU card available on AWS has a memory of only 40 GB.
Activations too consume GPU memory. These are dependent on your batch size and sequence length (the easiest knobs to turn).
How do you finetune large models on commodity cloud GPUs?
- In (DDP) training, each worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. While DDP has become very popular, it takes more GPU memory than it needs because the model weights and optimizer states are replicated across all DDP workers.
- If you cannot fit in the model parameters (p), optimizer (o), and gradients(g) in device, you cannot train using DDP.
Fully sharded Data paralllel (FSDP) is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. FSDP GPU memory footprint would be smaller than DDP across all workers due to this. - Lower memory footprint makes training and fine-tuning of huge models feasible on lower configuration GPUs. Both Pytorch [1] & Deepspeed [2] offer FSDP libraries with memory optimizations to tackle this.
- The same optimizations also helps to fit larger batch sizes for training job. Increasing batch size can improve your training throughput significantly.
How does FSDP work internally ?
Even though the parameters are sharded to different GPUs, but the computation for each microbatch of data is still local to each GPU worker. This conceptual simplicity makes FSDP easier to understand and more applicable to a wide range of usage scenarios.
Setup:
- Shard model parameters and each rank only keeps its own shard. (Note DDP keeps all parameters in each rank.)
In forward path
- Run all_gather [communication] to collect all shards from all ranks to recover the full parameter in each rank. Run forward computation. [compute]
- Discard parameter shards it has just collected.
In backward path
- Run all_gather [communication] to collect all shards from all ranks to recover the full parameter in each rank. Run backward computation. [compute]
- Run reduce_scatter to sync gradients. [communication]
- Discard parameter shards it has just collected.
Other memory optimizations:
- Overlap communication & compute: FSDP training needs to communicate more as o,p and g are sharded across ranks.
The compute and communication is made to overlap to improve efficiency.
- Offload to CPU: You can use a CPU node to store model parameters and optimizer state. With optimizers such as CPUAdam, you can fine-tune a model on GPU + CPU devices with competitive training efficiency and far lower cost.
- Activation checkpointing (or gradient checkpointing) is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass.
Learnings:
- Memory optimizations are necessary to fine-tune larger models. Fully Sharded Data parallel approach offers a path reducing memory pressure on device while not impacting throughput considerably.
- With careful experimentation, you can achieve higher throughput, lower cost and possibly both with the memory optimizations provided in FSDP.