Adversarial Diffusion Distillation: the return of GANs in generative modelling???

Ambrose Ling
deMISTify
Published in
9 min readJan 16, 2024
Image generated by DALL-E with the prompt: “A very powerful diffusion model”

Introduction

Generative AI is a familiar phrase that has infiltrated every part of our lives. Whether it is helping you write that essay due tomorrow with ChatGPT or using Adobe generative fill to make yourself look like you’re 6 foot 2, it is a piece of technology we cannot stop talking about. In particular, image generation with diffusion models, such as Midjourney, Stable Diffusion, DALL-E, has been a rapidly developing field of research within generative AI over the past decade. Deep learning scientists and researchers are constantly discovering new ways to improve the generation quality as well as the generation speed of diffusion models, because who would want to wait 5 minutes just to generate something that’s not what you’re looking for?

In this article, I will introduce a recently developed method of image generation called Adversarial Diffusion Distillation (ADD), that leverages the power of GANs as well as knowledge distillation techniques to reduce generation time while maintaining high-quality samples. I will first give a brief overview of each of the components of this new method and gradually put them all together. Let’s dive into it!

Adversarial

When we talk about adversarial, we are referring to a framework very similar to Generative Adversarial Networks (GANs). GANs are a type of generative models that are comprised of 2 components, a generator and a discriminator.

Architecture of a GAN model, taken from https://developers.google.com/machine-learning/gan/gan_structure

You can imagine the generator as the counterfeiters and the discriminator as the police. The counterfeiters are constantly trying to generate fake currency as realistic as possible, while the police are trying to detect the fake currency. The generator takes in random noise and tries to generate a denoised, clean sample/image. The generator plays the role of learning an input data distribution and tries to samples from this distribution during generation. The discriminator receives input samples and plays the role of determining the probability that this sample belongs to either the real data distribution or the generator distribution. Both models are trained simultaneously with the same objective. From the original GAN paper, the training objective can be formulated as follows:

This objective tells us that we are training the discriminator D to maximize this entire loss expression (maximize the probability that D assigns the correct label to the given sample). On the contrary, the generator G is trained to minimze log(1-D(G(z)) (minimize the probability that D identifies the correct label given a sample from the generator).

One strength of GANs is that it allows for quick and efficient sampling since there is no iterative process. You can directly retrieve a generated sample by inputting noise into the generator. However, GANs are difficult to train because of several reasons:

  1. Mode collapse — This occurs when the generator fails to generate a diverse variety of images in order to fool the discriminator, which greatly impacts generation quality.
  2. Vanishing gradients — If samples from the generator are poor in the beginning, the discriminator could result in high confidence, which results in small losses and gradient updates to the generator.

For these reasons, there has been a barrier for GANs to achieve greater performance when it comes to generation diversity or capturing the full range of the data distribution.

Diffusion

Diffusion models are a different type of generative model with a iterative diffusion (forward) and denoising (reverse) process.

In the forward process, we iteratively destroy structure in the data by injecting Gaussian noise into the original signal x_0 for a certain number of steps until it is transformed into pure noise x_T with normal distribution N(0,I). We use a noise scheduler to control how much noise is added to the original signal at each step.

In ADD, the forward diffusion process looks something like this via the reparametrisation trick:

x_0 represents the original signal and epsilon represents noise sampled from a Gaussian distribution. x_s represent the noisy signal after applying noise for s steps. The coefficients of x_0 and epsilon represent the noise schedule (dictating how much of the original signal should I keep and how much pure noise shoud be added). For more about how noise schedules work, you can check out these resources. On a high level, it is just a way of representing the noise level present in our current sample.

In the reverse process, we train a neural network to either predict how much noise should be removed from x_T to obtain the original signal x_0 at each step OR we predict the original signal x_0 directly (which is what ADD’s diffusion model does). We repeatedly remove noise until we get to a denoised version of the image. We usually refer to the reverse process as the sampling process and the number of times we repeat this noise removal as sampling steps.

The sampling process is an entire field of research itself and can be completely indepedent from the forward process. There have been many different samplers developed to reduce the number of sampling steps required while maintaining the sample quality. We won’t cover much about sampling in this article as ADD’s architecture enables a very simple sampling process. However, I will attach some useful links you can explore if you are interested.

The most common model architecture used for this neural network is a U-Net, which is composed of a series of residual and attention blocks with a large number of convolutional layersr that capture crucial spatial information about the image.

How are diffusion models trained?

Standard training procedure for diffusion models, taken from https://arxiv.org/pdf/2202.00512.pdf

Here is a standard way of training diffusion models:

  1. Sample a random image x from your dataset, a random timestep t and a random noise vector epsilon
  2. Add noise to the image for t amount of time steps
  3. Predict the denoised version of the image and compute the loss btw the original image and the predicted
  4. Perform gradient descent until convergence

Distillation

Distillation or Knowledge Distillation is a technique developed to “distill” the knowledge from of a much more powerful model to a smaller model. Distillation can generally be applied when we want to deploy smaller models on edge devices or when we wish to decrease the memory demands from our model while maintaining the performance of a powerful model.

Above is a simple example of a basic knowledge distillation framework, where we denote the difference between the teacher logits and student logits as distillation loss. And the students weights are updated based on that loss.

In diffusion models, distillation can also be as a method used to reduce the number of iterative sampling steps. This can be achieved by distilling knowledge from a teacher model trained on a larger number of sampling steps to a student model capable of achieving similar generation quality but with fewer steps. Methods such as Progressive Distillation, Consistency Model Distillation have been applications of distillation in diffusion models.

Adversarial Diffusion Distillation

This paper puts together the 3 aforementioned concepts together to form ADD. ADD aims to solve a few problems:

  1. Sampling speed is too slow for diffusion models with a large number of sampling steps
  2. Exsiting distillation methods still exhibit some blurriness in generation results
  3. Sample quality deteriorates with less sampling steps

ADD’s method combines the superior sample quality of existing pretrained diffusion models as well as the speed of GANs. Let’s break down their whole training procedure:

ADD first samples an image from the real data distribution x_0. It applies the forward diffusion process to get a noisy image. The student and teacher model are both initizlied as a pretrained diffusion model (specifically a Stable Diffusion UNet). However only the teacher parameters are frozen. The student predicts the original signal given some noisy signal that has applied forward diffusion. Generated samples from the student are passed to a discriminator. The discriminator has 2 parts: a feature network F that is a ViT (Vision Transformer) and a discriminator head that is a classifier.

The authors also used techniques to condition the discriminator on additional information such as text or the original image, which encourages the ADD student model to extract meaningful information from the input.

We then take the denoised prediction from the student and apply forward diffusion again, that acts as the input to the teacher model. We then take the denoised predictions from the teacher and we use that as a reconstruction target (an example for the student to mimic).

The overall objective used to train the ADD-student model can be summarized as follows:

Adversarial Loss:

The introduction of an adversarial loss term essentially helps the student model generate images as realistic as possible, pushing the generated samples close to the actual data distribution.

Distillation Loss:

The distillation loss term (see from the figure) uses a distance metric d to measure the mismatch between the teacher’s generation and the student’s generation. This term is what forces the student to mimic the compositionality that is present in the teacher predictions. Notice how we needed to apply noise to the student generation result , that is because the pretrained unet was trained on a distribution of diffused inputs (we always apply noise to the image, then give it to the model). Giving it a denoised version by the student would fall out of the input training distribution.

Notable Results

The authors showed that their method was able to outperform a great variety of SOTAs in both inference speed and FID score (a score to evaluate the quality of generation, lower is better) with it being able to perform only 1 sampling step.

The authors also conducted a variety of ablation experiments to help them make a few important decisions:

  1. Conditioning the discriminator on both the text and the image was shown to improve generation quality.
  2. The adversarial loss was incredibly crucial to its high fidelity samples. Both the adversarial and distillation loss are needed to achieve optimal performance.
  3. Initialisation of the student to a pre-trained stable diffusion UNet was crucial to its performance.

The authors also performed user preference studies to evaluate the generation quality. Their method was shown to have an ELO score that surpassed other SOTAs.

Impact

ADD was truly a jaw dropper as it was surprising to see GANs return to generative modelling in a way that compliments existing diffusion models. It has demonstrated its strong capability in maintaining generatin quality while enabling fast sampling, which puts them ahead of the state-of-the arts. Diffusion models have always experienced this barrier of deployment because of its large model size and need for iterative sampling. However, these new training techniques proposed by ADD could potentially enable the rise in popularity for diffusion models to be easily deployed on edge devices at a scale greater than before. Users can embrace creative endeavours in image generation with great ease, bringing their artistic ideas to life in a matter of seconds. These possibilities are what make these small breakthroughs super exciting to see.

References

  1. https://arxiv.org/abs/2202.00512

2. https://arxiv.org/pdf/2311.17042.pdf

3. https://arxiv.org/abs/1406.2661

4. https://www.youtube.com/watch?v=tT9Lnt6stwA

--

--

Ambrose Ling
deMISTify

I like neuroscience, machine learning, business, computational biology:)