Optimizing FLAN T5: A Practical Guide to PEFT with LoRA & Soft Prompts

Jack Harding
Nerd For Tech
Published in
6 min readMay 24, 2024

Fine-tuning large language models, crucial for adapting these models to specific tasks and enhancing their performance. The traditional fine-tuning process can be resource-intensive, demanding significant computational power. This challenge has led to the development of more efficient methods for fine-tuning, collectively known as Parameter-Efficient Fine-Tuning (PEFT).

PEFT techniques, such as Low-Rank Adaptation (LoRA) and soft prompts, offer a promising solution to reduce the computational burden and cost associated with fine-tuning large models. These methods optimize only a subset of the model parameters or introduce additional trainable components, significantly decreasing the required resources while maintaining or even improving model performance.

In this article, I will explore the theory behind PEFT, providing a comprehensive understanding of how these methods work and their benefits over traditional fine-tuning. Also, I will illustrate how to implement these methods using popular machine learning frameworks.

By the end of this article, you will have a clear understanding of the advantages of PEFT methods, the process of fine-tuning FLAN T5 using these techniques, and an evaluation of their performance. This knowledge will equip you to apply PEFT in your own NLP projects, potentially saving significant computational resources and costs.

Theoretical Background

Fine-tuning a language model is where a larger language model, that had resource-intensive training, is specialised to a certain dataset instead of retraining a new model. PEFT is a more refined approach to this, where the original model weights are maintained, and new layers are added for the new data.

LoRA

LoRA using a low-rank (LoR) decomposition matrix to train the new weights mentioned. The rank determines the dimensions of that matrix; usually bigger is better, up until about 8 or 16 where increasing it yields diminishing returns. This make for far faster training and lower computation costs. To further this goal, quantization can be used to reduce the space each parameter occupies.

Decomposition matrix to reduce the output

Soft Prompts

Not to be confused with prompt engineering, where prompts are defined at inference by a human (hard). Soft prompts are weights (soft) within the embedding layer created during fine-tuning. The base model’s weights are frozen and a new tunable layer is appended. This means multiple tasks can be trained on data and replaced very quickly. From a software engineering perspective, this also decouples different applications, making future developments easier to manage.

Fine-tuning vs Soft prompt

Fine-Tuning the FLAN T5 Model

FLAN-T5

FLAN-T5 is a variant of the T5 (Text-To-Text Transfer Transformer) model, designed to enhance the capabilities of the original T5 by incorporating a broader range of training tasks and datasets. Developed by Google Research, FLAN-T5 is trained using a mixture of supervised learning on a diverse set of tasks, such as translation, summarization, and question answering, along with reinforcement learning from human feedback. This comprehensive training regimen aims to improve the model’s generalization and adaptability, making it more robust and versatile in handling various natural language processing tasks.

Imports

After installing a long list of pip dependencies (on my GitHub), the dataset and model can be imported:

dataset = load_dataset("knkarthick/dialogsum")

model_name='google/flan-t5-base'
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

To have some sort of baseline, zero-shot inferencing was tested with pretty underwhelming responses.

Fine-Tuning

The goal is for the model to be able to describe what happened in a customer support chat. I tried fine-tuning the model on the dataset above, but I stopped after 20 minutes. Luckily, I have a fine-tuned FLAN-T5 model to hand thanks to a course I completed recently on Coursera. Here is a snippet of what fine-tuning would look like.

training_args = TrainingArguments(
output_dir=output_dir,
learning_rate=1e-5,
num_train_epochs=1,
weight_decay=0.01,
logging_steps=1,
max_steps=1
)

trainer = Trainer(
model=original_model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation']
)

PEFT

As I mentioned before, the underlying FLAN-T5 model is frozen and a new layer is appended.

lora_config = LoraConfig(
r=32, # Rank
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)
peft_model = get_peft_model(original_model, lora_config)

Training the PEFT layers was significantly faster, taking around 6 minutes (TPU + m5.xl in AWS).

peft_training_args = TrainingArguments(
output_dir=output_dir,
auto_find_batch_size=True,
learning_rate=1e-3, # Higher learning rate than full fine-tuning.
num_train_epochs=1,
logging_steps=1,
max_steps=1
)

peft_trainer = Trainer(
model=peft_model,
args=peft_training_args,
train_dataset=tokenized_datasets["train"],
)
peft_trainer.train()

The next step is adding the PEFT layers to the original one. Note the is_trainable=False : this tells the base model to keep its original weights and essentially glue the new ones onto the end of it.

peft_model = PeftModel.from_pretrained(
original_model,
output_dir,
torch_dtype=torch.bfloat16,
is_trainable=False
)

Evaluation

ROUGE-L (Recall-Oriented Understudy for Gisting Evaluation — Longest Common Subsequence) evaluates text summaries by comparing them to reference summaries based on the longest common subsequence (LCS). It measures the overlap in sequences of words between the generated and reference texts, emphasizing the importance of maintaining the order of words. ROUGE-L helps assess how closely a model’s output matches human-generated summaries, making it a reliable metric for evaluating the quality of summarization and translation models.

Nice explainer from James Briggs

Each model’s summaries are compared to human baseline ones included in the original dataset. For simplicity, I will only use the LCS results.

def compute_rouge(model_summaries):
return rouge.compute(
predictions=model_summaries,
references=human_baseline_summaries[0:len(model_summaries)],
use_aggregator=True,
use_stemmer=True,
)

original_model_results = compute_rouge(original_model_summaries)
instruct_model_results = compute_rouge(instruct_model_summaries)
peft_model_results= compute_rouge(peft_model_summaries)

The original model only achieves 20.1%, the fine-tuned one got 33.8% and the PEFT-LoRA got 32.5%. This means the fully fine-tuned model only got 1.34% more than the PEFT model despite the training time being 30 mins+.

Due to cost reasons, I could not quantitatively compare the training times of each approach. This would have been very valuable to have a “bang for your buck” metric for training.

Conclusion

The fully fine-tuned model performs marginally better than the PEFT one, but uses a lot more resources. This seems like a no-brainer, for most applications choosing the cheaper, almost-as-good option makes more sense. The results above only reflect a small subset of what LLMs might be used for. One drawback of PEFT is latency; at inference time, the model has additional overhead due to the extra layers. In high-performance applications, the time it takes to respond might play a key factor.

Training LLMs can be very expensive, leaving most people out of the race when building applications with it. Advances such as PEFT and LoRA lower the bar for exploring this technology and seem to accommodate most non-critical requirements. The model’s reduced storage size (~17MB) means that it can be stored in memory, reducing the complexity of small applications. I have yet to test these models with quantization (QLoRA) which reduces each precision of the model weights to further reduce training time and storage size.

As more discoveries minify these models, the image of the data scientist training on huge server racks seems more distant, with highly capable models being easily accessible by edge devices such as microcontrollers and phones. The result of which will make these human-like algorithms available at all times, for better or for worse.

--

--