Paper Insights: Gemma — Google’s Breakthrough in Open-Source AI from Gemini Research

Abhishek Selokar
8 min readAug 13, 2024

--

In 2024, Google unveiled Gemma, a groundbreaking series of open-source AI models set to transform the landscape of artificial intelligence. Named after the Latin word for ‘precious stone’, Gemma offers two powerful variants: a compact 2 billion parameter model and a more robust 7 billion parameter version. Despite their smaller size, these models outperform many larger competitors on key AI benchmarks, all while adhering to strict safety and ethical standards. Join us as we explore the technicalities of Gemma.

Source

Table of Content
· 1. Overview
·
2. Model Architecture
2.1. Attention strategy
2.2. RoPE Embeddings
2.3. GeGLU Activations
2.4. RMSNorm
·
3. Pretraining
3.1. Training Data
3.2. Filtering
·
4. Instruction Tuning
4.1. Supervised Fine-Tuning
4.2. Filtering
4.3. Formatting
4.4. Reinforcement Learning from Human Feedback
·
5. Evaluation
·
References

1. Overview

Gemma is inspired by the Gemini Models (Google’s most capable LLM models). Similar to Gemini, these models demonstrate robust generalist capabilities in text domains, coupled with state-of-the-art understanding and reasoning abilities at scale. It has been trained on up to 6T tokens of text, using the same architecture, data, and training recipes as Gemini.

Gemma is available in two sizes:

  • 7 billion parameter model optimized for efficient deployment and development on GPU and TPU
  • 2 billion parameter model designed for CPU and on-device applications.

It beats the current state-of-the-art source models with the same size or relatively larger size (LLaMA 2-[7B 13B], Mistral- 7B)across a wide range of domains such as coding, mathematics, and science, question answering, common sense reasoning

Source

2. Model Architecture

Gemma is based on the Transformer Decoder Model ( the Autoregressive model) that is trained on a context length of 8192 tokens.

Source: Key Model Parameters

2.1. Attention strategy

7B model uses Multi-Head Attention (MHA) while the 2B model uses Multi-Query Attention(MQA).

Multi-Head Attention (MHA): In MHA, we use multiple “heads” to focus on different parts of the input data simultaneously. Each head looks at the input from a different perspective, capturing diverse information. This method gives high-quality results because it gathers a wide range of insights from the data. However, it’s computationally heavy because each head requires its own set of calculations.

Multi-Query Attention (MQA): In contrast, MQA simplifies the process by using just one attention mechanism for all queries, keys, and values. This approach is faster and less resource-intensive since it performs only one set of calculations. However, the trade-off is that MQA may not capture as much detail or nuance, potentially leading to lower-quality results compared to MHA.

2.2. RoPE Embeddings

Rotary Positional Embedding is used in each layer instead of using absolute positional embeddings. It help models understand the order of words in a sequence. Unlike traditional methods, which use fixed positions, RoPE uses rotation to encode positions. This allows the model to handle sequences of any length and understand how the importance of words changes depending on their distance from each other.

2.3. GeGLU Activations

GeGLU is a variant of GLU which stands for Gated Linear Units and is defined as the componentwise product of two linear transformations of the input, one of which is sigmoid-activated.

GLU(x, W, V, b, c) = σ(xW + b) ⊗ (xV + c)

When the sigmoid function is replaced with Gaussian Error Linear Units (GELU), it is known as GEGLU. GELU activation function is xΦ(x), where Φ(x) the standard Gaussian cumulative distribution function

GELU = xΦ(x)

GEGLU(x, W, V, b, c) = GELU(xW + b) ⊗ (xV + c)

2.4. RMSNorm

The input of each transformer sub-layer, the attention layer, and the feedforward layer is normalized with RMSNorm to stabilize the training.

RMSNorm regularizes the summed inputs to a neuron in one layer according to root mean square (RMS), giving the model re-scaling invariance property and implicit learning rate adaptation ability. — Source

3. Pretraining

3.1. Training Data

Gemma 2B and 7B models are trained on 3 trillion and 6 trillion tokens, respectively, primarily sourced from English web documents, mathematics, and code.

The vocabulary size is 256k tokens. The tokenizer used for preparing training data is SentencePiece tokenizer which splits digits, keeps extra whitespace, and relies on byte-level encodings for unknown tokens.

3.2. Filtering

  • Pre-training data is required to get filtered first to remove low-quality data, harmful, and inappropriate data.
  • In Gemma, it is done using both heuristics and model-based classifiers.
  • All datasets used for evaluating the model are excluded from the pre-training data to avoid any overlap, ensuring that the model isn’t inadvertently tested on data it has already seen during training.

4. Instruction Tuning

After pre-training, the model becomes good at predicting the next word or token in a sequence. However, it still struggles to follow instructions accurately and may produce responses that aren’t quite right or don’t sound human-like. To address this, the model undergoes instruction fine-tuning, where it is further trained using a dataset of instructions paired with their correct responses. This helps the model learn how to generate more accurate and appropriate answers.

4.1. Supervised Fine-Tuning

This stage involves refining the model by evaluating and selecting the best responses to prompts. The process uses a larger and capable model as a judge to compare responses from the test model(the one being fine-tuned) and a baseline(an earlier or simpler version), focusing on key aspects like instruction following, factual accuracy, creativity, and safety. The larger model judges responses using advanced prompting techniques such as Chain of thought (CoT) to ensure alignment with human preferences, ultimately making the model more reliable and human-like in its outputs.

4.2. Filtering

The synthetic data generated is filtered to remove any duplicate data, harmful, inappropriate, Personally identifiable information (PII )and low quality data.

4.3. Formatting

During the Intruction Fine tuning stage the data is formated first. It essential to ensure the model accurately recognizes conversational roles and turn-taking, leading to more coherent and contextually appropriate responses. Special token os reserved in the vocabulary to be used in this process such as shown in below image (left).

Purpose of Formatting:

  1. Indicate Roles: The formatter marks different roles in a conversation, like identifying who is the user.
  2. Delineate Turns: It helps in clearly separating different turns in a conversation, especially for multi-turn interactions.

4.4. Reinforcement Learning from Human Feedback

It involves fine-tuning a model by using human-rated preference pairs to train a reward function. This reward function guides the model to optimize responses using a reinforcement learning algorithm. RLHF helps in making the model’s responses more accurate, relevant, and user-friendly. It helps improve the quality and safety of AI-generated content by incorporating direct human feedback into the training process.

5. Evaluation

Model is evaluated on variety of benchmarks such as MMLU, TriviaQA ,MATH, HumanEval, HellaSwag, TruthfulQA and many more.

MMLU :This is a massive multitask test consisting of multiple-choice questions from various branches of knowledge. The test spans subjects in the humanities, social sciences, hard sciences, and other areas that are important for some people to learn.- Source

TruthfulQA is a benchmark to measure whether a language model is truthful in generating answers to questions. The benchmark comprises 817 questions that span 38 categories, including health, law, finance and politics. -Source

HumanEval It used to measure functional correctness for synthesizing programs from docstrings. It consists of 164 original programming problems, assessing language comprehension, algorithms, and simple mathematics, with some comparable to simple software interview questions. — Source

HellaSwag to test commonsense natural language inference (NLI) about physical situations. — Source

TriviaqQA is a reading comprehension dataset containing over 650K question-answer-evidence triples. TriviaqQA includes 95K question-answer pairs authored by trivia enthusiasts and independently gathered evidence documents, six per question on average, that provide high quality distant supervision for answering the questions. — Source

Bias in Open-ended Language Generation Dataset (BOLD) is a dataset to evaluate fairness in open-ended language generation in English language. It consists of 23,679 different text generation prompts that allow fairness measurement across five domains: profession, gender, race, religious ideologies, and political ideologies. — Source

GSM8K (Grade School Math 8K) is a dataset of 8.5K high quality linguistically diverse grade school math word problems. The dataset was created to support the task of question answering on basic mathematical problems that require multi-step reasoning. — Source

Gemma model is then compared with other similarly sized Open Source models such as LLaMA-2 and Mistral on the aforementioned benchmarks. In most of the cases Gemma outperforms other models in multiple benchmarks as can be seen here

Source

Eager to explore the technical depths of LLama 3.1? Check out in-depth blog for all the insights!

References

--

--

Abhishek Selokar

Masters Student @ Indian Institute Of Technology, Kharagpur || Thirsty to learn more about AI