Gemma: Fine-tuning using LoRA(Low-Rank Adaptation)

BavalpreetSinghh
10 min readFeb 26, 2024

--

Photo by Luke Chesser on Unsplash

In the last blog I have shared the insights about what gemma is and how we can start with it. As a next step now we will learn how we can fine tune the gemma using low rank adaptation technique. Before learning about what lora is and what it does? Let’s first go through what is Lower-Rank Matrices?

Matrix

A matrix is a rectangular array of numbers arranged in rows and columns. For example:

A = [1 4 7
2 5 8
3 6 9]

Here, A is a 3x3 matrix because it has 3 rows and 3 columns.

Rank

The rank of a matrix is like figuring out how much truly independent information it carries. Imagine you have a bunch of rows in a matrix, and you want to know how many of them are actually telling you something new, not just repeating what you already know.

Let’s take a simple example:

M = [1 4 2
2 5 4
3 6 6]

The third column is just twice the first column. So, the columns are not linearly independent. Therefore, the rank of matrix M is 2(since only the first and second column is independent).

Lower Rank Matrices

A matrix is called lower rank matrix if it’s rank is lower than the highest rank it could have based on its size. In simpler terms, if a matrix doesn’t have enough unique information compared to its size, it’s considered a lower-rank matrix.

Continuing with matrix M as an example, since its rank is 2 (which is less than the maximum possible rank for a 3x3 matrix, which is 3), matrix M is a lower-rank matrix.

Lower-rank matrices are useful because they indicate redundancy or patterns in the data. This can be valuable in various mathematical and computational applications, as it allows for more efficient storage, analysis, and manipulation of the data.

Now you might have got an idea like why it was necessary to go through the lower rank matrix before understanding what LoRA does.

LoRA — Low Rank Adaptation

LoRA (Low-Rank Adaptation of Large Language Models) has gained popularity as a lightweight training method that slashes the number of parameters to train. Instead of adjusting every single weight in the model, lora adds a smaller set of new weights and focuses training solely on them. This approach speeds up training, saves memory, and generates smaller model sizes (just a few hundred MBs), making them simpler to store and share.

Having a higher rank allows for more detailed adjustments, which can enhance precision but also increases the number of parameters to train. On the other hand, a lower rank reduces computational burden but might result in less precise adaptation.

For this tutorial, a LoRA rank of 4 is utilized. It’s recommended to start with a relatively small rank, like 4, 8, or 16, for computational efficiency during experimentation. Train your model with this rank initially and assess its performance improvement on your specific task. As you progress, you can gradually increase the rank in subsequent experiments to observe if it leads to further enhancements in performance

Implementation

Install dependencies

!pip install -q -U keras-nlp
!pip install -q -U keras>=3
!pip install pyarrow

Import Packages

import keras
import keras_nlp
import os
import pyarrow.parquet as pq

Select Backend

os.environ["KERAS_BACKEND"] = "torch"

Dataset

The No Robots dataset consists of 10,000 meticulously crafted instructions and demonstrations, curated by proficient human annotators. This dataset serves as a valuable resource for supervised fine-tuning (SFT), enabling language models to better understand and adhere to provided instructions. Inspired by OpenAI’s InstructGPT paper, “No Robots” primarily comprises single-turn instructions covering various categories such as Generation, Open QA, Brainstorm, Chat, Rewrite, Summarize, Coding, Classify, Closed QA, and Extract. I have downloaded the file and uploaded it to kaggle for ease of use you can access it directly as well from Hugging Face.

# Step 1: Read Parquet file
table = pq.read_table('/kaggle/input/no-robots-sft/train_sft-00000-of-00001-8aba5401a3b757f5.parquet')

# Step 2: Convert table to DataFrame
df = table.to_pandas()

# Step 3: Access data
df.head() # Display first few rows of the DataFrame
Structure of a dataframe created from no robot dataset

Let’s have a look at the content under message column.

df['messages'][0]
#output
array([{'content': 'Please summarize the goals for scientists in this text:\n\nWithin three days, the intertwined cup nest of grasses was complete, featuring a canopy of overhanging grasses to conceal it. And decades later, it served as Rinkert’s portal to the past inside the California Academy of Sciences. Information gleaned from such nests, woven long ago from species in plant communities called transitional habitat, could help restore the shoreline in the future. Transitional habitat has nearly disappeared from the San Francisco Bay, and scientists need a clearer picture of its original species composition—which was never properly documented. With that insight, conservation research groups like the San Francisco Bay Bird Observatory can help guide best practices when restoring the native habitat that has long served as critical refuge for imperiled birds and animals as adjacent marshes flood more with rising sea levels. “We can’t ask restoration ecologists to plant nonnative species or to just take their best guess and throw things out there,” says Rinkert.', 'role': 'user'},
{'content': 'Scientists are studying nests hoping to learn about transitional habitats that could help restore the shoreline of San Francisco Bay.', 'role': 'assistant'}],
dtype=object)

Let’s format it before giving it to gemma for fine-tuning

# Initialize an empty list to store the formatted instructions and responses
formatted_data = []

# Iterate over each row of the DataFrame
for index, row in df.iterrows():
# Extract messages
messages = row['messages']

# Extract content from the first and second message
instruction_content = messages[0]['content']
response_content = messages[1]['content']

# Create the formatted string
formatted_string = f'Instruction:\n{instruction_content}\n\nResponse:\n{response_content}'

# Append the formatted string to the list
formatted_data.append(formatted_string)

Let’s have a look at data after it’s formatted.

# Print the formatted data
for item in formatted_data:
print(item)
break
#output

```
Instruction:
Please summarize the goals for scientists in this text:

Within three days, the intertwined cup nest of grasses was complete, featuring
a canopy of overhanging grasses to conceal it. And decades later, it served as
Rinkert’s portal to the past inside the California Academy of Sciences. Information
gleaned from such nests, woven long ago from species in plant communities called
transitional habitat, could help restore the shoreline in the future. Transitional
habitat has nearly disappeared from the San Francisco Bay, and scientists need a
clearer picture of its original species composition—which was never properly
documented. With that insight, conservation research groups like the San Francisco
Bay Bird Observatory can help guide best practices when restoring the native habitat
that has long served as critical refuge for imperiled birds and animals as adjacent
marshes flood more with rising sea levels. “We can’t ask restoration ecologists to
plant nonnative species or to just take their best guess and throw things out there,”
says Rinkert.

Response:
Scientists are studying nests hoping to learn about transitional habitats that
could help restore the shoreline of San Francisco Bay.
```

We will utilise only 200 instances for tuning as it is for learning purpose only

formatted_data = formatted_data[:200]

Load Model

KerasNLP provides implementations of many popular model architectures. In this tutorial, we’ll construct a model using GemmaCausalLM, an end-to-end Gemma model tailored for causal language modeling. Causal language modeling entails predicting the subsequent token based on previous tokens.

Create the model using the from_preset method:

gemma_model = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_model.summary()
PS: Total params and Trainable params

Inference before fine tuning

In this section, you will query the model with various prompts to see how it responds.

Butter chicken Prompt
Query the model for suggestions on how to cook delicious butter chicken.

prompt = "how to cook delicious butter chicken"
print(gemma_model.generate(prompt, max_length=256))
#output

how to cook delicious butter chicken in 15 minutes

Answer:

Step 1/6
1. Start by preheating your oven to 350 degrees Fahrenheit.

Step 2/6
2. In a large skillet, heat 2 tablespoons of butter over medium heat.

Step 3/6
3. Add 1 pound of chicken breasts and cook for 5 minutes, or until they are lightly browned.

Step 4/6
4. Add 1 cup of chicken broth and 1 teaspoon of garam masala to the skillet and bring to a boil.

Step 5/6
5. Reduce the heat to low and simmer for 10 minutes, or until the chicken is cooked through.

Step 6/6
6. Serve the chicken with basmati rice and a side of naan bread. Enjoy!

Explanation prompt — concept named evaporation

Prompt the model to explain evaporation in terms simple enough for a younger child to understand.

prompt = "Explain the process of evaporation in a way that a school going 3rd standard child could understand."
print(gemma_model.generate(prompt, max_length=256))
#output

Explain the process of evaporation in a way that a school going 3rd standard child could understand.

Answer:

Step 1/5
1. When water is heated, it changes its state from liquid to gas. This process is called evaporation.

Step 2/5
2. The water molecules in the liquid state move faster and collide with each other and the walls of the container.

Step 3/5
3. Some of these water molecules escape from the surface of the liquid and enter the gas phase. This process is called vaporization.

Step 4/5
4. The vapor molecules move away from the liquid and form a layer of gas above the liquid.

Step 5/5
5. The gas molecules are lighter than the liquid molecules, so they rise up and escape into the air. This process is called diffusion. So, in simple terms, evaporation is the process of water molecules escaping from the surface of a liquid and forming a layer of gas above the liquid. This gas then rises up and escapes into the air, leaving behind a smaller amount of liquid. This process is repeated continuously, and the amount of water in the liquid decreases over time

Fine-tunning

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_model.backbone.enable_lora(rank=4)
gemma_model.summary()
Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 B to 1.3 M)

Hyperparameters for lora
We are defining only rank rest all will be picked default. But it is good to know about others.

We chose a rank of 4 for efficiency while maintaining solid performance, finding minimal improvement with higher ranks like 8 or 16. Sticking with 8 is advisable for hardware compatibility, ensuring manageable checkpoint sizes without sacrificing accuracy.

Alpha, which scales learned weights, is commonly kept fixed at 16 based on literature and the LoRA paper. This practice is prevalent in the LLM community.

Expanding LoRA to all dense layers, instead of just “Q” and “V” attention matrices, is supported by research, indicating better results. However, at backend the code targets “query_dense” and “value_dense” layers for LoRA activation.

A base learning rate of 1e-4 is standard for fine-tuning LLMs with LoRA, despite occasional training loss instabilities. Lowering it to 3e-5 can stabilize the process, as observed by some practitioners.

#finetuning
# to control memory usage limit the input sequence length to 512.
gemma_model.preprocessor.sequence_length = 512
# Using AdamW - a common optimizer for transformer models.
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_model.fit(formatted_data, epochs=1, batch_size=1)
Tuning on 200 instances of no robot data

Inference after fine-tuning

Ps: At the time of saving the notebook on kaggle it re runs all the cells, you might see different outputs in code. For your ease of understanding I have printed out the outputs which Igot while Iwas doing inferences.

prompt = "how to cook delicious butter chicken"
print(gemma_model.generate(prompt, max_length=256))
Output after fine-tuning

Although the output is terminated but the way the response is presented and illustrated, it clearly shows the improvement.

prompt = "Explain the process of evaporation in a way that a school going 3rd standard child could understand."
print(gemma_model.generate(prompt, max_length=256))
Output after fine-tuning

Although due to limited max_length the output is terminated, we can clearly see that earlier it was halucinating but after fine tuning response are better, now it is talking about the evaporation process only.

Keep in mind that in this tutorial, we’re training the model on only a small part of the dataset for a single epoch and with a low LoRA rank setting. If you want improved responses from the fine-tuned model, you might want to try out the following:

  • Expanding the size of the dataset used for fine-tuning.
  • Adjusting the LoRA rank to a higher value.
  • Tweaking the hyperparameter values like learning_rate and weight_decay.
  • Increasing the number of training steps (epochs).

Summary

This tutorial provided an overview of LoRA fine-tuning on a Gemma model using KerasNLP. It explained the process and key considerations for fine-tuning, including rank selection, alpha scaling, target modules, and base learning rate.

Codebase — https://shorturl.at/almR7

Next Steps:
1. Learn how to generate text with a Gemma model.
2. Explore distributed fine-tuning and inference on a Gemma model.
3. Discover how to utilize Gemma open models with Vertex AI.
4. Dive into fine-tuning Gemma using KerasNLP and deploying to Vertex AI.

These resources will help you further understand and apply Gemma models for your specific tasks and projects.

--

--

BavalpreetSinghh

Consultant Data Scientist and AI ML Engineer @ CloudCosmos | Data Scientist at Tatras Data | Mentor @ Humber College | Consultant @ SL2