Gemma — A new Large Language Model

BavalpreetSinghh
5 min readFeb 25, 2024

--

This guide will assist you in starting your journey with Gemma through KerasNLP. Gemma comprises a set of lightweight, cutting-edge open models developed from the same research and technology underpinning the creation of the Gemini models. KerasNLP, on the other hand, is a compilation of natural language processing (NLP) models crafted in Keras and compatible with JAX, PyTorch, and TensorFlow.

Throughout this tutorial, you’ll leverage Gemma to generate text responses to various prompts. If you’re new to Keras, you may find it beneficial to peruse “Getting Started with Keras” before diving in. Your understanding of KerasNLP will naturally deepen as you progress through this tutorial.

Photo by Priscilla Du Preez 🇨🇦 on Unsplash

Gemma Setup

To complete this tutorial, you will first need to complete the setup instructions for Gemma. The Gemma setup instructions guide you through the following steps:

- We will go through Kaggle because it offers free access and doesn’t require a cloud account.

- Sign in or register at Kaggle.com.
- Open the Gemma model card and select “Request Access.”
- Complete the consent form and accept the terms and conditions.

Following these steps will enable you to set up Gemma and proceed with the tutorial seamlessly.

Install dependencies

pip install --upgrade keras-nlp
pip install --upgrade keras

It’s crucial to reinstall Keras 3 after installing KerasNLP. This step is temporary because TensorFlow is currently pinned to Keras 2. However, this won’t be needed once TensorFlow 2.16 is released. The reason behind this is that keras-nlp relies on tensorflow-text, which installs tensorflow==2.15. Consequently, this overwrites your Keras installation with keras==2.15. For more you can refer Keras Documentation.

Import packages

import keras
import keras_nlp
import os

Choosing a backend

In Keras 3 ,you have the flexibility to select the backend: TensorFlow, JAX, or PyTorch. All three options are compatible with this tutorial. With Keras 3, any model can be instantiated as a PyTorch Module, exported as a TensorFlow SavedModel, or instantiated as a stateless JAX function. This versatility allows you to maintain a single implementation for your components (e.g., a single model.py along with a checkpoint file) and use it seamlessly across all frameworks while ensuring consistent numerical results.

os.environ["KERAS_BACKEND"] = "torch"  # Or "tensorflow" or "jax".

Model curation

In KerasNLP, you’ll discover implementations of various popular model architectures. Specifically, under Gemma, you’ll find components like GemmaTokenizer, GemmaPreprocessor layer, GemmaBackbone model, GemmaCausalLM model, and GemmaCausalLMPreprocessor layer.

Let’s delve into GemmaBackbone first. This backbone serves as the foundation for the Gemma model, encompassing embedding lookups and transformer layers. It doesn’t generate predictions across the entire vocabulary space; instead, it produces final hidden states for each token. For text generation at a higher level, you’ll want to explore keras_nlp.models.GemmaCausalLM.

Therefore, in this tutorial, we’ll construct a model using GemmaCausalLM, an end-to-end Gemma model tailored for causal language modeling. Causal language modeling entails predicting the subsequent token based on previous tokens.

from_preset method

GemmaCausalLM.from_preset()

To instantiate the GemmaCausalLM model, you’ll utilize a preset architecture along with its associated weights.

Here are the arguments required:

  • preset: A string specifying one of the following presets: "gemma_2b_en", "gemma_instruct_2b_en", "gemma_7b_en", or "gemma_instruct_7b_en".
  • load_weights: This parameter determines whether pre-trained weights should be loaded into the model. By default, it's set to True.
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
"gemma_2b_en",
load_weights=False
)

Use summary to get more info about the model:

gemma_model.summary()
Gemma 2B model summary with pretrained weights

Text Generation

Let’s dive into text generation! The model provides a handy generate method specifically designed for this task. By providing a prompt, you can generate text effortlessly. Additionally, you can specify the maximum length of the generated sequence using the optional max_length argument.

Here’s how it works:

  • The generate method takes inputs and generates text based on them. You can set the sampling method for generation using the compile() method.
  • If the inputs are provided as a tf.data.Dataset, the outputs will be generated batch by batch and then concatenated. Otherwise, all inputs will be treated as a single batch.
  • If a preprocessor is attached to the model, inputs will be preprocessed within the generate() function. These inputs should match the structure expected by the preprocessor layer, typically raw strings. If there's no preprocessor attached, inputs should align with the structure expected by the backbone.
prompt = "i am listening to"

gemma_model.compile(sampler="top_k")

gemma_model.generate(prompt, max_length= 10)
#output
'i am listening toांत Painting poses perbaikan poses'
#changing the sampler
gemma_model.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2))

gemma_model.generate(prompt, max_length=0)
#output
'i am listening to leds leds RSVP RSVP RSVP'

You have the flexibility to recompile the model with various `keras_nlp.samplers` objects to fine-tune text generation. By default, the model utilizes “greedy” sampling. However, there are other samplers available for experimentation, including CustomSampler, beam, contrastive, random, and top_p.

Feel free to explore and experiment with these samplers to observe their effects on text generation. It’s an opportunity to tailor the generation process to your specific needs and preferences.

You have the ability to assess the quality of a generated sequence based on the token IDs provided. For the sake of understanding the functionality we will be artifitially curating them as tensorflow addons have installation issues on kaggle platform.

score method

GemmaCausalLM.score(
token_ids,
padding_mask=None,
scoring_mode="logits",
layer_intercept_fn=None,
target_ids=None,
)

Arguments:

  • token_ids: A tensor of shape [batch_size, num_tokens] containing the tokens to be scored. Typically, this tensor captures the output from a call to GemmaCausalLM.generate(), encompassing tokens for both the input text and the model-generated text.
  • padding_mask: A tensor of shape [batch_size, num_tokens] indicating the tokens that should be preserved during generation. This is primarily an artifact required by the GemmaBackbone and doesn't significantly affect the computation of this function. If omitted, this function generates a tensor of appropriate shape using keras.ops.ones().
  • scoring_mode: Specifies the type of scores to return, either "logits" or "loss". Both will be computed per input token.
  • layer_intercept_fn: An optional function for augmenting activations with additional computation, typically used for interpretability research. This function is passed the activations as its first parameter and a numeric index associated with that backbone layer.
  • target_ids: A tensor of shape [batch_size, num_tokens] containing the predicted tokens against which the loss should be computed. If a span of tokens is provided (sequential truthy values along axis=1 in the tensor), the loss will be computed as the aggregate across those tokens.
import tensorflow as tf
generations = gemma_model.generate(
["what is", "Where are you"],
max_length=10
)
preprocessed = gemma_model.preprocessor.generate_preprocess(generations)
generation_ids = preprocessed["token_ids"]
padding_mask = preprocessed["padding_mask"]
target_ids = tf.random.uniform(shape=(8192,), dtype=tf.int32, minval=0, maxval=8192)
# Convert the tensor to int32
target_ids = tf.cast(target_ids, dtype=tf.int32)
losses = gemma_model.score(
token_ids=generation_ids,
padding_mask=padding_mask,
scoring_mode="loss",
target_ids=target_ids,
)

Link to the notebook — https://www.kaggle.com/code/bavalpreet26/gemma-a-new-llm

Unfortunately In order to get the output we have to Upgrade to Google Cloud AI Platform Notebooks. But you can get the essence of how we can calculate the losses.

In next tutorial we will learn how to do the fine tuning of Gemma 2B model.

--

--

BavalpreetSinghh

Data Scientist at Tatras Data | Mentor @ Humber College | Consultant @ SL2