Using and Finetuning Google’s State-of-the-Art Open Source Model Gemma-2B
On Feb 21, 2024, the Keras team at Google introduced Gemma, a family of lightweight state of the art open source models all over the world. These models come in very portable parameter sizes — 2B and 7B. With very small size they deliver significant advances against similar and even larger open models.
Gemma is inspired by Google’s Gemini model. The name Gemma comes from the Latin gemma, meaning “precious stone.” Along with the model weights, Google has also released tools to support developer innovation, foster collaboration, and guide responsible use of Gemma models.
- Both Gemma 2B and Gemma 7B are released with pretrained and instruction tuned variants.
- Gemma models share technical and infrastructure components with Gemini, Google’s largest and most capable AI model. So Gemma 2B and 7B achieve best-in-class performance for their sizes compared to other open models. Gemma surpasses significantly larger models on key benchmarks while adhering to rigorous standards for safe and responsible outputs.
- Gemma models run across popular device types, including laptop, desktop, IoT, mobile and cloud, enabling broadly accessible AI capabilities.
- You can use it with your favorite framework, it has reference implementations for inference and fine-tuning for Keras 3.0, native PyTorch, JAX, and Hugging Face Transformers.
- Google has released a new Responsible Generative AI Toolkit that provides guidance and essential tools for creating safer AI applications with Gemma.
- Google has provided toolchains for inference and supervised fine-tuning (SFT) across all major frameworks: JAX, PyTorch, and TensorFlow through native Keras 3.0.
- It is easy to use Gemma with Google Colab and Kaggle notebooks, alongside integration with popular tools such as Hugging Face, MaxText, NVIDIA NeMo and TensorRT-LLM.
- Gemma is optimized to work on NVIDIA GPUs, from data center to the cloud to local RTX AI PCs, ensuring industry-leading performance and integration with cutting-edge technology
- Pre-trained and instruction-tuned Gemma models can run on your laptop, workstation, or Google Cloud with easy deployment on Vertex AI and Google Kubernetes Engine (GKE).
- Gemma is optimized across multiple AI hardware platforms ensuring industry-leading performance, including NVIDIA GPUs and Google Cloud TPUs.
- According to its terms of use , you are permitted responsible commercial usage and distribution for all organizations, regardless of size.
Using Gemma in KerasNLP API
You can use Gemma models in KerasNLP API and run it directly on a text-prompt. You don’t have to load a separate tokenizer as the tokenization is built into the model.
KerasNLP is a collection of natural language processing (NLP) models implemented in Keras and runnable on JAX, PyTorch, and TensorFlow.
You can fine tune Gemma with LoRA for parameter efficient fine-tuning with a single line of code. You can also fine tune Gemma models on multiple GPU/TPUs with data-parallel and model-parallel distributed training options.
Gemma Access and Setup from Kaggle.com:
You can access Gemma from Kaggle.com or from Huggingface.com. You need to have a free account on the site you are accessing the Gemma model from.
Check out Gemma’s Kaggle Model Card.
Once you have a Kaggle account, you need to request access to Gemma from you Kaggle account by filling in a consent form. The access is easily given.
- Authenticating yourself to Kaggle for model access through Colab:
Kaggle API Key: To use Gemma, you must provide your Kaggle username and a Kaggle API key. To generate and configure these values, follow these steps:
- To generate a Kaggle API key, go to the Account tab of your Kaggle user profile and select Create New Token. This will trigger the download of a kaggle.json file containing your API credentials.
- Open kaggle.json in a text editor. The contents should look something like this:
{"username":"your_username","key":"012345678abcdef012345678abcdef1a"}
Once you have kaggle.json file, Open a new Colab notebook (you should have a gmail account for this), and then:
3. In Colab, select Secrets (🔑) and add your Kaggle username and Kaggle API key. Store your username under the name KAGGLE_USERNAME and your API key under the name KAGGLE_KEY.
Once you have created you KAGGLE_USERNAME and KAGGLE_KEY Colab secrets, you can set them as your environment variables with this code from within Colab:
#Set environment variables for KAGGLE_USERNAME and KAGGLE_KEY from Colab secrets
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Authenticating using kagglehub
You can also use kagglehub to authenticate yourself to Kaggle while accessing a model.
You will need to install kagglehub package with pip:
!pip install kagglehub
# Then you import kagglehub and use its login().
import kagglehub
kagglehub.login()
# It will open a dialog box in which you need to enter your Kaggle API key as
# Found in kaggle.json file.
2. Installing Keras-NLP
KerasNLP is a NLP library for end-to-end development cycle. It is an extension of the core Keras API; all high-level modules are Layers
or Models
.
KerasNLP uses Keras 3 to work with any of TensorFlow, Pytorch and Jax backends. Keras 3 requires TensorFlow≥2.16 to work with JAX backend.
Use this code to install Keras-NLP
!pip install --upgrade keras-nlp
!pip install --upgrade keras
I used torch backend to run my code in Google Colab with free tier T4 GPU.
3. Accessing the Gemma Model from Kaggle.com
import tensorflow
import keras
import torch
import keras_nlp
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
4. Let us have a look at the model summary
gemma_lm.summary()
You will see model summary like below:
This model has no non-trainable parameters.
5. Let us pass some input prompt to the model and ask it to generate text for us:
gemma_lm.generate("What is the meaning of life?", max_length=64)
You should see the output text like below:
“What is the meaning of life? The question is one of the most important questions in the world. It’s the question that has been asked by philosophers, theologians, and scientists for centuries. And it’s the question that has been asked by people who are looking for answers to their own lives”
6. You can also configure the sampling algorithm and batch the prompts together:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate(
["What is the meaning of life?",
"How does the brain work?"],
max_length=64)
You can also use a prompt template specifying the format in which Gemma responds to your prompt like this:
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
You should see a response like this:
Now let us fine tune Gemma using LoRA rank 4on Databricks Dolly 15k dataset. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.
Fine-tuning allows us to get better responses from the model.
Low Rank Adaptation (LoRA) is an efficient fine tuning technique. The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.
- Install required libraries and load the model
import tensorflow
import keras
import torch
import keras_nlp
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
2. Download the dataset from Huggingface
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
3. Let us parse the dataset from the json format into a list. We have used only first 1000 prompt-response pairs here.
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
for line in file:
features = json.loads(line)
# Filter out examples with context, to keep it simple.
if features["context"]:
continue
# Format the entire example as a single string.
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
# Only use 1000 training examples, to keep it fast.
data = data[:1000]
4. Enable fintuning of model with LoRA of rank 4
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Now, the model summary shows that you have 1,363,968 trainable parameters.
5. Begin finetuning
We have used a fixed input sequence length of 512, AdamW optimizer, a learning rate of 5e-5 and weight decay of 0.01. We are running the training for 1 epoch.
To get better responses from the fine-tuned model, you increase the size of the fine-tuning dataset, training for epochs, using a higher rank of LoRA and vary learning rate and weight decay rate.
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
6. Use the finetuned model for inference
Now that you have finetuned the model, its responses will be better than those we got earlier:
prompt = template.format(
instruction="What should I do on a trip to Europe?",
response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
There you go. Enjoy!!
References:
https://developers.googleblog.com/2024/02/gemma-models-in-keras.html