Fine-tuning Llama 2 on Trainium instances

Randall DeFauw
2 min readSep 19, 2023

--

Llama 2 is a powerful and popular large-language model (LLM) published by Meta. There’s a lot of interest in fine-tuning Llama 2 with custom data and instructions. Meta has provided a fine-tuning recipe, and AWS has published example of how to fine-tune Llama 2 via SageMaker Jumpstart.

AWS has also published an example of how to pretrain Llama 2 on Trainium instances. This example pretrains Llama 2 on a large corpus of unlabeled data, in this case one of the RedPajama datasets. I wanted to see how I could adapt this to a fine-tuning example instead, where we use a smaller batch of data to adapt the model to a specific task. I decided to use the samsum dataset, which is a summarization dataset.

This dataset has labeled examples. They include the original text and the human-generated summary. In order to fine-tune it, I first converted it into a single prompt, using the same method as in the Llama 2 fine-tuning recipe.

def apply_prompt_template(sample):
return {
"text": prompt.format(
dialog=sample["dialogue"],
summary=sample["summary"],
eos_token=tokenizer.eos_token,
)
}

dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
dataset.to_json(f"samsum-train.jsonl")

Then I ran the preprocessing script in the pretraining example, precompiled the model with Neuron, and launched the fine-tuning job.

One question might come to mind — why bother running a fine-tuning job on a Trainium machine? The main reason is cost. A trn1.32xlarge machine has 512 GB of total accelerator memory and costs $21.50 per hour on-demand, while a p4d.24xlarge, which has a total of 640 GB of GPU memory, costs $32.77 per hour on-demand. Of course not all accelerators are made equal and I’d want to test how long a fine-tuning job takes to run on both machines, but given the size of these models, being frugal with resources makes sense.

Finally, you will naturally wonder about using PEFT techniques like LORA to reduce training time and cost. The SageMaker Jumpstart example for fine-tuning Llama 2 makes use of these techniques, but my example does not. That’s the next thing I want to experiment on.

--

--

Randall DeFauw

I am s Sr. Principal Solutions Architect at AWS. The opinions I express here are my own, and do not reflect the views of my employer.