Diffusion Models: The Catalyst for Breakthroughs in AI and Research

Nandini Lokesh Reddy
9 min readAug 11, 2024

--

Mid-Journey, Stable Diffusion, DALL-E, and other similar models can create stunning images with just a simple prompt. But have you ever wondered how these algorithms manage to do this or what makes them capable of such feats? You might have come across the concept of noise in the context of image classification or detection. This blog dives into how this noise is harnessed to produce the desired images through a process known as diffusion.

Before we delve into Diffusion Models, it’s important to understand Generative Models. As the name suggests, these models are designed to generate new data from given inputs. They aim to create a diverse array of outputs that resemble the input data.

But how do these diffusion models actually assist in generating data? What exactly is this “noise”? What is diffusion, and what algorithm drives it?

Imagine we have a set of game characters in a specific format, and we need more characters in the same style. To create these new characters using the existing data, we rely on neural networks and the diffusion process.

However, before feeding data into the neural network, we need to design it in a way that the network can comprehend:

a. Fine details: such as what constitutes a game character, including color, hair color, shape, size, etc.

b. General outlines: like the shape of the head, body, and other major features.

c. Everything in between.

One method to emphasize either the finer details or the general outlines is by adding varying levels of noise to the data. This technique, known as the “noising process,” is a key step in generating the desired images.

This process is inspired by a concept from physics: imagine dropping ink into a glass of water. At first, you can clearly see where the ink drop hits, but as time passes, the ink disperses throughout the water until it eventually fades away.

Picture Courtesy: Byjus

So, what happens when we add noise? How do these neural networks respond?
If this image is the input, the neural network should confidently identify it and say → “That’s Bob the Sprite!”

Bob the Sprite

When there’s a bit of noise or if the image isn’t clear enough, the neural network might say → “There’s some noise here,” and then suggest possible details to make it look more like Bob the Sprite.

Probably Bob

And as the noise increases, if only the outline of the Sprite remains, you might only get a vague sense that it’s a Sprite character — it could be Bob, Fred, or someone else entirely. At this point, the neural network would suggest more general details to identify the most likely Sprite.

Maybe Bod, or Fred

Finally, if the image looks like nothing recognizable, but you still want it to resemble a Sprite, you’d want the neural network to take that image and gradually transform it by proposing an outline of what a Sprite might look like.

No Idea

The neural network is trained to take various noisy images and transform them back into sprites. It learns how to remove the added noise and refine the shape to resemble a sprite.

When noise is added to an image, it typically follows a normal distribution. This means that each pixel is sampled from a normal distribution, also known as a Gaussian distribution or bell-shaped curve.

So, when a neural network is tasked with creating a new sprite, it can sample noise from this normal distribution and gradually refine it to generate an entirely new sprite by progressively removing the noise.

Now let's dive deeper into the steps:

  1. Sampling:

At each step, the neural network attempts to accurately predict the noise.

Rather than predicting the sprite directly, the neural network focuses on predicting the noise. By subtracting this predicted noise from the sample, the result becomes progressively closer to resembling a sprite. However, since a single layer of the neural network isn’t enough to fully remove the noise, multiple steps are necessary to achieve high-quality samples.

Now let’s jump into actual Implementation:

Import necessary packages:

from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML

Define Sampling: We begin with the original noise introduced at the start. The process involves stepping backward through time, starting from the final iteration where the image is fully noisy, and progressively moving toward the initial state, where the image is refined to resemble the desired output.

def sample_ddpm(n_sample, save_rate=20):
samples = torch.randn(n_sample, 3, height, height).to(device)

intermediate = []
for i in range(timesteps, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')

# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(samples) if i > 1 else 0

At this stage, you feed the original noise back into the neural network to obtain a prediction of the noise. The goal is to subtract this predicted noise from the original noise to get an image that looks more like the desired sprite.

Finally, we use sampling algorithms known as “DDPM” (Denoising Diffusion Probabilistic Models) to facilitate the subtraction of the predicted noise from the original noise, refining the image step by step.

eps = nn_model(samples, t)    # predict noise e_(x_t,t)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==timesteps or i<8:
intermediate.append(samples.detach().cpu().numpy())

intermediate = np.stack(intermediate)
return samples, intermediate
Sampling Example

If we only subtract the predicted noise from the original noise without reintroducing any additional noise, the results will look like this:

2. Neural Network:
The diffusion model employs a U-Net architecture, introduced in 2015. It accepts an input image of a specific shape and produces an output image with the same shape, including noise.

An interesting fact about U-Net is that it was originally used for segmentation tasks.

U-NET Architecture.

The U-Net architecture first embeds the input information through convolutional layers, then down-samples the data through multiple convolutional layers, compressing the information into a smaller space. Subsequently, it up-samples the data using an equal number of up-sampling blocks, reconstructing the output to perform its intended tasks.

Implementation:

class ContextUnet(nn.Module):
def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
super(ContextUnet, self).__init__()

# number of input channels, number of intermediate feature maps and number of classes
self.in_channels = in_channels
self.n_feat = n_feat
self.n_cfeat = n_cfeat
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...

# Initialize the initial convolutional layer
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

# Initialize the down-sampling path of the U-Net with two levels
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]

# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())

# Embed the timestep and context labels with a one-layer fully connected neural network
self.timeembed1 = EmbedFC(1, 2*n_feat)
self.timeembed2 = EmbedFC(1, 1*n_feat)
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)

# Initialize the up-sampling path of the U-Net with three levels
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample
nn.GroupNorm(8, 2 * n_feat), # normalize
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)

# Initialize the final convolutional layers to map to the same number of channels as the input image
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
nn.GroupNorm(8, n_feat), # normalize
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
)

def forward(self, x, t, c=None):
"""
x : (batch, n_feat, h, w) : input image
t : (batch, n_cfeat) : time step
c : (batch, n_classes) : context label
"""
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on

# pass the input image through the initial convolutional layer
x = self.init_conv(x)
# pass the result through the down-sampling path
down1 = self.down1(x) #[10, 256, 8, 8]
down2 = self.down2(down1) #[10, 256, 4, 4]

# convert the feature maps to a vector and apply an activation
hiddenvec = self.to_vec(down2)

# mask out context if context_mask == 1
if c is None:
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)

# embed context and timestep
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")


up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2*up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out

Additionally, U-Net can incorporate additional information through embeddings, enhancing its ability to process and analyze complex data.

3. Training:
The objective is for the neural network to learn the distribution of noise in the image and how to approximate the likeness of a sprite. This is achieved using the training data. The network should be capable of generating different sprites each time it samples, reflecting its ability to produce varied outputs.

4. Control:
In this phase, we specify our desired output to the neural network, which then generates it according to our instructions. This control is facilitated through the use of embeddings, which guide the network in producing the desired result.

So, what are embeddings?
Embeddings are numerical representations that capture the meaning of text. The unique aspect of embeddings is their ability to represent semantic meaning: text with similar meanings will have similar vector representations. Additionally, these vectors can be used in arithmetic operations to manipulate and understand text more effectively.

When we add an embedding with context, it also incorporates the sample time, which helps in generating novel objects that haven’t been seen before. The context can include:

  • Vector for controlling generation: Guides the neural network in shaping the output.
  • Text embedding: Encodes semantic meaning from textual descriptions.
  • Categories: Defines specific classes or types for the generated output.
Example of Control In Diffusion Model

Implementation:


nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)

optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

# randomly mask out c
context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)
c = c * context_mask.unsqueeze(-1)

Examples of context-based embedding images generated:

Conclusion:

Diffusion Models are emerging as foundational tools in cutting-edge research across various fields, including life sciences. One notable application is in drug discovery, where these models are used to generate novel molecules with potential therapeutic properties. Their ability to produce diverse and high-quality outputs makes them invaluable for advancing research and innovation.

Stay tuned for more insights on Vision Models and their applications, as we continue to explore the transformative potential of these technologies.

References:

  1. Full Code on My Github: https://github.com/NandiniLReddy/DiffusionModelWorking
  2. Code modified From: https://github.com/cloneofsimo/minDiffusion
  3. Diffusion Model: https://arxiv.org/abs/2006.11239
  4. Denoising mechanism: https://arxiv.org/abs/2010.02502

--

--