Deriving DPO’s Loss
Direct preference optimisation has become critical for aligning LLMs with human preferences. I have been talking to many people about it and noticed that many require help deriving its equations so they can build on top of it and improve it. This article is dedicated towards a step-by-step derivation of DPO. I hope it’s helpful to you.
Bradley Terry (BT) Model:
Imagine you are given two answers and would like to compare which answer is better. The BT model offers a way to execute that comparison such that:
The annoying part of the above equation is the reward function! Of course, we could model this reward and follow something like PPO, like what we do in RLHF. The problem is that learning this reward function can take time and effort. That is why DPO aims to eliminate it using available pair-wise data!
What is our Goal?
Let us imagine we have the reward function for a second! We want to update our LLM parameters to generate answers that maximise such a reward while remaining close to our original reference model. This way, our LLM ‘’aligns’’ with the reward while not losing its language properties since we would like to stay close to a reference LLM. Formally, we wish to maximise the following optimisation problem:
Strategy: Remember, our goal is to eliminate the reward. We will achieve this by doing some math on the above optimisation problem and looking at the optimal policy we can get for a KL-constraint RL problem. Generally, this is given as is in papers we read, but we are going to derive it next.
Before we start, we note that the optimisation problem above has a constraint that the LLM should always be a valid probability distribution. As such, our overall problem can be stated as:
To solve the above problem, we follow the standard approach to constraint optimisation by rewriting the above in an unconstraint form by introducing Lagrangian multipliers, giving us the following:
To solve this unconstrained problem, we will consider its gradient and attempt to find the extrema. Don’t worry, it is straightforward! We will take it step by step!
From the above, we notice that there are three terms that we need to differentiate: Terms I, II, and III in the above figure. For Term I, we write:
Now, consider Term II (the KL term):
Since Term II = Term II.1 + Term II.2, we can now write:
Jumping to the last term (Term III), we can easily see that:
As discussed earlier, our goal was to find the extremum of our optimisation objective. To do that, we now set the gradient to zero and find the equation of the LLM’s policy with the other variables:
We are getting closer to Equation 4 on page 4 of the DPO paper, but we still need to arrive! We still need to nail that ugly normalisation term. Let us think about this a bit!
Generally, when you have Lagrange multipliers, you would also do the gradients with respect to them and so-forth. Rather than doing all that mess, notice that in our case the Lagrangage multiplier is responsible for making the policy a valid probability distribution (remember the constraint we had from the beginning). So one approach we can follow to determine it is value is we can see what lambda would make my pi^star above sum to 1 — i.e., become a valid probablity distribution. Let us work through that a bit!
… and that is Equation 4!
What remains to be done is to get the equation of r(x,y) from the above equation and replace it in the BT model. Can you do it as an exercise?
If you need any help, let me know!
Please feel free to like and subscribe!
The DPO paper can be found here: https://arxiv.org/pdf/2305.18290