Distributed full fine-tuning of Llama2 on Kubernetes

Jiahao
6 min readDec 19, 2023

Despite the popularization of PEFT for LLMs, full fine-tuning is still an important use case to consider from an MLOps perspective, as it represents the most resource intensive workload that can be performed using LLMs. For example, full fine-tuning of a small 7.5b model using Adam will require ~120GB of memory (VRAM/RAM) for the parameters, gradients and more importantly the optimizer states, which is proportional to the number of parameters. Below is a great reference diagram from the ZeRO paper which mentions this baseline memory consumption, and the memory consumption per device when distributed over N_d GPUs in different ways.

from the ZeRO: Memory Optimizations Toward Training Trillion Parameter Models paper. Note that distributing all parameters, gradients, optimizer states across all GPUs is the most memory-efficient.

Attempting to fully fine-tune a 7.5b model on a single GPU is difficult even with the latest H100 from NVIDIA. Thankfully, there are methods to distribute the memory usage across different GPUs (or even CPU RAM) such as ZeRO and FSDP.

Personally, my aim is to support LLM training (full fine-tuning included) with our experiment manager using our existing Kubernetes cluster. This article will describe and document down the process and learning points of getting full fine-tuning to work on a Kubernetes cluster with limited resources. The repo can be found here.

Setup

Kubernetes Cluster

The cluster contains multiple nodes, each with 4 V100 GPUs (32GB VRAM). Note that they are rather outdated and do not support bfloat16. Training can be done using 2 kinds of setups: undistributed using single node single GPU, and distributed using single node multi GPU or multi node multi GPU.

Training Script

The training script will be largely based from the llama-recipes repo by Meta which contains the scripts to perform full fine-tuning. For distributed setups, torch.distributed.fsdp is used to shard the model, optimizer and gradients across the different GPUs. Note that the model is sharded by wrapping each LlamaDecoderLayer (a subunit of the Llama2 model which includes attention and MLP) in an FSDP unit. Each GPU will then only materialize the parameters and gradients in an FSDP unit (in this case a LlamaDecoderLayer), perform computation, discard and proceed to the next FSDP unit. Hence, the memory consumption of training the Llama2 model can be distributed over several GPUs this way.

Model and Dataset

7b and 13b variants of Llama2 are downloaded from Meta and converted to HuggingFace format using the suggested script. The SamSUM dataset is used as suggested by the llama-recipes repo, and saved in .jsonl format.

Training Image

The training image is built with the following Dockerfile:

FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update && apt-get install -y \
python3.10 \
python3-pip \
vim \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app

COPY requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir torch torchvision torchaudio torchviz --extra-index-url https://download.pytorch.org/whl/cu118
RUN pip install --no-cache-dir -r requirements.txt

COPY model-7b/ /app/model-7b/
COPY model-13b/ /app/model-13b/
COPY dataset/ /app/dataset/

COPY src/ /app/

CMD python3 finetuning.py --split_slice 1% --use_peft --quantization

The model weights and dataset files are copied to the image itself which results in a ~53GB image, which is not recommended! Alternatively, use a K8s PV or download the weights from a bucket. Note that the --split_slice argument is implemented to control (and reduce) the amount training and validation data used, as being able to train is the main focus, rather than the quality of training. By default all of the SamSUM dataset is used.

The entry command above executes LoRA PEFT of the 7b model which is quantized to 4 bits, using only 1% of the SamSUM dataset (147 training samples). The training process utilizes ~7GB VRAM and runs on my own machine with 3080 Ti (16GB VRAM). Note that the default batch size is 1 if unspecified, which is different compared to the default batch size of 4 in llama-recipes, due to GPU memory constraints.

Deployment

Single node single GPU

The training image can be deployed as a pod with a similar entry point (which simply runs the python script):

python3 finetuning.py \
--use_peft \
--quantization

Single node multi GPU

The training image can be deployed as a pod with an entry point that uses torchrun to run the script:

torchrun \
--nnodes 1 \
--nproc_per_node 4 \
finetuning.py \
--enable_fsdp \
--use_fp16 \
--fsdp_config.optimizer SGD

The above entry point uses 4 GPUs in the node to fully fine-tune the 7b model in mixed precision using SGD (note that the SGD optimizer is not available in llama-recipes code). The model is successfully fine-tuned with peak GPU memory usage of ~25GB per device.

Multi node multi GPU

To perform multi-node training, PyTorch uses a mechanism for distributed synchronization called rendezvous (rdzv), with 2 supported backends namely c10d and etcd-v2. C10d was used as the rdzv backend in order to not introduce etcd as an additional dependency.

For multi node multi GPU setup, one pod is to be deployed per node (refer to the yaml files here and here for a 2 node example). Note that a headless K8s service is required per pod to resolve the IP addresses of the pods during synchronization. In addition, the service URL needs to be passed explicitly using the --local_addr flag to torchrun. The training processes must be able to reach a common rdzv endpoint for synchronization, which in this case is hosted in the first pod llm-test.

K8s cluster setup for multi node multi GPU training

An example entry point for the first pod llm-test in node0 is:

torchrun \
--nnodes 2 \
--nproc_per_node 2 \
--rdzv-id=123 \
--rdzv-endpoint=llm-test.llm-test.svc.cluster.local:29500 \
--rdzv-backend=c10d \
--local_addr llm-test.llm-test.svc.cluster.local \
finetuning.py \
--enable_fsdp \
--use_fp16 \
--fsdp_config.optimizer SGD \
--split_slice 1%

An example entry point for the second pod llm-test-2 in node1 is:

torchrun \
--nnodes 2 \
--nproc_per_node 2 \
--rdzv-id=123 \
--rdzv-endpoint=llm-test.llm-test.svc.cluster.local:29500 \
--rdzv-backend=c10d \
--local_addr llm-test-2.llm-test.svc.cluster.local \
finetuning.py \
--enable_fsdp \
--use_fp16 \
--fsdp_config.optimizer SGD \
--split_slice 1% \

The example entry points uses 4 GPUs (2 GPUs per node) to fully fine-tune the 7b model in mixed precision using SGD. The model is successfully fine-tuned with peak GPU memory usage of ~25GB per device, which is the same as the single node multi GPU setup. However, using the same amount of data (split_slice of 1%), multi node training is ~55x slower than single node training (1 node 4 gpu average epoch time of 27s vs 2 node 4 gpu average epoch time of 1487s) for the 7b model, due to the cost of synchronization over network. Note that multi node training could be further improved by having better connectivity between nodes (such as InfiniBand), or optimizing the deployment (such as having a node host the rdzv endpoint instead of a pod; but the node itself does not run any training process for my case; or taking a look at TorchX Kubernetes Scheduler).

Memory Usage

For successful trainings both train and validation loss and perplexity do decrease every epoch. Also note that the 7b cannot be trained using Adam and 4*32GB GPUs even though a 120GB requirement is mentioned at the beginning. This might be due the torch implementation of Adam trading off memory for speed. For other details, refer to the README.md

Hence, the minimum resources required to fully fine-tune a Llama2 model are:

  • 7b in full precision: 4 V100 GPUs (preferably in a single node)
  • 7b in mixed precision: 4 V100 GPUs (preferably in a single node)
  • 13b in full precision: 8 V100 GPUs (across 2 nodes)
  • 13b in mixed precision: 6 V100 GPUs (across 2 nodes)

Memory requirements can possibly be further reduced by using flash attention 2 during training, with the aim of being able to train the 13b on a single node.

--

--