Fine-tune LLMs for free on custom text data: A Step-by-step Tutorial

Table of Contents

What’s Llama-2?

Llama-2 is a family of open-source LLMs released by Meta. Llama-2 7B is the smallest model of this family in terms of parameter count. The “chat” variant of Llama-2 7B is optimized for chatbot-like dialogue use cases. This is particularly useful for applications that involve conversations as it’s optimized to generate responses in a conversational context, making it particularly useful for applications like chatbots or virtual assistants. The Llama-2 7B chat model is smaller and faster than its counterparts in the Llama-2 family, making it a good choice for speed and cost-efficiency at the expense of some accuracy.

Image generated using DALLE-3

What’s fine-tuning and why fine-tune?

Fine-tuning LLMs essentially means taking a pre-trained model like Llama-2 that has been already been trained on a massive datasets and making minor changes to the weights of the trainable parameters of this model to optimize its performance on a new, specific task or dataset. During the process of fine-tuning, the overall architecture of the pre-trained Llama-2 model remains unchanged since only a small set of parameters’ weights is modified to learn the important features of the training dataset.

Fine-tuning offers several advantages:

  1. Cost-effective and efficient: Training a LLM from scratch can be extremely time-consuming and computationally expensive. Hence, fine-tuning is a great alternative since it uses a pre-trained model and builds on this, significantly reducing the time and compute resources while achieving good results.
  2. Improved performance: Since pre-trained LLMs are already trained on massive amounts of data (~ 2 trillion tokens for Llama-2), by fine-tuning a pre-trained model, we can take advantage of this knowledge to improve performance on our new, specific task or dataset.

Let’s get started with fine-tuning Llama-2

This tutorial is based on this Google Colab notebook found here, where you could run all the cells sequentially and get your personal fine-tuned Llama-2 chatbot!

1. Environment setup and library imports

In this tutorial, we’ll be using the Nvidia T4 GPU with 16 GB of VRAM that is offered in the free version of Google Colab. If you’re running the notebook on your own GPU, that’s works too! The code below will automatically connect to the T4 GPU if running it on Colab, or the first GPU (if you’ve multiple GPUs) if you’re running it elsewhere.

!pip install GPUtil

import torch
import GPUtil
import os

GPUtil.showUtilization()

if torch.cuda.is_available():
print("GPU is available!")
else:
print("GPU not available.")

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Set to the GPU ID (0 for T4)

Now that you’ve established your GPU connection, it’s time to install (and import) the necessary libraries for fine-tuning.

!pip install git+https://github.com/huggingface/peft.git
!pip install accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes
!pip install transformers==4.30
!pip install datasets
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig,LlamaTokenizer
from huggingface_hub import notebook_login
from datasets import load_dataset
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from datetime import datetime

if 'COLAB_GPU' in os.environ:
from google.colab import output
output.enable_custom_widget_manager()

Since Llama-2 is governed by the Meta license, to download the model weights and tokenizer, please visit Meta’s website to accept their license and request access for their models in HuggingFace (usually should take less than a day to get access).

Once you’ve gotten access to the Llama-2 models, log in to HuggingFace to enter the write access token when prompted to load the model in your notebook.

if 'COLAB_GPU' in os.environ:
!huggingface-cli login
else:
notebook_login()

2. Load and Configure your model

Having completed our setup, it’s time to load our model (Llama-2 7B Chat) using QLoRA (quantization of parameter weights to 4 bits) to reduce memory requirements and increase training speed, while ensuring that we don’t reach the bottleneck of the 16GB GPU memory.

Note: In the code below, we load all trainable parameters in the 4-bit normal-float (nf4) datatype and use double quantization to further memory savings. However, our computational precision is 16-bits (bfloat16) since we want speedup compute of hidden states as the default datatype is float32.

base_model_id = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id,
quantization_config=bnb_config)

3. Loading and tokenizing the dataset

Most of our private data is in unstructured formats, like text files or pdfs. While reformatting this to structured data like JSON or CSV files could result in better training results since there is a clear mapping between question-answer pairs, however this format is labor intensive and only ideal for scenarios where data is exclusively Q&A pairs neatly organized and follows a predictable structure. We understand this and hence this tutorial focuses on fine-tuning Llama-2 solely on data in unstructured .txt files!

Since Llama-2 has been trained on data until July 2023, for this tutorial we’ll be using the data about the Hawaii wildfires in August 2023 sourced from the report of the Maui Police department found here. We’ve copied the data of the PDF into multiple text files without any additional formatting.

We’ll clone the GitHub repository containing the text files, and load them as training data.

!git clone https://github.com/poloclub/Fine-tuning-LLMs.git
train_dataset = load_dataset("text", data_files={"train":
["hawaii_wf_1.txt", "hawaii_wf_2.txt",
"hawaii_wf_3.txt","hawaii_wf_4.txt",
"hawaii_wf_5.txt","hawaii_wf_6.txt",
"hawaii_wf_7.txt","hawaii_wf_8.txt",
"hawaii_wf_9.txt","hawaii_wf_10.txt",
"hawaii_wf_11.txt"]}, split='train')

Having loaded our data, we’ll have to tokenize (break down sequences of text into smaller parts or “tokens”) this training data before passing this into Llama-2 to fine-tune it. We will initialize the LlamaTokenizer with the pre-trained Llama-2–7B-chat model and manually set the eos_token so that the model knows how to recognize the “end of sentence” and the pad_token to pad shorter lines to match the length of longer ones, since the LlamaTokenizer is known to have issues with this.

tokenizer = LlamaTokenizer.from_pretrained(base_model_id, use_fast=False,
trust_remote_code=True,
add_eos_token=True)

if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

Our tokenizer is configured, which means it’s now time to tokenize our training data!

tokenized_train_dataset=[]
for phrase in train_dataset:
tokenized_train_dataset.append(tokenizer(phrase['text']))

4. Configuring model with LoRA

We’re one step away from training the model! We need to enable gradient checkpointing to trade computation time for lower memory usage during training. We then setup our LoRA configuration to reduce the number of trainable parameters which would significantly reduce the memory and time required for fine-tuning. LoRA works by decomposing the large matrix of the pre-trained model into two smaller low-rank matrices in the attention layers which drastically reduces the number of parameters that need to be fine-tuned. Refer to the LoRA documentation to learn more about the parameters and use cases.

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
# rank of the update matrices
# Lower rank results in smaller matrices with fewer trainable params
r=8,

# impacts low-rank approximation aggressiveness
# increasing value speeds up training
lora_alpha=64,

# modules to apply the LoRA update matrices
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"gate_proj",
"down_proj",
"up_proj",
"o_proj"
],

# determines LoRA bias type, influencing training dynamics
bias="none",

# regulates model regularization; increasing may lead to underfitting
lora_dropout=0.05,
task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

5. Training the model

It’s finally time to train our Llama-2 model on our new data (yay!). We’ll be using the Transformers library to create a Trainer object for training the model. The Trainer takes the pre-trained model (Llama-2 7B chat), training datasets, training arguments (defined below), and data collator as input.

Training time depends on the size of the training data, number of epochs and the configuration of the GPU used. If you use the sample Hawaii wildfire dataset provided and run the notebook on Google Colab’s T4 GPU, then it should take around 1 hour 30 minutes to complete training for 3 epochs.

When you’re fine-tuning on your private data, we highly recommend you to modify the training parameters, particularly learning rate and number of epochs, to achieve the good performance of the fine-tuned model. While doing this, beware of overfitting!

Keep in mind that increasing the learning rate might lead to faster convergence, but it might overshoot the optimal solution. Conversely, a lower value may result in slower training but better fine-tuning. Also, increasing the number of epochs may allow the model to learn more from the data, but this may lead to overfitting.

trainer = transformers.Trainer(
model=model, # llama-2-7b-chat model
train_dataset=tokenized_train_dataset, # training data that's tokenized
args=transformers.TrainingArguments(
output_dir="./finetunedModel", # directory where checkpoints are saved
per_device_train_batch_size=2, # number of samples processed in one forward/backward pass per GPU
gradient_accumulation_steps=2, # [default = 1] number of updates steps to accumulate the gradients for
num_train_epochs=3, # [IMPORTANT] number of times of complete pass through the entire training dataset
learning_rate=1e-4, # [IMPORTANT] smaller LR for better finetuning
bf16=False, # train parameters with this precision
optim="paged_adamw_8bit", # use paging to improve memory management of default adamw optimizer
logging_dir="./logs", # directory to save training log outputs
save_strategy="epoch", # [default = "steps"] store after every iteration of a datapoint
save_steps=50, # save checkpoint after number of iterations
logging_steps = 10 # specify frequency of printing training loss data
),

# use to form a batch from a list of elements of train_dataset
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# if use_cache is True, past key values are used to speed up decoding
# if applicable to model. This defeats the purpose of finetuning
model.config.use_cache = False

# train the model based on the above config
trainer.train()

6. Loading your fine-tuned model

If you’ve reached this far, congratulations! You’ve successfully fine-tuned Llama 2 on your own data. Now, let’s load the finetuned model using the BitsAndBytesConfig we used previously. Ensure to choose the model checkpoint with the least training loss.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig,LlamaTokenizer
from peft import PeftModel

base_model_id = "meta-llama/Llama-2-7b-chat-hf"

nf4Config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = LlamaTokenizer.from_pretrained(base_model_id, use_fast=False,
trust_remote_code=True,
add_eos_token=True)

base_model = AutoModelForCausalLM.from_pretrained(
base_model_id, #same as before
quantization_config=nf4Config, #same quantization config as before
device_map="auto",
trust_remote_code=True,
use_auth_token=True
)

tokenizer = LlamaTokenizer.from_pretrained(base_model_id, use_fast=False,
trust_remote_code=True)


modelFinetuned = PeftModel.from_pretrained(base_model,
"finetunedModel/checkpoint-1455")

7. Test your fine-tuned model

It’s now time to test your personal fine-tuned Llama 2! Below are some example responses when asked questions about the some details of the 2023 Hawaii Wildfires.

### ENTER YOUR QUESTION BELOW

question = "Just answer this question: Tell me about the role of Maui Emergency Management Agency (MEMA) in the 2023 wildfires??"

# Format the question
eval_prompt = f"{question}\n\n"

promptTokenized = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

modelFinetuned.eval()
with torch.no_grad():
print(tokenizer.decode(modelFinetuned.generate(**promptTokenized, max_new_tokens = 1024)[0], skip_special_tokens=True))
torch.cuda.empty_cache()
Just answer this question: Tell me about the role of Maui Emergency Management Agency (MEMA) in the 2023 wildfires??

MEMA played a crucial role in coordinating and facilitating the response to the 2023 wildfires. MEMA’s primary role is to coordinate and facilitate the response to emergencies and disasters in Maui County.
In the case of the 2023 wildfires, MEMA’s responsibilities included:

1. Coordinating with other emergency response agencies, such as the Maui Police Department, Maui Fire Department, and the Hawaii National Guard, to ensure a unified response.
2. Activating the Maui Emergency Operations Center (EOC), which served as the centralized location for coordination of response and support activities.
3. Providing situational reports and information to the County of Maui Mayor’s office and other stakeholders.
4. Coordinating with media outlets to ensure accurate and timely dissemination of information to the public.
5. Managing and coordinating volunteer and donated resources.
6. Coordinating with other government agencies, such as the Federal Emergency Management Agency (FEMA), the American Red Cross, and other support agencies.
7. Providing support to emergency shelters and other response efforts.
8. Coordinating with the Maui County Corporation Counsel’s office to ensure legal support for emergency response and recovery efforts.
9. Coordinating with the Maui Police Department’s Wailuku and Lanai districts to ensure support for the Wailuku and Lanai communities.
10. Coordinating with the Maui Fire Department’s Wailuku and Lanai stations to ensure support for the Wailuku and Lanai fire units.

In summary, MEMA played a crucial role in coordinating and facilitating the response to the 2023 wildfires in Maui County. MEMA’s responsibilities included coordinating with other emergency response agencies,
activating the Maui EOC, providing situational reports and information, managing and coordinating volunteer and donated resources,
and providing support to emergency shelters and other response efforts.

Another example:

# User enters question below
user_question = "When did the Hawaii wildfires take place?"

# Format the question
eval_prompt = f"Question: {user_question}. Just answer this question accurately and concisely\n\n"

promptTokenized = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

modelFinetuned.eval()
with torch.no_grad():
print(tokenizer.decode(modelFinetuned.generate(**promptTokenized, max_new_tokens = 1024)[0], skip_special_tokens=True))
torch.cuda.empty_cache()
Question: When did the Hawaii wildfires take place?. Just answer this question accurately

Answer: The Hawaii wildfires took place from August 8, 2023 to August 12, 2023.

We can see from the above examples that the model performs very well and demonstrates a strong understanding of about the 2023 Wildfire incident!

This brings us to the end of the tutorial! Feel free to tinker around with the notebook and fine-tune your personal Llama-2 chatbot on your private data and have fun :)

Credits

--

--

Polo Club of Data Science | Georgia Tech
Polo Club of Data Science | Georgia Tech

Published in Polo Club of Data Science | Georgia Tech

At Georgia Tech, we develop scalable, interpretable, trustworthy tools for understanding large-scale data and complex ML models, solving real world problems in human-centered AI (interpretable, fair, safe AI; adversarial ML), cybersecurity, and social good (health, energy).

No responses yet