Understanding DDPM objective, Part 3

Luv Verma
10 min readApr 12, 2023

--

Welcome to Part 3 of the exciting blog series, where we’ll dive deep into the world of diffusion models — a cutting-edge class of generative models that are taking the AI landscape by storm!

In this part 3 of the blog series, we will finally unveil the much anticipated objective function for the Denoising Diffusion Probabilistic Models (DDPM) paper. Prepare to be amazed as we reveal how the training algorithm ingeniously builds upon this captivating objective function. The equation numbers in this blog continue from blog 2.

(link to part 1)

In continuation from the second blog (link to part 2)…

Let’s start by recapping equation 19 (in part 2 of the series) and continue the journey to reach the objective function:

Equation 19, in part 2 was given as:

In explanation, I mentioned as follows:

Think of a scenario, we know we have an image which is completely noisy, we know we have to go back to the original image, and we know we have an objective function for that (equation 17), but that is unknown. What can we do? We have just learned the definition of KL-divergence and we know all about the forward diffusion process. We know the primary aim is to get back to the original image (x0) in the reverse diffusion process. We know that we have a neural network in our kitty and we know about Bayes Theorem (Wikipedia rules!). Now if we think about it, we can say that okay, we know the original image/distribution (in forward diffusion process) and let’s say we can get another distribution using the parametrized network. we can calculate the KL-divergence between them and add it to the first term which we know nothing about.”

Moving forward:

Let’s first deal with term 2 in equation 19.

If we see equation 19 and 20 carefully, we can deduce that we can just deal with the terms inside the log in equation 20, since other than KL divergence, things will get cancelled.

Therefore, with the help of above explanation, equation 20, reduces to:

Wow, equation 21 is a relief. Because, now it’s just applying the Bayes Theorem. Let us start with applying Bayes theorem to the denominator of term 2, which will give us following set of equations (equation 22).

substituting from equation 22, into the 2nd term in equation 21, we would get:

We started from some abstruse terms in equation 19 (which I tried to explain above), and by equation 23, we are left with the log of a fraction. Let us try to simplify further. Do not forget that the aim here is to reach a simpler objective function keeping in mind that we need to get back to the original image x(0).

Using Bayes Net, we can simplify the denominator. We know that we parametrized the transition function while going in reverse direction i.e from time ‘t’ to time ‘t-1’. We know according to Markov Models, the image at a time ‘t-1’ will only depend on its immediate parent and not on any other ancestor. Therefore, the denominator of equation 23, can be simplified as:

Similarly, from Markov Models and a rule of conditional independence (basic probability), we can also modify the numerator:

Substituting terms from equations 24 and 25 into the RHS term in equation 23, we get as follows:

Please note that the products in equations 24 and 25 got converted into summations in equation 26 because of the log terms.

For simplifying equation 26 further, the DDPM paper has split the summation as follows:

Let’s look at term-2 in equation 27. Using Bayes Theorem, we can re-write the numerator of equation 27 as follows:

Let’s talk about the term q(x(t-1)|x(t)) in equation 28.

We know that image x(t) is noisy. We know that even if we have information of the image at x(t), it’s hard to move in the right direction even from time ‘t’ to time ‘t-1’. But, can we say the same about the difficulty of this task, if we assume that the original image at time 0 is given? Obviously No!! Knowing what is the original image, we can move in the right direction.

Therefore, let’s introduce x(0) in equation 28, and that leads to:

Due to the relation in equation 29, the term 2 in equation 27 will be modified as follows:

Something to observe, term 2 in equation 30, can be simplified into a very simple form.

Now using equation 30 (which modified the term 2 in equation 27, by considering x(0)) and equation 31, let us re-try to simplify equation 27.

Just, to reiterate, equation 27 was modified to become, as follows:

Modifying equation 32 above further, to get the following set of equations (equation 33):

The term 1 on the RHS of the last equation in the set of equations above (equation 33) can be canceled since the two distributions are known and will be approximately the same and thus their log will be a very small value. Therefore, we would arrive at the following set of equations:

How to get further from equation 34? In term 1 (KL-divergence term), we have q and p(theta). If you remember, earlier the issue was that in was impossible to find the image at a time (t-1) given only the image at time t in the reverse diffusion process. However, the problem became tractable after we conditioned it on the original image, which is great news right? Do you know why? Because now we can represent it in the form of Gaussian(Normal) distribution.

We also know the formulation from p(theta) from equation 2. Also, we know that variance is fixed.

According to the DDPM paper, the normal distribution used to represent the ‘q’ in term 1 in equation 34 is:

Reiterating, equation 2, we had:

Since, from equation 35 and equation 2, and KL divergence in equation 34, what we want is:

To achieve this, in the DDPM paper, the objective/loss function is specified as:

Since the term in the denominator is constant, we can ignore it anyways (also ignored in the DDPM paper). Thus, the loss function can be re-written as:

Now intuitively, the first term in equation 38 depends only on x(t) and x(0) and thus we can think of expressing it as a weighted average of two constants. Thus, let us assume we can represent the first term in equation 38 as:

Please note that the first term in equation 38 has a closed form as given in the DDPM paper:

However, for the purpose of deriving/reaching the objective function, I will keep things simple and use equation 39 in the explanation to reach to the objective function.

Re-writing equations 8 and 9, using the parametrization trick:

Rewriting x(0) from equation 41, in terms of x(t) and substituting in equation 39, we would have as follows:

The second term in equation 38 can also be re-written in a similar manner as done in the last equation 42 (this is followed in the DDPM model, the constant remains the same):

Please note that the first term in equation 43 follows the structure of closed form solution (equation 38) and thus is given as:

Substituting equations 43 and 39 in equation 38, we will have the objective/loss function in the form of:

There, we have how the KL-divergence term from equation 34 looks. Therefore equation 34, can be re-written as:

Equation 47, consists of the objective function we were looking for. However, there is some work that is still needed on this equation. What happens to the second term in the equation? Well, how would we use the first term for training and sampling? So, Let’s move further.

In the DDPM paper, the second term has been simplified such that it is clearly differentiable with respect to parameters (Theta), however, they have quoted the following: “However, we found it beneficial to sample quality (and simpler to implement) to train on the following variant of the variational bound”. This means we can just keep the first term in equation 47, for training. Let me re-write it in a very crude form:

Equation 48 is the first term from equation 47 and is written separately since that will be the objective function for training. However, since we have to train, equation 48 is not enough, since we still do not know how to use it. Let’s keep on working on it further. By substituting equation 41 (at least use some of the constant factors we derive above in equations 7,8, 9) in equation 48, we get:

Let’s break down the above equation.

  • x(0): The initial state of the diffusion process. x(0) is the original image (uncorrupted). How can I sample this?
  • ‘t’: time step. The time step ‘t’ is uniformly sampled between 1 and the final time step ‘T’.
  • ϵ: The noise sampled from a normal distribution, N(0, I).
  • ϵ(θ): The noise modeled by the DDPM, which depends on the model parameters ‘θ’, the initial state (original image) x(0), and the time step ‘t’. In short, ϵ(θ) is the neural network.
  • α̅_t: The decay factor for the diffusion process at time step ‘t’ (look at equation 7).

Since I have to minimize the objective function (equation 49) while training (isn’t that what we do in machine learning to reach to the optimal set of parameters ?), we need to know how to sample all the terms in equation 49. In the DDPM paper, they have mentioned the objective function such that it clearly mentions what will be sampled. It looks like this:

It is the same as equation 49, just representing information clearly about what has to be sampled for training (also it is how the DDPM paper has represented the objective).

Wow !! that’s it? With just mean squared error, randomly sampled noise, and neural network predictions, it’s a true testament to the power of straightforward design, isn’t it ?

Important: How can I sample x(0)?

x(0) represents a sample from the “real” data distribution (sample from the set of original images), that you’re trying to model. The distribution q(x0) is the empirical data distribution derived from our dataset. It is important to note that q(x0) is not a specific distribution like Gaussian or Uniform; instead, it is the distribution of the actual data you are working with.

To sample x0 from q(x(0)), you would simply draw a random sample from your dataset. In practice, this usually involves dividing your dataset into mini-batches and iterating through the dataset, sampling a mini-batch of data points (x0’s) at each step during the training process.

For example, if we were training a DDPM on images, q(x(0)) would represent the distribution of all the images in your dataset, and sampling x(0) would mean selecting a random image or a mini-batch of images from your dataset to use for a training iteration.

Thus, on the basis of equation 50 and the information above, the training algorithm in DDPM (refer to Algorithm 1 from the paper, taken as it is from the paper) is as follows:

Please Note: In the above blog, I have missed some of the minute details (like the second term in equation 46) and information about sampling (Algorithm 2 in the DDPM paper), since my aim was to get to/or explain the objective function. I will be next writing a blog on how to use the above equations while coding and adding insights about the details I have missed here.

(link to part 1)…

(link to part 2)…

If you like it or find it useful, please clap and share.

--

--