Your Ultimate Guide to Instinct Fine-Tuning and Optimizing Google’s Gemma 2B Using LoRA
On February 21, 2024, Google’s Keras team unveiled Gemma, a new set of lightweight open-source models. Gemma models, available in 2B and 7B parameter sizes, offer remarkable performance improvements despite their compact size. Inspired by Google’s Gemini model, Gemma derives its name from the Latin word for “precious stone.” In addition to releasing the model weights, Google has provided tools to encourage developer creativity, facilitate collaboration, and promote responsible usage of Gemma models.
These models were trained on a vast dataset of text data, totaling 6 trillion tokens, sourced from diverse origins including web documents, code samples, and mathematical texts. This comprehensive training approach exposes the model to various linguistic styles, programming syntax, and mathematical concepts, enabling it to handle a wide range of tasks effectively.
In this blog, we will cover two main topics:
1. Loading and utilizing the Gemma 2B model in 4-bit precision.
2. Learning how to effectively fine-tune the model through instruction.
Before we dive in: To access the Gemma model artifacts, users must first agree to the consent form. Once that’s settled, let’s jump right into the implementation.
You can access the notebook from here.
To access the Gemma 2B model from Hugging Face in your Google Colab, you need to provide your access token. First, you need to add the “hf_token” to the Google Colab secret keys. After that, you should add this environment variable.
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN_READ')
To ensure smooth execution, install the required packages
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0
Assuming you have access to the model artifacts from the Hugging Face Hub, you can begin by downloading both the model and tokenizer. Additionally, include a BitsAndBytesConfig for weight-only quantization.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
Now, before commencing with the fine-tuning process, it’s prudent to conduct a preliminary test of the model’s capabilities. We can do this by feeding it a well-known quote and observing its output.
text = "Quote: Imagination is more"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
model output
Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.
-Albert Einstein
I
Google released an instruction-tuned version of both the 7B and 2B models. These instruction-tuned models employ a chat template that must be adhered to for conversational use. The Gemma model follows the template format as outlined below:
<start_of_turn>user
How does the brain work?<end_of_turn>
<start_of_turn>model
lest try it :
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
model_id = "google/gemma-2b-it"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
text = """<start_of_turn>user
How does the brain work?<end_of_turn>
<start_of_turn>model"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
user
Explain 'AMSI init bypass' and its purpose.
modelSure, here's a detailed explanation of the AMSI init bypass feature and its purpose:
**AMSI init bypass:**
AMSI (Advanced Microcontroller Startup Interface) init bypass is a technique used in microcontroller initialization where the initialization process is
The output from the model is not bad, but let’s enhance it. We will leverage a dataset from ahmed000000000/cybersec, which contains induction and response pairs for the cybersecurity domain.
from datasets import load_dataset
data = load_dataset("ahmed000000000/cybersec")
Now, let’s set up the LoRA configuration.
from peft import LoraConfig
lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
Now, let’s create a custom function to format the data into the Gemma instruction template format.
def formatting_func(example):
text = f"<start_of_turn>user\n{example['INSTRUCTION'][0]}<end_of_turn> <start_of_turn>model\n{example['RESPONSE'][0]}<end_of_turn>"
return [text]
Initializes a SFTTrainer
import transformers
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=150,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_config,
formatting_func=formatting_func,
)
model
: This is the pre-trained language model (such as Gemma) that you want to fine-tune.train_dataset
: This parameter specifies the training dataset to be used for fine-tuning.args
: It contains various training arguments such as batch size, learning rate, and optimization settings. Notably,per_device_train_batch_size
sets the batch size per GPU,gradient_accumulation_steps
determines how many batches are accumulated before performing a gradient update,warmup_steps
specifies the number of steps for the learning rate warm-up,max_steps
sets the maximum number of training steps,learning_rate
defines the initial learning rate,fp16
enables mixed precision training for faster and more memory-efficient training,logging_steps
determines the frequency of logging training metrics, andoutput_dir
specifies the directory to save training outputs.peft_config
: This parameter specifies the configuration for LoRA (Low-Rank Adaptation) fine-tuning, which helps in reducing the computational cost of fine-tuning large language models.formatting_func
: This is a custom function used to format the training data into the Gemma instruction template format, ensuring compatibility with the model's input requirements.
Initiation of model training
trainer.train()
TrainOutput(global_step=150, training_loss=0.5255898060897987, metrics={'train_runtime': 121.7084, 'train_samples_per_second': 4.93, 'train_steps_per_second': 1.232, 'total_flos': 1065426810224640.0, 'train_loss': 0.5255898060897987, 'epoch': 46.15})
At last, the moment of truth is upon us! Let’s put our model through the wringer once again with the same prompt we wielded earlier.
text = """<start_of_turn>user
Explain 'AMSI init bypass' and its purpose.<end_of_turn>
<start_of_turn>model"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
user
Explain 'AMSI init bypass' and its purpose.
model
AMSI init bypass is a security feature in Windows that allows the System Management Interface (AMSI) to be initialized even when it is not necessary. This feature is designed to provide extra protection against certain types of exploits, by ensuring that the AM
text = """<start_of_turn>user
Explain 'APT groups and operations' and its purpose.<end_of_turn>
<start_of_turn>model"""
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
user
Explain 'APT groups and operations' and its purpose.
model
APT groups and operations refer to the organization and grouping of security patches and updates by release notes or patch management tools. This feature enables organizations to organize and manage patches and updates in a structured manner, making it easier to identify important security updates and apply
conclusion
In conclusion, we have successfully fine-tuned the Gemma 2B instruction model using our custom dataset with the assistance of LoRA. The implementation of this training process encapsulated within the provided code framework, can be readily adapted for application to any other dataset of choice. We appreciate your readership and engagement throughout this endeavor. Until our paths converge again, we bid you farewell and wish you continued success in your explorations of language modeling and beyond.