Fine tuning Gemma with LoRA on GCP

pritam sahoo
Google Cloud - Community
6 min readApr 15, 2024

My obsession with Gemma continues. Folks new to the Gemma model can revisit my previous blog link.

In brief Gemma is the family of lightweight, state of the art (SOTA) open models powered by the same technology powering one of the most popular Google Cloud Gemini models.

In this blog we will get started with fine tuning with Gemma with LoRA.

Lets understand first a bit on fine tuning. One of the reasons finetuning is picking up is the reason Large language Models(LLMs) are not trained on specific tasks or domain related data. Primarily LLMs often called as foundational models are trained on internet scale massive corpus of data, texts etc. Doing a full training of pre-trained LLM models becomes technically challenging due to expensive computational resources as one of the major concerns.

Let’s understand the benefits of Fine tuning.

  1. Fine Tuning pre-trained model is much faster and cost effective leading to less computational resources required.
  2. Better Performances for domain specific tasks especially on industry use cases related to Financial services, Insurance , Healthcare etc.
  3. Lets not forget about democratization of GenAI models for individual users i.e. developers and others who have less computational power.

Lets understand Parameter efficient fine tuning a.k.a. PEFT. It’s a subset of fine tuning which effectively utilizes parameters/weights with efficient output. Instead of altering all the parameters of the model PEFT selects a subset of them thereby reducing computational and memory requirements. PEFT plays a major role in the fine tuning process thereby improving the performance of base/foundational LLMs on specific tasks. This is super useful when training LLM models like Gemini and its different variants, PALM,even open source Gemma models etc from Google.

We will explore fine tuning Gemma Models with LoRA. LoRA stands for Low Rank Adaptation of Large Language Models. It’s a technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights/parameters of the base model and introducing a small number of new weights into the model.

Crucial Point to consider In LoRA, the starting point hypothesis is super important . It assumes that the pre-trained model’s weights are already close to the optimal solution for the downstream tasks.

Advantages of using LoRA as fine tuning technique

  1. Reduces Parameter and memory footprint. LoRA significantly reduces the number of trainable parameters, making it much more memory-efficient and computationally cheaper.
  2. Fine tuning and so does inference is faster ~ as it uses less parameters/weights.
  3. Maintains performance: LoRA has been proved to maintain performance close to traditional fine-tuning methods in several tasks.

So let’s get started with Fine tuning with LoRA on the Gemma Model.

For this demo I will be using Google Collab Notebook to get some horsepower with T4 GPUs.

Step 1: Get access to Gemma

To complete this collab, you will first need to complete the setup instructions at Gemma setup. The Gemma setup instructions show you how to do the following:

  • Get access to Gemma on kaggle.com.
  • Select a Colab runtime with sufficient resources to run the Gemma 2B model.
  • Generate and configure a Kaggle username and API key.

After you’ve completed the Gemma setup, move on to the next section, where you’ll set environment variables for your Colab environment.

Step 2 : Select the Runtime

Step 3 : Configure your secrets i.e. username and key in Account tab

Step 4 : Select the Data for fine tuning from hugging face. Databricks Dolly 15k dataset. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs. Brief screenshot of the datasets

Step 5 : Set the environment variables and run the below commands in Collab

import os

from google.colab import userdata

os.environ[“KAGGLE_USERNAME”] = userdata.get(‘username’)

os.environ[“KAGGLE_KEY”] = userdata.get(‘key’)

Step 6 : Install the dependencies

!pip install -q -U keras-nlp

!pip install -q -U keras>=3

Step 7 : Select the backend. You may choose from PyTorch or Tensorflow or Jax

os.environ[“KERAS_BACKEND”] = “jax”.

# Avoid memory fragmentation on JAX backend.

os.environ[“XLA_PYTHON_CLIENT_MEM_FRACTION”]=”1.00"

Step 8 : Import Packages i.e. Keras and KerasNLP.

import keras

import keras_nlp

Step 9 : Load the dataset from hugging face.

!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

Step 10 : For this demo purpose I will be using a subset of 1000 examples instead of 15K examples. For better fine tuning you may use more examples.

import json

data = []

with open(“databricks-dolly-15k.jsonl”) as file:

for line in file:

features = json.loads(line)

# Filter out examples with context, to keep it simple.

if features[“context”]:

continue

# Format the entire example as a single string.

template = “Instruction:\n{instruction}\n\nResponse:\n{response}”

data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.

data = data[:1000]

Step 11 : Now its time to Load the Gemma 2B base Model. You may try using the Gemma 7B base model.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(“gemma_2b_en”)

gemma_lm.summary()

You will see below summary output if everything is working fine.

Step 11: Lets Inference the Model before fine tuning.

Pass the below prompt i.e. “ What should I do on a trip to Europe?”

prompt = template.format(

instruction=”What should I do on a trip to Europe?”,

response=””,

)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)

gemma_lm.compile(sampler=sampler)

print(gemma_lm.generate(prompt, max_length=256))

You will see very generic blant and not so great output from the base model as mentioned below

— — — — — — — — — — — — — — — — — — — — —

Instruction:

What should I do on a trip to Europe?

Response:

It’s easy, you just need to follow these steps:

First you must book your trip with a travel agency.

Then you must choose a country and a city.

Next you must choose your hotel, your flight, and your travel insurance

And last you must pack for your trip.

— — — — — — — — — — — — — — —

Step 12: Lets fine tuning using LoRA using Databricks Dolly 15K dataset.

LoRA rank. It controls the expressiveness and precision of the fine-tuning adjustments.Lower rank means which requirement of computational power and also less precision adaptation. You may start with 4,8 etc for demo/experimentation purposes.

>> gemma_lm.backbone.enable_lora(rank=4)

>> gemma_lm.summary()

Total params: 2,507,536,384 (9.34 GB)

Trainable params: 1,363,968 (5.20 MB)

Non-trainable params: 2,506,172,416 (9.34 GB)

While you run the below section in the collab notebook be patient as it will take some time and you will see reduction in losses.This step will reduce the number of trainable parameters significantly.Epoch = 1 means it will run for 1 time for 1000 datasets.

gemma_lm.preprocessor.sequence_length = 512

optimizer = keras.optimizers.AdamW( // AdamW ~ optimizer for transformer models

learning_rate=5e-5,

weight_decay=0.01,

)

optimizer.exclude_from_weight_decay(var_names=[“bias”, “scale”])

gemma_lm.compile(

loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),

optimizer=optimizer,

weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],

)

gemma_lm.fit(data, epochs=1, batch_size=1)

The output from the above step will show significant reduction in loss with just 1000 datasets.

Step 13: Let’s get started with Inferencing post fine tuning.

Pass the below prompt again i.e. “ What should I do on a trip to Europe?”

prompt = template.format(

instruction=”What should I do on a trip to Europe?”,

response=””,

)

sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)

gemma_lm.compile(sampler=sampler)

print(gemma_lm.generate(prompt, max_length=256))

**** Let me know the results. Must be better than before finetuning.

Thats’ it folks on Gemma fine tuning with LoRA. Stay tuned for more updates coming your way on QLoRA……..

Signing off…. Pritam

--

--