Fine Tuning BLOOMZ for Legal Question Answering
In this article, I explore the steps for fine-tuning the language model BLOOMZ with low resources and a small dataset.
BLOOMZ is a family of language models resulting from multitask fine tuning (or instruction tuning) the open source pre-trained multilingual model BLOOM. Instruction tuning aims at improving the zero-shot task generalization of pre-trained large language models by tuning the model with a collection of datasets from a variety of tasks, such as question answering (QA), summarization, and translation. Given the variety of tasks, the pre-trained model learns which task to perform through a prompt.
In this article, I cover the steps to fine tune BLOOMZ, more specifically the model bloomz-560m, on a small dataset composed of Portuguese questions and answers related to the Brazilian legal system. For this task, I used a desktop PC with a standalone nVidia RTX 3090 GPU with 24 GB of GDDR6X VRAM and a 16-core AMD Ryzen 9 5950X CPU with 32 GB of DDR4 RAM.
The first step towards fine tuning BLOOMZ was to build the Portuguese question answering dataset, called LegalQA. Then, a dataset of concise and clear prompts was defined for the task. Finally, the model bloomz-560m was tuned on the dataset of QA prompts, resulting in a finetuned model called LegalQA-bloom-560.
The LegalQA Dataset
The LegalQA dataset was built by extracting text from PDF files downloaded from the Brazilian equivalent of the USA's Internal Revenue Service (Receita Federal) website. In total, five PDF files were parsed resulting in a total of 396 pairs of questions and answers related to five different topics varying from the rural property tax (ITR, Imposto Territorial Rural) to common income tax returns.
The code implemented to parse the PDF files is available on Github, as well as the final question answering dataset, the dataset with the prompts, and the code to fine tune BLOOMZ.
from PyPDF2 import PdfReader
# create pandas dataframe to store the QnAs extracted from pdf files
df = pd.DataFrame()
filename = './rawdata/ITR2022.pdf' # pdf file to be parsed
skipto = 12 # skip the first 12 pages
qa = [] # store questions and answers
# read file
reader = PdfReader(filename)
# get total number of pages
num_pages = len(reader.pages)
print('File %s has %d pages' % (filename, num_pages))
for idx in range(skipto, num_pages):
page = reader.pages[idx]
pagetext = page.extract_text()
...
# extract questions and answers
question =
answer =
...
qa.append({'question': question, 'answer': answer})
df_temp = pd.DataFrame.from_dict(qa)
df = pd.concat([df, df_temp])
# saving QnA data
df.to_csv('LegalQA_dataset.csv', index=False, sep='\t')
I plan to keep adding more data to this dataset by parsing other available PDFs :)
The LegalQA Prompt Dataset
Using the LegalQA dataset, a prompt dataset was built by adding concise and clear natural language instructions to the question answering input. The prompt has the following format:
"Given the question delimited by triple backticks ```{question}```,
what is the answer? Answer: {answer}"
Even though the question answering dataset is comprised of Portuguese written pairs of questions and answers, I use English prompts to fine tune BLOOMZ. This decision was due to the fact that the model bloomz-560m was initially tuned on the dataset xP3 and it is recommended to be used with English prompts.
import pandas as pd
import json
# load portuguese legal QnA datasetL LegalQA
dataset = pd.read_csv("LegalQA_dataset.csv", sep = '\t')
def buildprompt(data):
prompt['text'] = "Given the question delimited by triple backticks
```{" + data['question'] + "}```, what is the answer?
Answer: {" + data['answer'] + "}"
return prompt
dataset['prompt'] = dataset.apply(buildprompt, axis=1)
result = dataset['prompt'].to_list()
# save prompts to a json file
with open('prompts.json', 'w') as outfile:
json.dump(result, outfile, ensure_ascii=False)
Fine tuning BLOOMZ
Considering the exploratory nature of this work and the resource limitation, the BLOOMZ model chosen to be tuned was the smallest one, the bloomz-560m.
I start by loading bloomz-560m tokenizer and model, and also by loading the prompt dataset. Model and tokenizer can be downloaded using the HuggingFace library. The dataset could have been divided into train and evaluation sets; however, since I was not so concerned with the quality of the model (it is supposed to be just a simple study), I used the entire dataset to train the model.
import pandas as pd
import torch
import json
from transformers import BloomTokenizerFast, BloomForCausalLM,
TrainingArguments, Trainer
from datasets import load_dataset
# Loading bloomz model and tokenizer
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloomz-560m")
model = BloomForCausalLM.from_pretrained("bigscience/bloomz-560m").to("cuda")
# Loading dataset prompts.json built using de portuguese legalQA dataset
dataset = load_dataset("json", data_files="prompts.json")
# prepare the data for training
def prepare_train_data(data):
# prompt + completion
text_input = data['text']
# tokenize the input (prompt + completion) text
tokenized_input = tokenizer(text_input, return_tensors='pt', padding=True)
# generative models: labels are the same as the input
tokenized_input['labels'] = tokenized_input['input_ids']
return tokenized_input
train_dataset = dataset['train'].map(prepare_train_data,
batched=True,
remove_columns=["text"])
You can find below the arguments used during training. These were defined with the resource constraints in mind. With more resources, the model could be tuned for more epochs and higher batch sizes, for instance.
# setting arguments to be used during training
training_arguments = TrainingArguments(
'LegalQA-bloom-560m',
learning_rate=2e-5,
per_device_train_batch_size=2,
num_train_epochs=2,
weight_decay=0.01,
fp16=True,
optim="adafactor",
gradient_accumulation_steps=4,
gradient_checkpointing=True
)
Now, it is time to train and save the model!
trainer = Trainer(
model = model,
args = training_arguments,
train_dataset = train_dataset
)
trainer.train()
trainer.save_model()
Let’s check the quality of our fine tuned model.
I start by loading the pre-trained bloomz-560m model and then testing it using a question very close to one of the question in the LegalQA dataset.
import torch
from transformers import pipeline
from transformers import BloomTokenizerFast, BloomForCausalLM
# Loading the original model: bloomz-560m
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloomz-560m")
model = BloomForCausalLM.from_pretrained("bigscience/bloomz-560m",
low_cpu_mem_usage=True).to("cpu")
prompt = 'Given the question delimited by triple backticks
```{ O que é o imposto territorial rural? }```, what is the answer?
Answer:'
generator = pipeline('text-generation',
model=model,
tokenizer=tokenizer,
do_sample=False)
result = generator(prompt, max_length=128)
print(result)
Output:
[{'generated_text': 'Given the question delimited by triple backticks
```{ O que é o imposto territorial rural? }```, what is the answer?
Answer: o que é o imposto territorial rural?'}]
The pre-trained model was not able to answer the given question.
Now, lets try the same question with the fine tuned model, LegalQA-bloom-560m.
import torch
from transformers import pipeline
from transformers import BloomTokenizerFast, BloomForCausalLM
# Loading the fine-tuned model: LegalQA-bloom-560m
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloomz-560m")
model = BloomForCausalLM.from_pretrained("LegalQA-bloom-560m",
low_cpu_mem_usage=True).to("cpu")
prompt = 'Given the question delimited by triple backticks
```{ O que é o imposto territorial rural? }```, what is the answer?
Answer:'
generator = pipeline('text-generation',
model=model,
tokenizer=tokenizer,
do_sample=False)
result = generator(prompt, max_length=128)
print(result)
Output:
[{'generated_text': 'Given the question delimited by triple backticks
```{ O que é o imposto territorial rural? }```, what is the answer?
Answer: { O imposto territorial rural é o imposto devido pelo contribuinte,
diretamente, sobre as áreas de pastagem, pecuária, pecuária de corte,
pecuária de corte e de corte de corte, destinadas à produção de grãos,
vegetais e animais de corte, observado o disposto no art. 1º da Lei n º 9.393,
de 1996, e no art. 2º da Lei n º 10.165, de 2000, e no art.'}]
Perfect! Now the small BLOOMZ model fine tuned on a small dataset and with limited resources can answer our question perfectly.
"Sometimes it is just the journey…"