Latent Diffusion and Perceptual Latent Loss
This article shows a novel approach to training a generative model for image generation at reduced training times using latents and using a pre-trained ImageNet latent classifier as a component of the loss function.
The above image shows generated celebrity faces from a generative model trained on the celeba-hq dataset for 250 epochs. Training an initialised (not pre-trained) model remarkably was less than 10 hours on a single NVIDIA 10GB 3080RTX card.
The below image shows Generated bedroom images from a generative model trained on the LSun bedrooms subset (20%) for 25 epochs. Training an initialised (not pre-trained) model again remarkably was less than 12 hours on a single NVIDIA 10GB 3080RTX card.
Quite how much of an improvement these images are can be demonstrated by looking the output from a Wassertein GAN method from 2019 which was trained for more epochs.
This article will give an overview of the key concepts that have achieved this via Perceptual latent loss:
- Latent classification model
- Iterative refinement
- Training a Latent diffusion model
- A Perceptual image latent loss function
Firstly the entire of ImageNet (over 14 million images) is encoded into their latent representations using a VAE.
A classification model is then trained on the entire ImageNet training set.
The head from the classification model is removed and the model’s activations are used as part of the loss function to train a diffusion U-Net model for Latent Diffusion.
The diffusion U-Net is used with DDIM sampling of latents from the model to generate image latents, which are then decoded into images.
Perceptual Latent Loss
Perceptual loss functions (sometimes called feature loss) have been popular for several years to improve generative AI above that possible from common loss metrics such as MSE (L2) and MAE (L1) loss.
In the latest fast.ai course at almost the end of the last lesson Jeremy Howard proposed the idea of training a classifier for the full ImageNet dataset that could then be used to aid image generation, with a difference that the classifier uses the latents from processing the ImageNet images through a pre-trained Variational Autoencoder.
Here is what might be the first implementation of a pre-trained latent classifier used as part of a loss function for training a neural network for high quality image generation.
This generates new samples by iteratively applying a diffusion process using a U-Net neural network to remove noise from random noise.
ImageNet Latent Classification Model
This uses the sd-vae-ft-ema Variational Autoencoder that was trained with MSE, perceptual and GAN loss by Stability AI using the OpenImages dataset.
The Stable Diffusion Variational Autoencoder (SD-VAE) is a generative model that can be used for image generation and reconstruction.
The SD-VAE was originally trained on OpenImages but was fine-tuned on the Stable Diffusion training set. The fine-tuning process enriched the dataset with images of humans to improve the reconstruction of faces. The first fine-tuning process, ft-EMA, was resumed from the original checkpoint, trained for 313198 steps and uses EMA weights.
The entire 14 million images from ImageNet are encoded into latents with the sd-vae-ft-ema VAE. By using a VAE, it is possible to take advantage of the compute time that others have previously put into training the VAE.
These latents map 3 channel 256 x 256 pixel images down by a spatial resolution factor of 8 to 4 channels 32 x 32 latents. These 4 channel latents are 48 times smaller tensors than the original 3 channel 256 x 256 images. Examples below from ImageNet encoded as latents displaying the first 3 channels. The basic geometry is still there, although changed considerably.
The benefit of this is it requires 48 times less memory and much less compute for a convolution across the latent representation of the image.
Encoding into latents is beneficial for generation as the VAE maps down into a 200 dimensional vector and back up to the image in an effective manner. These are examples of the latents decoded back into 3 channel 256 x 256 images. Considering the compression, these decoded images are impressive.
Classification model training
A model is created consisting of ResBlocks with dropout to be trained to classify the latent representations of the ImageNet images.
The model was trained to 65.7% accuracy for ImageNet classification over 30 epochs with data augmentation, pixel padding them cropping randomly back to 32x32 latent pixels and using Random Erase (from fast.ai’s 2022 lesson 18’s augmentation notebook).
Remarkably this training to classify latents on a single NVIDIA 10GB 3080RTX card took less than 20 hours.
Iterative refinement — DDIM and DDPM
Denoising Diffusion Implicit Models (DDIM) sampling is as a fast method of iterative refinement. This is faster than Denoising Diffusion Probabilistic Models (DDPM) sampling. To understand DDIM, it helps to understand DDPM first.
Denoising Diffusion Probabilistic Models (DDPM)
Imagine you have a clear, detailed photograph. Now, picture gradually adding random noise to this photo until it becomes a completely unrecognizable, blurry mess. That’s the first part of DDPM — the ‘diffusion’ process, where we’re moving from a clear image to noise.
Now, let’s reverse this process. This is where ‘denoising’ comes in. The goal of DDPM is to learn how to reverse this process, starting from the noise and slowly removing it to recreate the original image. It’s like having a puzzle where you only have a box of shapeless, colored pieces, and you’re learning to put it back together to form a beautiful picture.
The ‘probabilistic’ part means that the model deals with probabilities. Instead of being absolutely sure about each step, it makes educated guesses. Over time and with training, these guesses become more and more accurate.
In simple terms, DDPM takes a journey from clarity to chaos (adding noise) and then learns the path back from chaos to clarity (removing noise), doing this in a way that’s not certain but based on probabilities. It’s like an artist who first obscures a painting and then learns to restore it to its original beauty, but with a bit of guessing involved.
Denoising Diffusion Implicit Models (DDIM)
DDIM (Denoising Diffusion Implicit Models) is a variation of the original DDPM (Denoising Diffusion Probabilistic Models) concept. Both are used for generating images, but they differ in their approach and efficiency.
DDPM works through a series of steps where an image is gradually transformed from noise into a clear picture. It does this by learning how to reverse a process where a clear image is gradually converted into noise.
DDPM removes noise in small increments over a large number of steps. The process is generally slow because it requires many steps to denoise the image. DDIM is a modification of DDPM that aims to speed up the image generation process.
DDIM uses a different approach to reduce noise that allows for fewer steps in the denoising process. It modifies the way each step in the denoising process is calculated, making them more efficient.
The key advantage of DDIM is its speed. It can produce images much faster than DDPM because it requires fewer steps to achieve similar results.
While both DDPM and DDIM share the same fundamental concept of transforming noise into a coherent image, DDIM does it more efficiently, requiring fewer steps and thus reducing the time it takes to generate an image. This efficiency makes DDIM particularly valuable in practical applications where speed is crucial.
You can learn more about DDPM and DDPI in these Fast.ai lessons:
Training a Latent diffusion model
Training a model to remove noise from latents is similar to how stable diffusion operates as a latent diffusion model, without the natural language model element.
Both this model demonstrated here and Stable diffusion are types of generative model that can be used to create realistic images. They work by starting with an image made up of random noise and then gradually removing noise from it over several iterations. The noise is removed using a neural network, which is trained on a dataset of real images which had noise added to them. The result is an image that is similar to the real images in the dataset, with new elements added.
Space of all possible images
Attempt to imagine a 200,000 dimensional space of all possible images. A random point in this space would be just noise. If two similar images were mapped into this space and they would be close to one another. With manifold theory, plausible images lie in a lower dimensional manifold within this higher dimensional space. Effectively the plausible images clustered onto some surface, a manifold, within the space.
Noisy images and iteratively removing the noise
If an image is taken that is mapped into the latent space and random noise is added in steps to gradually corrupt it then this moves it away from the manifold of real plausible images. This is a stochastic differential equation. The model needs to achieve the reverse of this to move a noisy image back towards the manifold of real plausible images, This is an Ordinary Differential Equation. It isn’t possible to solve Ordinary Differential Equation in a single step, hence several noise removal steps need to be applied iteratively.
Essentially the model takes noisy latents as an input and outputs slightly less noisy latent out. Then that model can iterate over a less and less noisy input until a real plausible image is generated.
The iterative refinement uses DDIM (Denoising Diffusion Implicit Models) sampling described earlier.
The model is trying to generate latent pixels. Latent pixels are much more precise compared to conventional pixel. Neighbouring latent pixels aren’t likely to be similar in the way conventional pixels are, the detail has been compressed out. It’s a much more difficult task to predict latent pixels.
The model’s predicted latents then need to be decoded using the VAE into a conventional pixel image.
Using image latents within a perceptual latent loss function
Traditionally perceptual loss is performed on image activations rather than latent activations. It involves passing the network’s output through a classifier, trained on ImageNet or otherwise and then confirming the activations are similar within the loss function.
The classification head of the model is chopped off, in this case removing the last 4 layers in the above back to the Adaptive Average Pool 2D layer. It is the activations from this layer used to evaluate the perceptual latent loss.
Training the U-Net diffusion model
A U-Net diffusion architecture is used for the latent noise removing model.
When the U-Net diffusion model is trained to remove noise we can now compare both the MSE (L2) loss from the de-noised image latents compared to the target latents and also a latent activation loss based on the activations in the Adaptive Average Pool 2D layer from our ImageNet Classification model.
Comparison of Perceptual Latent Loss versus MSE only
Outputs from a generative model trained with perceptual Latent based loss function after 250 epochs of training:
Outputs from a generative model trained with MSE only based loss function after 250 epochs of training:
The difference in quality is remarkable, although the MSE-only trained model’s output is in itself remarkable with 250 epochs trained in less than 10 hours on a single NVIDIA 10GB 3080RTX card.
Fast.ai’s notebook training a latent classifier on ImageNet:
ImageNet Latent Classifier
Latent Diffusion generative AI trained on the CelebA-HQ dataset:
Latent Diffusion generative AI trained on the LSun dataset:
Fast.ai Practical Deep Learning for Coders Part 2 course:
Fast.ai part 2
Diffusion U-Net course lesson:
Fast.ai part 2 lesson 24— Diffusion U-Net
Latent Diffusion course lesson:
Fast.ai part 2 lesson 25 — Latent Diffusion