Sherlock Holmes Q&A Enhanced with Gemma 2b-it Fine-Tuning

Luca Massaron
14 min readApr 14, 2024

Find how fine-tuning can be an alternative to building a RAG

credits: DALL·E 3

Source: https://www.kaggle.com/code/lucamassaron/sherlock-holmes-q-a-with-gemma-fine-tuning

If you want an LLM to answer a question about a topic on which it has not been trained, an alternative to building a RAG is to fine-tune a model to specifically answer a Q&A on that topic.

In this Kaggle Notebook tutorial, I use Gemma 2b-it and Hugging Face packages to build a specialized Gemma model for answering tricky questions about Sherlock Holmes!

In the tutorial, the steps shown are:

  1. retrieving a knowledge base from Wikipedia (but you can use any text you want)
  2. leveraging Gemma to build meaningful Q&A based on the knowledge base
  3. train Gemma on the Q&A data using 4-bit quantization and LoRA.
  4. save the trained LoRA weights and merge them back into Gemma

In the end, given some patience in gathering enough data and processing it in a Q&A form, you will have a very specialized Gemma model on the topic you want (not necessarily Sherlock Holmes)!

We start on the Kaggle notebook with code that installs several Python packages using pip:

  • This first line installs the PyTorch library, which is used for deep learning tasks, particularly neural networks. The—q flag quiets the installation process (no output except for errors), and -U ensures that if PyTorch is installed, it will be updated to the latest version. The—index-url flag specifies a custom URL for package indexes. In this case, it’s downloading the PyTorch wheel from a specific URL for CUDA 11.7.
  • The following line installs a package named bitsandbytes from the Python Package Index (PyPI). Similar to the previous line, -q makes the installation quiet, -U updates the package if it’s already installed, and -i specifies the package index URL.
  • Next, it installs the transformers library, which provides state-of-the-art natural language processing models like BERT, GPT, etc. The flags -q and -U have the same meaning as before.
  • The following line installs the accelerate library, which provides utilities for high-performance computing, particularly in the context of deep learning. Again, -q and -U are used for quiet installation and updating the package.
  • The following line installs the datasets library, which provides easy access to various datasets for machine learning tasks. Once more, -q and -U are used for quiet installation and updating.
  • The following line installs the trl library, a full-stack library by HuggingFace that provides a set of tools to train transformer language models with Reinforcement Learning from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. As before,—q and -U are used for quiet installation and updating.
  • The following line installs the peft library, which is a Python library by HuggingFace for efficient adaptation of pre-trained language models (PLMs) to various downstream applications without fine-tuning all the model’s parameters. PEFT methods only fine-tune a small number of (extra) model parameters, thereby significantly decreasing the computational and storage costs.
  • Finally, the last line installs the wikipedia-api library, which provides an easy interface to interact with Wikipedia data. -q and -U are used for quiet installation and updating.
!pip install -q -U torch --index-url https://download.pytorch.org/whl/cu117
!pip install -q -U -i https://pypi.org/simple/ bitsandbytes
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl
!pip install -q -U peft
!pip install -q -U wikipedia-api

The code imports the os module and sets two environment variables:

  • CUDA_VISIBLE_DEVICES: This environment variable tells PyTorch which GPUs to use. In this case, the code sets the environment variable to 0, meaning that PyTorch will use the first GPU.
  • TOKENIZERS_PARALLELISM: This environment variable tells the Hugging Face Transformers library whether to parallelize the tokenization process. In this case, the code sets the environment variable to false, meaning the tokenization process will not be parallelized.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

The code import warnings; warnings.filterwarnings(“ignore”) imports the warnings module and sets the warning filter to ignore. This means that all warnings will be suppressed and not displayed. During training, many warnings do not prevent fine-tuning but can be distracting and make you wonder if you are doing the correct thing.

import warnings
warnings.filterwarnings("ignore")

In the following cell, there are all the other imports for running the notebook:

import re
import numpy as np
import pandas as pd
from tqdm import tqdm
import wikipediaapi
import torch
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import transformers
from transformers import (AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from datasets import Dataset
from peft import LoraConfig, PeftConfig
import bitsandbytes as bnb
from trl import SFTTrainer

The following cell presents a function that returns the device where to map the model and the data when working with the PyTorch library (used by the HF packages). It works with a CPU-based computer, GPU, and MacOS with MPS.

def define_device():
"""Define the device to be used by PyTorch"""
# Get the PyTorch version
torch_version = torch.__version__
# Print the PyTorch version
print(f"PyTorch version: {torch_version}", end=" -- ")
# Check if MPS (Multi-Process Service) device is available on MacOS
if torch.backends.mps.is_available():
# If MPS is available, print a message indicating its usage
print("using MPS device on MacOS")
# Define the device as MPS
defined_device = torch.device("mps")
else:
# If MPS is not available, determine the device based on GPU availability
defined_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Print a message indicating the selected device
print(f"using {defined_device}")
# Return the defined device
return defined_device

Step 1: get the knowledge base

Apart from the first two functions helpful in cleaning the text from tags and formatting, the following code extracts references, such as pages or other Wikipedia categories, using the extract_wikipedia_pages function. Then, the get_wikipedia_pages function crawls to all the pages and information related to some initial Wikipedia category or page.

# Pre-compile the regular expression pattern for better performance
BRACES_PATTERN = re.compile(r'\{.*?\}|\}')

def remove_braces_and_content(text):
"""Remove all occurrences of curly braces and their content from the given text"""
return BRACES_PATTERN.sub('', text)

def clean_string(input_string):
"""Clean the input string."""

# Remove extra spaces by splitting the string by spaces and joining back together
cleaned_string = ' '.join(input_string.split())

# Remove consecutive carriage return characters until there are no more consecutive occurrences
cleaned_string = re.sub(r'\r+', '\r', cleaned_string)

# Remove all occurrences of curly braces and their content from the cleaned string
cleaned_string = remove_braces_and_content(cleaned_string)

# Return the cleaned string
return cleaned_string
def get_wikipedia_pages(categories):
"""Retrieve Wikipedia pages from a list of categories and extract their content"""

# Create a Wikipedia object
wiki_wiki = wikipediaapi.Wikipedia('Gemma AI Assistant (gemma@example.com)', 'en')

# Initialize lists to store explored categories and Wikipedia pages
explored_categories = []
wikipedia_pages = []

# Iterate through each category
print("- Processing Wikipedia categories:")
for category_name in categories:
print(f"\tExploring {category_name} on Wikipedia")

# Get the Wikipedia page corresponding to the category
category = wiki_wiki.page("Category:" + category_name)

# Extract Wikipedia pages from the category and extend the list
wikipedia_pages.extend(extract_wikipedia_pages(wiki_wiki, category_name))

# Add the explored category to the list
explored_categories.append(category_name)
# Extract subcategories and remove duplicate categories
categories_to_explore = [item.replace("Category:", "") for item in wikipedia_pages if "Category:" in item]
wikipedia_pages = list(set([item for item in wikipedia_pages if "Category:" not in item]))

# Explore subcategories recursively
while categories_to_explore:
category_name = categories_to_explore.pop()
print(f"\tExploring {category_name} on Wikipedia")

# Extract more references from the subcategory
more_refs = extract_wikipedia_pages(wiki_wiki, category_name)
# Iterate through the references
for ref in more_refs:
# Check if the reference is a category
if "Category:" in ref:
new_category = ref.replace("Category:", "")
# Add the new category to the explored categories list
if new_category not in explored_categories:
explored_categories.append(new_category)
else:
# Add the reference to the Wikipedia pages list
if ref not in wikipedia_pages:
wikipedia_pages.append(ref)
# Initialize a list to store extracted texts
extracted_texts = []

# Iterate through each Wikipedia page
print("- Processing Wikipedia pages:")
for page_title in tqdm(wikipedia_pages):
try:
# Make a request to the Wikipedia page
page = wiki_wiki.page(page_title)
# Check if the page summary does not contain certain keywords
if "Biden" not in page.summary and "Trump" not in page.summary:
# Append the page title and summary to the extracted texts list
if len(page.summary) > len(page.title):
extracted_texts.append(page.title + " : " + clean_string(page.summary))
# Iterate through the sections in the page
for section in page.sections:
# Append the page title and section text to the extracted texts list
if len(section.text) > len(page.title):
extracted_texts.append(page.title + " : " + clean_string(section.text))

except Exception as e:
print(f"Error processing page {page.title}: {e}")

# Return the extracted texts
return extracted_texts
def extract_wikipedia_pages(wiki_wiki, category_name):
"""Extract all references from a category on Wikipedia"""

# Get the Wikipedia page corresponding to the provided category name
category = wiki_wiki.page("Category:" + category_name)

# Initialize an empty list to store page titles
pages = []

# Check if the category exists
if category.exists():
# Iterate through each article in the category and append its title to the list
for article in category.categorymembers.values():
pages.append(article.title)

# Return the list of page titles
return pages

To gather the information necessary to answer the most tricky questions about Sherlock Holmes and his World, I’ve chosen to begin with a series of topics related to Conan Doyle and his writings.

categories = ["Sherlock_Holmes", "Arthur_Conan_Doyle", "A_Scandal_in_Bohemia",
"The_Adventures_of_Sherlock_Holmes", "A_Study_in_Scarlet", "The_Sign_of_the_Four",
"The_Memoirs_of_Sherlock_Holmes", "The_Hound_of_the_Baskervilles",
"The_Return_of_Sherlock_Holmes", "The_Valley_of_Fear", "His_Last_Bow",
"The_Case-Book_of_Sherlock_Holmes", "Canon_of_Sherlock_Holmes", "Dr._Watson",
"221B_Baker_Street", "Mrs._Hudson", "Professor_Moriarty", "The_Strand_Magazine",
]
extracted_texts = get_wikipedia_pages(categories)
print("Found", len(extracted_texts), "Wikipedia pages")
- Processing Wikipedia categories:
Exploring Sherlock_Holmes on Wikipedia
Exploring Arthur_Conan_Doyle on Wikipedia
Exploring A_Scandal_in_Bohemia on Wikipedia
Exploring The_Adventures_of_Sherlock_Holmes on Wikipedia
Exploring A_Study_in_Scarlet on Wikipedia
Exploring The_Sign_of_the_Four on Wikipedia
Exploring The_Memoirs_of_Sherlock_Holmes on Wikipedia
Exploring The_Hound_of_the_Baskervilles on Wikipedia
Exploring The_Return_of_Sherlock_Holmes on Wikipedia
Exploring The_Valley_of_Fear on Wikipedia
Exploring His_Last_Bow on Wikipedia
Exploring The_Case-Book_of_Sherlock_Holmes on Wikipedia
Exploring Canon_of_Sherlock_Holmes on Wikipedia
Exploring Dr._Watson on Wikipedia
Exploring 221B_Baker_Street on Wikipedia
Exploring Mrs._Hudson on Wikipedia
Exploring Professor_Moriarty on Wikipedia
Exploring The_Strand_Magazine on Wikipedia
Exploring Works originally published in The Strand Magazine on Wikipedia
Exploring Non-free The Strand Magazine magazine covers on Wikipedia
Exploring The Strand Magazine editors on Wikipedia
Exploring Films based on The Hound of the Baskervilles on Wikipedia
Exploring Works by Arthur Conan Doyle on Wikipedia
Exploring Cultural depictions of Arthur Conan Doyle on Wikipedia
Exploring Arthur Conan Doyle characters on Wikipedia
Exploring Works based on Sherlock Holmes on Wikipedia
Exploring Sherlock Holmes short story collections on Wikipedia
Exploring Sherlock Holmes short stories on Wikipedia
Exploring Sherlock Holmes scholars on Wikipedia
Exploring Sherlock Holmes novels on Wikipedia
Exploring Sherlock Holmes navigational boxes on Wikipedia
Exploring Sherlock Holmes lists on Wikipedia
Exploring Dartmoor on Wikipedia
Exploring Sherlock Holmes characters on Wikipedia
- Processing Wikipedia pages:
Found 2451 Wikipedia pages

Step 2: convert the knowledge base into a Q&A dataset

Now, having collected our knowledge base on Sherlock Holmes, we need to leverage Gemma to convert it into something more useful for training a model. The idea is to use a Q&A approach (e.g. Q: What is Sherlock’s nemesis? A: Prof. Moriarty)

First, let’s upload Gemma 2b-it into memory by quantizing it into a 4-bit version using BitsAndBytes.

model_name = "/kaggle/input/gemma/transformers/2b-it/3"

compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
max_seq_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)

A simple function can wrap up all the steps necessary to inquire about Gemma on a topic or pose a question. The function allows for the pointing out of different temperatures and can return the answer as a stdout or a string.

def question_gemma(question, model=model, tokenizer=tokenizer, temperature=0.0, return_answer=False):
input_ids = tokenizer(question, return_tensors="pt").to("cuda")
if temperature > 0:
do_sample=True
else:
do_sample=False
outputs = model.generate(**input_ids,
max_new_tokens=256,
do_sample=do_sample,
temperature=temperature)
result = str(tokenizer.decode(outputs[0])).replace("<bos>", "").replace("<eos>", "").strip()
if return_answer:
return result
else:
print(result)

We can immediately test it on some general questions about Sherlock Holmes. You will be amazed by the answers! Gemma 2b is already quite knowledgeable about Sherlock.

question_gemma("Was Sherlock Holmes a real person?")
The premise of your question is incorrect. Sherlock Holmes is a fictional character in literature and not a real person.
question_gemma("How is Sherlock Holmes so smart?")
Sherlock Holmes is one of the most brilliant and complex characters in literature. He is known for his sharp intellect, deductive reasoning, and observation skills. How is he so smart?**1. Analytical Mind:**
Sherlock Holmes is an analytical thinker who can break down complex problems into smaller, more manageable pieces. He is able to identify patterns and relationships between different pieces of information, which allows him to deduce conclusions that are not immediately obvious.
**2. Deductive Reasoning:**
Sherlock Holmes is a master of deductive reasoning. He is able to draw conclusions from a single piece of evidence, and he is always looking for patterns and inconsistencies in evidence. This allows him to solve mysteries and identify the truth behind the deception.
**3. Observation Skills:**
Sherlock Holmes is an expert observer who can notice even the smallest details in a scene. He is able to use these details to piece together the truth and to identify the guilty party.
**4. Imagination and Creativity:**
Sherlock Holmes is a highly imaginative and creative character. He is able to come up with new ideas and solutions to problems, which allows him to outsmart his opponents.
**5. Perseverance and Determination:**
Sherlock Holmes is a persistent and determined character who will not give up
question_gemma("Does Sherlock Holmes die in the The Adventure of the Final Problem?")
No, Sherlock Holmes does not die in The Adventure of the Final Problem.
question_gemma("Who is Sherlock Holmes’s nemesis?")
The answer is Professor Moriarty.Professor Moriarty is a fictional character in Sherlock Holmes stories. He is a brilliant criminal who is Holmes's nemesis.
question_gemma("Does Sherlock Holmes have a museum?")
Yes, Sherlock Holmes does have a museum in Baker Street, London, England. It is called the Baker Street Gallery and is a museum dedicated to the life and work of Sherlock Holmes. The museum houses a collection of Holmes's personal belongings, including his personal library, furniture, and other items.
question_gemma("In what Sherlock Holmes is knowledgable?")
Sherlock Holmes is knowledgeable in a wide range of subjects, including:- Forensic science
- Criminology
- Logic
- Deduction
- History
- Literature
- Music
- Art
- Science
- Medicine
He is also a master of disguise and deception, and he often uses his knowledge to outsmart his opponents.
question_gemma("What was Arthur Conan Doyle's belief about paranormal phenomena?")
Arthur Conan Doyle was a staunch skeptic and did not believe in paranormal phenomena. He was skeptical of any claims that could not be verified through scientific observation or logic.

Quite impressive! However, we want to become even more experts and learn the most intricate facts about Sherlock Holmes and his creator, Sir Arthur Conan Doyle. Hence, we iterate through the Wikipedia knowledge base, asking Gemma to produce a question and answer it from the text.

Using all of the knowledge base and posing multiple answers derived from the same text will help build out fine-tuning training data. Asking multiple answers is a necessity because Gemma will pick just a topic from the test, and it will tend to answer briefly.

We can control how Gemma returns the question and answer, proposing it to return a JSON file in the form {“question”: “…”, “answer”: “…”}. Hence, it will be easy to retrieve the data from the output text utilizing regex.

qa_data = []


def extract_json(text, word):
pattern = fr'"{word}": "(.*?)"'
match = re.search(pattern, text)
if match:
return match.group(1)
else:
return ""
no_extracted_texts = 300 # increment this number up to len(extracted_texts)
question_ratio = 24 # decrement this number to produce more questions (suggested: 24)
for i in tqdm(range(len(extracted_texts[:no_extracted_texts]))):
question_text = f"""Create a question and its answer from the following piece of information,
put all the necessary information into the question (do not assume the reader knows the text),
and return it exclusively in JSON format in the format {'{"question": "...", "answer": "..."}'}
Here is the piece of information to elaborate:
{extracted_texts[i]}
OUTPUT JSON:
"""
no_questions = min(1, len(extracted_texts[i]) // question_ratio)
for j in range(no_questions):

result = question_gemma(question_text, model=model, temperature=0.9, return_answer=True)
result = result.split("OUTPUT JSON:")[-1]
question = extract_json(result, "question")
answer = extract_json(result, "answer")
qa_data.append(f"{question}\n{answer}")

Now that the dataset has been gathered, it is time to turn it into an HF Dataset.

max_seq_length = 1024

train_data = (pd.DataFrame(qa_data, columns=["text"])
.sample(frac=1, random_state=5)
.drop_duplicates()
)
train_data = Dataset.from_pandas(train_data)

Step 3: fine-tune the Gemma model

In the following cells, LoRA is set, and the training parameters are defined. Afterward, the fine-tuning can start.

output_dir = "gemma_sherlock"

peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
)
training_arguments = TrainingArguments(
output_dir=output_dir,
num_train_epochs=1,
gradient_checkpointing=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
optim="paged_adamw_32bit",
save_steps=0,
logging_steps=25,
learning_rate=5e-4,
weight_decay=0.001,
fp16=True,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=False,
evaluation_strategy='steps',
eval_steps = 500,
eval_accumulation_steps=1,
lr_scheduler_type="cosine",
report_to="tensorboard",
)
trainer = SFTTrainer(
model=model,
train_dataset=train_data,
peft_config=peft_config,
dataset_text_field="text",
tokenizer=tokenizer,
max_seq_length=max_seq_length,
args=training_arguments,
packing=False,
)
trainer.train()
TrainOutput(global_step=37, training_loss=2.9550353900806323, metrics={'train_runtime': 108.5072, 'train_samples_per_second': 2.765, 'train_steps_per_second': 0.341, 'total_flos': 119181907488768.0, 'train_loss': 2.9550353900806323, 'epoch': 0.99})

After we finish, we can ask a tricky question. Gemma's ability to answer depends on how good the Q&A data previously produced was! Remember, the more data is extracted, the better (redundancy is also better—i.e., similar questions with differently arranged answers).

question_gemma("What was Arthur Conan Doyle's belief about paranormal phenomena?",
model=model, tokenizer=tokenizer)

What was Arthur Conan Doyle's belief about paranormal phenomena?
Doyle believed that paranormal phenomena was real and that it was caused
by a combination of factors, including psychological factors, physical factors,
and environmental factors.

Amazed that the answer is different from the un-tunned Gemma? Actually, the following answer is more correct (read: https://blog.bookstellyouwhy.com/sir-arthur-conan-doyles-proclivity-for-the-paranormal)

Step 4: save the LoRA weights and merge them into Gemma

Now, the tricky part is saving the trained LoRA weights, reloading them, and merging them with the Gemma original model. The result is our new fine-tuned Gemma!

trainer.save_model()
tokenizer.save_pretrained(output_dir)
('gemma_sherlock/tokenizer_config.json',
'gemma_sherlock/special_tokens_map.json',
'gemma_sherlock/tokenizer.model',
'gemma_sherlock/added_tokens.json',
'gemma_sherlock/tokenizer.json')

This cell cleans up the CPU and GPU memory.

import gc

del [model, tokenizer, peft_config, trainer, train_data, bnb_config, training_arguments]
del [TrainingArguments, SFTTrainer, LoraConfig, BitsAndBytesConfig]

for _ in range(10):
torch.cuda.empty_cache()
gc.collect()

Now we proceed to the merging procedure:

from peft import AutoPeftModelForCausalLM

finetuned_model = output_dir
compute_dtype = getattr(torch, "float16")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoPeftModelForCausalLM.from_pretrained(
finetuned_model,
torch_dtype=compute_dtype,
return_dict=False,
low_cpu_mem_usage=True,
device_map="auto",
)
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./gemma_sherlock_merged",
safe_serialization=True,
max_shard_size="2GB")
tokenizer.save_pretrained("./gemma_sherlock_merged")
('./gemma_sherlock_merged/tokenizer_config.json',
'./gemma_sherlock_merged/special_tokens_map.json',
'./gemma_sherlock_merged/tokenizer.model',
'./gemma_sherlock_merged/added_tokens.json',
'./gemma_sherlock_merged/tokenizer.json')

Again, memory cleaning.

import gc

del [model, tokenizer, merged_model, AutoPeftModelForCausalLM]
for _ in range(10):
torch.cuda.empty_cache()
gc.collect()

The final step is reloading the fine-tuned model and try using it!

from transformers import (AutoModelForCausalLM, 
AutoTokenizer,
BitsAndBytesConfig)

model_name = "./gemma_sherlock_merged"
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=bnb_config,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
max_seq_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)

Now it is time for the last test:

question_gemma("What was the Strand magazine?",
model=model, tokenizer=tokenizer)
What was the Strand magazine?
The Strand magazine was a British magazine that was founded in 1903.
It was one of the first magazines to feature a variety of writers,
including J. B. Priestley, H. G. Wells, and Agatha Christie.
The magazine also featured illustrations by artists such as
Aubrey Beardsley and Arthur Rackham. The Strand magazine was a
popular magazine in Britain and was read by millions of people.
It was also translated into other languages, including German,
French, and Spanish.

Nice answer!

We conclude the tutorial here. By following the same steps, you can fine-tune Gemma for any topic.

Enjoy fine-tuning with Google Gemma!

#GemmaSprint

Google Cloud credits were provided for this project

--

--

Luca Massaron

Data scientist molding data into smarter artifacts. Author on AI, machine learning, and algorithms for Wiley, Packt, Manning. 3x Kaggle Grandmaster.