Fine-tune 4-bit Llama-2–7B with Flash Attention Using DPO

Drishti Sushma
5 min readSep 11, 2023

--

Introduction

This is essentially a documentation of the training process of 4-bit llama-2–7b model which I was trying to fine-tune on Stack-exchange dataset using DPO, but for some reason, the training prematurely stopped in between and since I’m running low on compute hence I’m not able to retrain at the moment. However, I hope to get it trained as soon as possible.

Now let’s begin!

What is DPO?

DPO offers a streamlined method for optimizing human-derived preferences in LLMs, such as GPT-4 or Claude. Traditional models utilize reinforcement learning (RL) to train models based on human feedback. This process, known as Reinforcement Learning from Human Feedback (RLHF), involves building a good reward function and carefully training the model to produce sensible text that aligns with human expectations.

DPO simplifies the RLHF process. Instead of using RL and a reward model, DPO employs a direct binary cross-entropy loss to fine-tune models. This method is significantly more straightforward and eliminates many complexities associated with RLHF.

How does DPO differ from PPO?

While the traditional RLHF method uses an auxiliary reward model to fine-tune models, DPO skips the reward modeling step entirely. It uses an analytical mapping from the reward function to the optimal RL policy, allowing for direct optimization of the language model based on preference data. Essentially, DPO simplifies the optimization process by focusing on the reference model and omitting the need for RL-based optimization.

Training with TRL’s DPO:

The TRL library, which supports DPO, provides tools for the entire RLHF pipeline. However, with DPO, only supervised fine-tuning (SFT) and data annotation for preference labels are needed. The DPOTrainer in TRL optimizes the model using the preference data.

For example, to train with the Stack Exchange preference dataset, you’d need to format the data appropriately. After processing the data, you can use the DPOTrainer, which needs the base model from the SFT pipeline, a reference model, and other necessary parameters to begin training.

So in short, for training with TRL’s DPO we will need to do following three steps:

  1. a supervised fine-tuning (SFT) step
  2. the process of annotating data with preference labels
  3. provide the DPOTrainer in TRL with preference data from step 2 which has a very specific format

Precisely, the data should be in dictionary format with the following three keys:

  • Context Prompt: Information provided to the model at inference time for text generation.
  • Chosen Response: The preferred generated response corresponding to the prompt.
  • Rejected Response: The response that is not preferred or should not be selected as the generated response for the given prompt.

Now let’s begin with the training process!

Setup and Preliminary Configuration

The foundational tools and libraries for our experiment include accelerate, peft, bitsandbytes, transformers, and trl, among others.

!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.5.0
!pip install -q sentencepiece

Hardware checks ensure compatibility with Flash Attention, postulating a CUDA capability of 8 or more.

!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!pip install ninja packaging
!MAX_JOBS=4 pip install flash-attn - no-build-isolation

Now we will define a function that maps the dataset entries to return the desired dictionary.

from typing import Dict, List, Union

def return_prompt_and_responses(samples) -> Dict[str, Union[str, List[str]]]:
return {
"prompt": [
"Question: " + question + "\n\nAnswer: "
for question in samples["question"]
],
"chosen": samples["response_j"], # rated better than k
"rejected": samples["response_k"], # rated worse than j
}

Loading the Dataset and Processing it

from datasets import load_dataset
dataset = load_dataset(
"lvwerra/stack-exchange-paired",
split="train",
data_dir="data/rl"
)

We obtain the Stack-Exchange dataset, suitably formatted for our task. The data is shuffled, segmented, and processed to generate prompts and responses, and is further partitioned into training and test sets.

original_columns = dataset.column_names
dataset = dataset.shuffle(seed=42).select(range(2000))
dataset
dataset.map(
return_prompt_and_responses,
batched=True,
remove_columns=original_columns
)

AFAFS

dataset = dataset.train_test_split(seed=42, shuffle=True, test_size=0.1)
print(dataset)

Model Loading and Flash Attention Implementation

Llama-2–7b is prepared with considerations for Flash Attention.

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
model_id = "NousResearch/Llama-2–7b-hf"
use_flash_attention = True

# replace attention with flash attention
if torch.cuda.get_device_capability()[0] >= 8:
from utils.llama_patch import replace_attn_with_flash_attn
print("Using flash attention")
replace_attn_with_flash_attn()
use_flash_attention = True

Once it’s confirmed that it’s using Flash Attention, the model is loaded with 4-bit quantization settings.

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
use_auth_token=True,
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1


# Validate that the model is using flash attention, by comparing doc strings
if use_flash_attention:
from utils.llama_patch import forward
assert base_model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"



tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

LoRA Integration and Model Preparation

LoRA (Low-Rank Adaptation) is incorporated atop the quantized base model. Post-LoRA integration, the model undergoes additional configurations to make it training-ready.

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)


# prepare model for training
base_model = prepare_model_for_kbit_training(base_model)
base_model = get_peft_model(base_model, peft_config)

Extending Model Compatibility with Torch Bfloat16

Now we will apply ‘upcast_layer_for_flash_attention’ to extend model compatibility with Torch Bfloat16

from utils.llama_patch import upcast_layer_for_flash_attention
model = upcast_layer_for_flash_attention(base_model, torch.bfloat16)

Defining Training Arguments

from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="llama-7-int4-stack-exchange",
max_steps=100,
per_device_train_batch_size=6 if use_flash_attention else 4,
optim="paged_adamw_32bit",
logging_steps=10,
save_strategy="steps",
learning_rate=2e-4,
bf16=True,
tf32=True,
lr_scheduler_type="constant",
disable_tqdm=True # disable tqdm since with packing values are in correct
)

Setting Up the SFTTrainer

from trl import SFTTrainer

max_seq_length = 2048 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
model=base_model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=return_prompt_and_responses,
args=training_args,
)

Training and Saving the Model

trainer.train()

# save model
trainer.save_model()

DPO (Differential Privacy Optimization) Training

Subsequent to the primary training phase, the model is subjected to DPO training for enhanced privacy considerations. This involves a separate trainer and an extended training process, followed by model saving.

model = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path, # location of saved SFT model
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path, # same model as the main one
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

Conclusion

Our endeavor to train the 4-bit llama-2–7b model using DPO faced challenges, but the potential of DPO is evident. DPO offers a streamlined approach to fine-tuning language models, sidestepping traditional complexities. The TRL library further simplifies this by providing essential tools for the process. As we embark on the training journey, it’s vital to understand the steps involved and the ultimate goal: a model that aligns with human preferences.

--

--