Aligning LLMs with Direct Preference Optimization (DPO)— background, overview, intuition and paper summary

Manish Chablani
3 min readFeb 9, 2024

--

Direct Preference Optimization (DPO) is a stable, performant, and computationally lightweight, technique for aligning LLM’s with a simple classification loss. DPO eliminates the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning.

Before we explain DPO lets go over a typical way to align LLM’s using RLHF (Reinforcement learning from Human Feedback) pipeline using reward model RM.

At its core RLHF is trying to use human feedback to align LLM to specific use case (it can be responding to instruction in a safe and useful way, avoiding biases, avoiding copyright violations, citing sources, providing explanations along with solutions, etc).

RLHF, involves a reward model RM (typypicaly another LLM) that is trained based on human feedback to generate a score that represents a reward. The underlying goal is to get a model or system that takes in a sequence of text, and returns a scalar reward which should numerically represent the human preference.

Once the RM is trianed from human annotations, it can be used to scale the finetuning of the original pretrained LLM — aligning. During this process the original LLM is provided with many test prompts, and the completions are sampled and the reward modelk score is used to update the weights of the LLM. This process is complex and computationally heavy as well as memory intensive as it involves 3 LLM’s (source LLM that is getting sampled, RM and the one that is getting optimized/aligned/finetuned).

Here is high level overview form (https://huggingface.co/blog/rlhf)

Note the KL prediction shift penealty in the figure above is a training trick that is used to prevent the “reward hacking” where the LLM learns to output certain text to always optimize for the reward. KL penalty is useful to keep the distribution of tokens somewhat similar between the source and the aligned LLM.

Now lets look at DPO and its variants

Intuitively DPO uses preference data (given a context/prompt, there is a preferred/good response over a dis-preferred/bad response).

At the heart of DPO is formulation of loss function that considers the likelihood of preferred response over dis-preffered response and optimizes the LLM model towards that objective:

Dataset of preferences {(x,yw​,yl​)}, where x is a prompt and yw​, yl​ are the preferred and dis-preferred responses.

Policy reformulation for DPO:

loss formulation:

Algorithm

  • Sample good/bad response
  • Run pairs through 2 models (active and reference)
  • Backprop

Beta B in the loss function above is one of the important hyperparam in DPO and ranges from 0 to 1. It controls how much to weight the preference of the reference model.

Credits:

https://docs.google.com/presentation/d/1S8ao40-CdclRU0D2D9FdyN5x8fZL1Iv5/

https://arxiv.org/abs/2305.18290

--

--

Manish Chablani

Head of AI @EightSleep , Marathoner. (Past: AI in healthcare @curaiHQ , DL for self driving cars @cruise , ML @Uber , Early engineer @MicrosoftAzure cloud