Finetune Language Models

Sharath S Hebbar
3 min readDec 19, 2023

--

GitHub LinkedIn Medium Portfolio Substack

FInetuning LLM.

Language models, particularly those based on transformer architectures, have proven to be incredibly powerful tools in natural language processing (NLP). Pre-trained language models, such as OpenAI’s GPT (Generative Pre-trained Transformer) series, Google’s BERT (Bidirectional Encoder Representations from Transformers), and T5 (Text to Text Transfer Transformers), have achieved state-of-the-art results on a wide range of NLP tasks. However, these pre-trained models might not be optimal for every specific use case or domain. Fine-tuning provides a solution to adapt these models to specific tasks or datasets, offering improved performance.

Some newer models are being developed by several Organizations such as Meta, Microsoft, DeciLM, and Mistral.

So fine-tuning a large language model is an intensive task and often requires expertise in the field of ML, a well-curated dataset, and access to high computing power which for an individual (even some organization) is impossible to obtain.

We can fine-tune some smaller models that fit in our local laptop or Google Colab.

We will be fine-tuning GPT2. This is the smallest version of GPT-2, with 124M parameters.

GPT2 Architecture.
Hugging Face

Link to the pre-trained model: https://huggingface.co/gpt2

Link to the dataset: https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

Here are the steps to follow

Importing Libraries

import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from datasets import load_dataset

Set the device to Cuda to run your model quickly

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
device

Select your model and dataset

model_name = "gpt2"
dataset_name = "HuggingFaceH4/ultrachat_200k"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

train_dataset = load_dataset(dataset_name, split='train_gen')
train_dataset.to_pandas()

test_dataset = load_dataset(dataset_name, split='test_gen')
test_dataset.to_pandas()

Dataset preparation and tokenization

def prepare_datasets(data):
prompt = ""
for i in range(len(data['messages']) - 1):
if len(prompt) > 0:
prompt = prompt + "\n" + f"""{data['messages'][i]['role']}: {data['messages'][i]['content']}"""
else:
prompt = f"""{data['messages'][i]['role']}: {data['messages'][i]['content']}"""
data['query'] = prompt
return data


def tokenize_datasets(dataset):
tokenized_dataset = dataset.map(
lambda example: tokenizer(
example['query'],
truncation=True,
max_length=128,
),
batched=True,
remove_columns=['query']
)
return tokenized_dataset

# Dataset Prep
train_dataset = train_dataset.map(
prepare_datasets, remove_columns=['prompt', 'prompt_id', 'messages']
)

test_dataset = test_dataset.map(
prepare_datasets, remove_columns=['prompt', 'prompt_id', 'messages']
)

# Tokenization
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = tokenize_datasets(train_dataset)
test_dataset = tokenize_datasets(test_dataset)

# Data Collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_collator

Training Arguments

batch_size = 4
training_args = TrainingArguments(
output_dir="./models/chat_gpt2",
gradient_accumulation_steps=batch_size,
num_train_epochs=3,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
load_best_model_at_end=True,
save_strategy="no",
save_total_limit=2,
fp16=True,
learning_rate=2e-05,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
report_to=None
)


training_args

Trainer

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator
)

trainer.train()

Pushing it to HuggingFace 🤗

MODEL_PATH = "<Yourname/Name_of_the_model>"
model.push_to_hub(
MODEL_PATH, token="<HF_Token>"
)

References:

  1. https://github.com/SharathHebbar/HuggingFace-handson/tree/main/chat_gpt2
  2. https://github.com/SharathHebbar/Transformers/blob/main/Basics/6_Instruction_following_using_GPT.ipynb

Know about Transformers: https://github.com/SharathHebbar/Transformers

--

--

Sharath S Hebbar

Data Science | Machine learning | Artificial Intelligence | Cloud | Internet of Things | Statistics