DALL-E 2 from Scratch
Text-Conditioned Image Generation on FashionMNIST using CLIP Latents
Denoising diffusion probabilistic models (DDPM) are a popular type of generative AI model that were introduced by Ho et al. in 2020 and improved upon by Nichol et al. in 2021. The basic idea behind these models is that noise is added to images in the forward diffusion process in order to train the model to predict the noise that should be removed at a certain timestep in the reverse diffusion process. When sampling images, you would start with an image containing pure noise and iteratively remove model’s predicted noise at each timestep until you get the final image.
In order to have a DDPM generate multiple types of images while still letting the user choose which type of image they want, the model needs to be conditioned on some input. Ramesh et al. introduced one such conditioning method called unCLIP, which is used in OpenAI’s DALL-E 2 model. In the method described by Ramesh et al., the input caption is first passed to a prior network which will use a trained CLIP model to get the CLIP text embeddings. These text embeddings are then used by a decoder-only transformer in order to generate possible CLIP image embeddings. The CLIP image embeddings generated by the prior network will be used by a decoder network, which consists of a UNet model, in order to condition the images that are created. In this article, we are going to be building a simple diffusion model using this process.
Colab notebook with the code for this tutorial can be found here. Additionally, GitHub repo can be found here.
Libraries and Modules
We are going to be building our models using PyTorch, so we will need to import the library plus others that we will be using in this tutorial.
import torch
import torch.nn as nn
We also need to import in torchvision.transforms in order to resize the input images and convert them to tensors. Resizing the input images is optional. You just need to make sure that the image size is divisible by the patch size.
import torchvision.transforms as T
We are going to be using Adam or AdamW for our optimizer depending on weight decay, so we need to import it in from torch.optim. We are also going to be using a learning rate scheduler, so we need need to import in lr_scheduler from torch.optim as well.
from torch.optim import Adam, AdamW, lr_scheduler
We are going to be importing in the FashionMNIST dataset from HuggingFace for this tutorial, so we need to import in datasets. We are going to be using Dataset and DataLoader from PyTorch to help load the data, so we need to import that in as well.
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
We are going to import in pyplot from matplotlib in order to display the original images, noisy images, diffused images, etc.
import matplotlib.pyplot as plt
Calculating our timestep embeddings is going to require us to perform sin and cosine, so we need to import in numpy.
import numpy as np
Finally we are going to import in dataclass from dataclasses in order to create config classes that hold our parameters.
from dataclasses import dataclass
Putting this all together, we get:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.optim import Adam, AdamW, lr_scheduler
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
CLIP
In order to create diffusion images from text, we are going to be using the embeddings from the CLIP model. The text embeddings that we get from CLIP are used to condition the prior model to diffuse corresponding image embeddings. These image embeddings are then used to condition the decoder model to help guide it to the desired image.
For more information about the CLIP model and how to implement it from scratch, I recommend reading my other Medium article.
My previous vision transformers article might be useful to read as well.
I will be leaving out some topics/code because I already went over them in my previous articles. Those topics include: multi-head attention, positional encodings, vision transformers, text transformers, CLIP, tokenization, and loading in the FashionMNIST dataset.
Forward Diffusion
Forward diffusion is a part of the diffusion process in which Gaussian noise is added to an image at each timestep over T steps. The amount of noise that should be given at each timestep is given by a variance schedule. The formula for forward diffusion is:
That is, we keep applying Gaussian noise at every time step. While the formula above adds the noise iteratively, by using a reparameterization trick, we are able to directly sample a noisy image at any arbitrary timestep in closed form. We are going to be using this method for our forward diffusion process. The formula for this method is:
For the variance schedule, in the original DDPM paper by Ho et al., a linear schedule was used. While this schedule worked well with high resolution images, when the images were smaller (64x64 or less), the forward process was too noisy at the end. In order to combat this, Nichol et al. suggested using a cosine schedule whose formula is shown below.
These values are clipped to be less than or equal to 0.999 in order to prevent singularities near t = T. Because we are using the FashionMNIST dataset which contains low resolution images, the cosine beta schedule is the one we are going to use.
Putting this all together, the code for the forward diffusion process is going to look something like this:
# Gets elements from indicies and makes sure output has a certain dimension
def extract_and_expand(x, idx, shape):
return x[idx].reshape(idx.shape[0], *((1, ) * (len(shape) - 1)))
# Returns beta schedule
def get_beta_schedule(schedule, max_time, s=0.008):
if schedule == "linear":
scale = 1000 / max_time
betas = torch.linspace(1e-4 * scale, 0.02 * scale, max_time)
elif schedule == "cosine":
t = torch.linspace(0, max_time, max_time + 1)
a_bars = torch.cos((((t / max_time) + s) / (1 + s)) * (np.pi / 2)) ** 2
a_bars = a_bars / a_bars[0]
betas = 1 - (a_bars[1:] / a_bars[:-1])
betas = torch.clamp(betas, min=0, max=0.999)
else:
Exception("Beta schedule not implemented.")
return betas
def get_schedule_values(config):
schedule_values = {}
schedule_values["betas"] = get_beta_schedule(config.prior.schedule, config.prior.max_time).to(config.device)
schedule_values["alphas"] = 1.0 - schedule_values["betas"]
schedule_values["alpha_bars"] = torch.cumprod(schedule_values["alphas"], axis = 0)
schedule_values["sqrt_recip_alphas"] = torch.sqrt(1.0 / schedule_values["alphas"])
schedule_values["sqrt_alpha_bars"] = torch.sqrt(schedule_values["alpha_bars"])
schedule_values["sqrt_one_minus_alpha_bars"] = torch.sqrt(1.0 - schedule_values["alpha_bars"])
schedule_values["alpha_bars_prev"] = torch.cat((torch.ones(1, device=config.device), schedule_values["alpha_bars"][:-1]))
schedule_values["sigma"] = schedule_values["betas"] * (1.0 - schedule_values["alpha_bars_prev"]) / (1.0 - schedule_values["alpha_bars"])
return schedule_values
# Gets noisy image at a certain timestep
def forward_diffusion(x_0, schedule_values, t):
noise = torch.randn_like(x_0)
sqrt_alpha_bars = extract_and_expand(schedule_values["sqrt_alpha_bars"], t, x_0.shape)
sqrt_one_minus_alpha_bars = extract_and_expand(schedule_values["sqrt_one_minus_alpha_bars"], t, x_0.shape)
x_noisy = (sqrt_alpha_bars * x_0) + (sqrt_one_minus_alpha_bars * noise)
return x_noisy, noise
Timestep Embedding
The timestep embedding is an important part of diffusion. This is because images at different timesteps have a different amount of noise. In order to use this information in our model, we are going to use the sinusoidal positional encodings. These positional encodings are the same that are commonly used for Transformers. The main difference is that our input timesteps are mostly likely not going to be in sequential order and contain all the possible timesteps, so we need to only get the positional encodings that correspond to the timesteps that are inputted.
class SinusoidalPositionalEncodings(nn.Module):
def __init__(self,
max_seq_length, # Maximum sequence length
width # Width of model
):
super().__init__()
# Create positional encodings
pe = torch.zeros(max_seq_length, width)
for pos in range(max_seq_length):
for i in range(width):
if i % 2 == 0:
pe[pos][i] = np.sin(pos / (10000 ** (i / width)))
else:
pe[pos][i] = np.cos(pos / (10000 ** ((i - 1) / width)))
self.register_buffer('pe', pe)
def forward(self, x):
# Get positional encodings corresponding to inputted timesteps
x = self.pe[x]
return x
These time encodings are then passed through a MLP to further capture temporal information.
self.time_mlp = nn.Sequential(
SinusoidalPositionalEmbedding(config.decoder.max_time, config.decoder.model_channels),
nn.Linear(config.decoder.model_channels, config.decoder.cond_channels),
nn.SiLU(),
nn.Linear(config.decoder.cond_channels, config.decoder.cond_channels)
)
There are multiple ways to condition the input on this timestep embedding which are described in the UNet: Residual Blocks section of this article.
Prior Model
The prior model is used to generate CLIP image embeddings from text captions. It is possible to forgo the prior model and condition on the CLIP text embeddings instead of the CLIP image embeddings generated by the prior, but it was that using the prior model performed the best.
There are two main model classes that can be used for the prior model, an autoregressive prior and a diffusion prior. For our model, we are going to be using the diffusion prior because Ramesh et al. 2022 found that it outperformed the autoregressive prior for comparable model size and reduced training compute. Classifier-free guidance is something that can be implemented in order to improve sample quality for both the autoregressive and diffusion prior, but we are not going to be using it for this article.
The diffusion prior starts by taking the text captions as an input and getting the CLIP text and image embeddings. When loading the CLIP model in, all of the layers should be frozen and the mode should be set to eval.
def freeze_model(model, set_eval=True):
if set_eval:
model.eval()
for param in model.parameters():
param.requires_grad = False
# Constructor
self.clip = CLIP(config).to(config.device)
self.clip.load_state_dict(torch.load(config.clip.model_location, map_location=config.device))
freeze_model(self.clip)
# Forward
image_embeddings = self.clip.image_encoder(images) # (B, C, H, W) -> (B, latent_dim)
text_embeddings = self.clip.text_encoder(captions, mask=masks) # (B, text_seq_length) -> (B, latent_dim)
It then gets random timesteps for each of the images in the batch and uses them to get the noisy CLIP image embeddings from forward diffusion.
# Forward
timesteps = torch.randint(0, self.config.prior.max_time, (images.shape[0],)) # (B, )
noisy_image_embedding, _ = forward_diffusion(image_embeddings, self.schedule_values, timesteps)
The timesteps are then passed through a MLP in order to get the timestep embeddings.
# Constructor
self.time_mlp = nn.Sequential(
SinusoidalPositionalEmbedding(config.prior.max_time, config.latent_dim),
nn.Linear(config.latent_dim, config.latent_dim * config.prior.r_mlp, bias=config.prior.bias),
nn.SiLU(),
nn.Linear(config.latent_dim * config.prior.r_mlp, config.latent_dim, bias=config.prior.bias)
)
# Forward
timestep_embeddings = self.time_mlp(timesteps) # (B, ) -> (B, latent_dim)
Another important part of the prior model are the learned embeddings. These embeddings are a torch parameter that are going to be used to predict the final output. The learned embedding from the constructor is going to need to be expanded in the forward method so that there is an embedding for each item in the batch.
# Constructor
self.learned_embedding = nn.Parameter(torch.randn(config.latent_dim))
# Forward
learned_embeddings = self.learned_embedding.repeat(images.shape[0], 1) # (latent_dim) -> (B, latent_dim)
The text captions, CLIP text embeddings, timestep embeddings, noisy CLIP image embeddings, and learned embeddings are going to be concatenated into a sequence. All these items are going to have shape (B, latent_dim), but we are going to want to add an extra dimension in the middle to give them shape (B, 1, latent_dim). We are going to want to concatenate on this new dimension which should give the sequence shape (B, 5, latent_dim). If using a dataset with higher quality images, one possible way to improve the quality of the model could be to pass the text embeddings, image embeddings, and/or timestep embeddings through convolution layers to increase the sequence length.
tokens = torch.cat((
captions, # Image Caption
text_embeddings, # CLIP Text Embedding
timestep_embeddings, # Timestep Embedding
noisy_image_embedding, # Noisy CLIP Image Embedding
learned_embeddings # Learned Embedding
), dim=1) # (B, 5, latent_dim)
This sequence is then passed through a decoder-only Transformer with a causal attention mask. A causal attention mask is a lower triangular matrix of ones which makes it so that tokens can only attend to tokens that came before them.
# Constructor
self.decoder = nn.ModuleList(
[TransformerBlock(
config.latent_dim,
cond_width=config.latent_dim,
n_heads=config.prior.n_heads,
dropout=config.prior.dropout,
r_mlp=config.prior.r_mlp,
bias=config.prior.bias
) for _ in range(config.prior.n_layers)]
)
self.register_buffer("causal_attention_mask", torch.tril(torch.ones(5, 5))[None, :])
# Forward
for block in self.decoder:
tokens = block(tokens, mask=self.causal_attention_mask)
Finally, we take the learned embeddings from the output of the Transformer and pass it through a LayerNorm and Linear layer to get the predicted image embeddings.
# Constructor
self.output = nn.Sequential(
nn.LayerNorm(config.latent_dim),
nn.Linear(config.latent_dim, config.latent_dim, bias=config.decoder.bias)
)
# Forward
pred_image_embeddings = self.output(tokens[:, -1, :])
While normally diffusion models would predict the noise and generate the sample by removing the noise iteratively, it was mentioned in the unCLIP paper that it was better to just predict the CLIP image embeddings directly. The loss function for this model should be a mean-squared error loss between the predicted and actual CLIP image embeddings.
loss = nn.functional.mse_loss(pred_image_embeddings, image_embeddings)
During sampling, in order to improve quality, the model should generate two samples of the CLIP image embeddings and choose the one that has the higher dot product with the CLIP text embedding.
def get_one_sample(self, text_embeddings, captions):
# Get image embeddings that are pure noise
noisy_image_embeddings = torch.randn(text_embeddings.shape, device=self.config.device)
# timestep is max for all items because image embeddings are pure noise
timesteps = torch.full((captions.shape[0],), self.config.prior.max_time - 1)
# Get timestep embeddings
timestep_embeddings = self.time_mlp(timesteps) # (B, ) -> (B, latent_dim)
timestep_embeddings = timestep_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Expand learned embedding so that there is one for each item in batch
learned_embeddings = self.learned_embedding.repeat(captions.shape[0], 1) # (latent_dim) -> (B, latent_dim)
learned_embeddings = learned_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
tokens = torch.cat((
captions, # Image Caption
text_embeddings, # CLIP Text Embedding
timestep_embeddings, # Timestep Embedding
noisy_image_embeddings, # Noisy CLIP Image Embedding
learned_embeddings # Learned Embedding
), dim=1) # (B, 5, latent_dim)
# Pass through transformer blocks with causal attention mask
for block in self.decoder:
tokens = block(tokens, mask=self.causal_attention_mask)
# Get learned embeddings and pass through output projection to get CLIP image embeddings
pred_image_embeddings = self.output(tokens[:, -1, :])
return pred_image_embeddings
def sample(self, captions, masks=None):
# Get CLIP text embeddings
t_emb = self.clip.text_encoder(captions, mask=masks) # (B, text_seq_length) -> (B, latent_dim)
text_embeddings = t_emb[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Make caption length equal to latent dimension
if self.config.text_seq_length >= self.config.latent_dim:
captions = captions[:, :self.config.latent_dim] # (B, max_seq_len) -> (B, latent_dim)
else:
captions = nn.functional.pad(captions, (0, self.config.latent_dim - self.config.text_seq_length)) # (B, max_seq_len) -> (B, latent_dim)
captions = captions[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Getting two samples
sample_1 = self.get_one_sample(text_embeddings, captions)
sample_2 = self.get_one_sample(text_embeddings, captions)
gen_image_embeddings = torch.zeros(sample_1.shape)
# Choosing the samples with the higher dot product with text embeddings
for i in range(gen_image_embeddings.shape[0]):
if sample_1[i] @ t_emb[i] >= sample_2[i] @ t_emb[i]:
gen_image_embeddings[i] = sample_1[i]
else:
gen_image_embeddings[i] = sample_2[i]
return gen_image_embeddings
Putting all of this together, the code for the diffusion prior model should look something like this:
class DiffusionPrior(nn.Module):
def __init__(self, config):
super().__init__()
# Loading CLIP Model
self.clip = CLIP(config).to(config.device)
self.clip.load_state_dict(torch.load(config.clip.model_location, map_location=config.device))
freeze_model(self.clip)
self.config = config
self.time_mlp = nn.Sequential(
SinusoidalPositionalEmbedding(config.prior.max_time, config.latent_dim),
nn.Linear(config.latent_dim, config.latent_dim * config.prior.r_mlp, bias=config.prior.bias),
nn.SiLU(),
nn.Linear(config.latent_dim * config.prior.r_mlp, config.latent_dim, bias=config.prior.bias)
)
self.learned_embedding = nn.Parameter(torch.randn(config.latent_dim))
self.schedule_values = get_schedule_values(config)
# Transformer blocks
self.decoder = nn.ModuleList(
[TransformerBlock(
config.latent_dim,
cond_width=config.latent_dim,
n_heads=config.prior.n_heads,
dropout=config.prior.dropout,
r_mlp=config.prior.r_mlp,
bias=config.prior.bias
) for _ in range(config.prior.n_layers)]
)
# Output Projection
self.output = nn.Sequential(
nn.LayerNorm(config.latent_dim),
nn.Linear(config.latent_dim, config.latent_dim, bias=config.decoder.bias)
)
self.register_buffer("causal_attention_mask", torch.tril(torch.ones(5, 5))[None, :])
def get_one_sample(self, text_embeddings, captions):
# Get image embeddings that are pure noise
noisy_image_embeddings = torch.randn(text_embeddings.shape, device=self.config.device)
# timestep is max for all items because image embeddings are pure noise
timesteps = torch.full((captions.shape[0],), self.config.prior.max_time - 1)
# Get timestep embeddings
timestep_embeddings = self.time_mlp(timesteps) # (B, ) -> (B, latent_dim)
timestep_embeddings = timestep_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Expand learned embedding so that there is one for each item in batch
learned_embeddings = self.learned_embedding.repeat(captions.shape[0], 1) # (latent_dim) -> (B, latent_dim)
learned_embeddings = learned_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
tokens = torch.cat((
captions, # Image Caption
text_embeddings, # CLIP Text Embedding
timestep_embeddings, # Timestep Embedding
noisy_image_embeddings, # Noisy CLIP Image Embedding
learned_embeddings # Learned Embedding
), dim=1) # (B, 5, latent_dim)
# Pass through transformer blocks with causal attention mask
for block in self.decoder:
tokens = block(tokens, mask=self.causal_attention_mask)
# Get learned embeddings and pass through output projection to get CLIP image embeddings
pred_image_embeddings = self.output(tokens[:, -1, :])
return pred_image_embeddings
def sample(self, captions, masks=None):
# Get CLIP text embeddings
t_emb = self.clip.text_encoder(captions, mask=masks) # (B, text_seq_length) -> (B, latent_dim)
text_embeddings = t_emb[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Make caption length equal to latent dimension
if self.config.text_seq_length >= self.config.latent_dim:
captions = captions[:, :self.config.latent_dim] # (B, max_seq_len) -> (B, latent_dim)
else:
captions = nn.functional.pad(captions, (0, self.config.latent_dim - self.config.text_seq_length)) # (B, max_seq_len) -> (B, latent_dim)
captions = captions[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Getting two samples
sample_1 = self.get_one_sample(text_embeddings, captions)
sample_2 = self.get_one_sample(text_embeddings, captions)
gen_image_embeddings = torch.zeros(sample_1.shape)
# Choosing the samples with the higher dot product with text embeddings
for i in range(gen_image_embeddings.shape[0]):
if sample_1[i] @ t_emb[i] >= sample_2[i] @ t_emb[i]:
gen_image_embeddings[i] = sample_1[i]
else:
gen_image_embeddings[i] = sample_2[i]
return gen_image_embeddings
def forward(self, images, captions, masks=None):
# Get CLIP image embeddings
image_embeddings = self.clip.image_encoder(images) # (B, C, H, W) -> (B, latent_dim)
# Get CLIP text embeddings
text_embeddings = self.clip.text_encoder(captions, mask=masks) # (B, text_seq_length) -> (B, latent_dim)
text_embeddings = text_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Make caption length equal to latent dimension
if self.config.text_seq_length >= self.config.latent_dim:
captions = captions[:, :self.config.latent_dim] # (B, max_seq_len) -> (B, latent_dim)
else:
captions = nn.functional.pad(captions, (0, self.config.latent_dim - self.config.text_seq_length)) # (B, max_seq_len) -> (B, latent_dim)
captions = captions[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Get random timesteps for forward diffusion
timesteps = torch.randint(0, self.config.prior.max_time, (images.shape[0],)) # (B, )
# Get timestep embeddings
timestep_embeddings = self.time_mlp(timesteps) # (B, ) -> (B, latent_dim)
timestep_embeddings = timestep_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Perform forward diffusion to get noisy CLIP image embeddings
noisy_image_embedding, _ = forward_diffusion(image_embeddings, self.schedule_values, timesteps)
noisy_image_embedding = noisy_image_embedding[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
# Expand learned embedding so that there is one for each item in batch
learned_embeddings = self.learned_embedding.repeat(images.shape[0], 1) # (latent_dim) -> (B, latent_dim)
learned_embeddings = learned_embeddings[:, None, :] # (B, latent_dim) -> (B, 1, latent_dim)
tokens = torch.cat((
captions, # Image Caption
text_embeddings, # CLIP Text Embedding
timestep_embeddings, # Timestep Embedding
noisy_image_embedding, # Noisy CLIP Image Embedding
learned_embeddings # Learned Embedding
), dim=1) # (B, 5, latent_dim)
# Pass through transformer blocks with causal attention mask
for block in self.decoder:
tokens = block(tokens, mask=self.causal_attention_mask)
# Get learned embeddings and pass through output projection to get CLIP image embeddings
pred_image_embeddings = self.output(tokens[:, -1, :])
loss = nn.functional.mse_loss(pred_image_embeddings, image_embeddings)
return loss
Diffusion Decoder Overview
The diffusion decoder is the part of the model where the image is created. It does this by predicting the noise that should be removed at each timestep and iteratively removing the predicted noise from a noisy image. In the unCLIP model used by Ramesh et al., they used a model that was extremely similar to the GLIDE model by Nichol et al., and we are also going to be following that pretty closely.
The two model architectures that are most often used for the backbone of of this are UNet and Transformers.
One of the main benefits of the Transformers method is its scalability, it scales well with large datasets and more complex models. Another benefit is that while UNet is used primarily for images, Transformers are more flexible and can be used for other data types with significant changes.
UNets are the main choice for many image related tasks. This is because it is good at getting the local information through its convolutional layers while maintaining high-resolution features through its skip connections. Another benefit of UNets is that the shape of the input and the output should be the same, which is useful in image diffusion.
UNet works by taking the input and passing it through the encoder layers which identify/capture relevant while decreasing the resolution. It is then passed through the decoder layers which try to locate the features while increasing the resolution back to its original shape. Because spatial information during the encoder layers, skip connections are added from the encoder to the decoder in order to help preserve them.
Decoder: Residual Blocks
Both encoder and decoder layers are made of residual blocks that contain the convolutions that are used to identify and locate features. These residual blocks usual are comprised of Normalization, Activation, Convolution, Normalization, Activation, Convolution; however, our residual blocks are going to be Normalization, Convolution, Normalization, Activation, Convolution, Normalization (Shown in figure). This architecture was shown to improve performance while still maintaining non-linearity in the paper by Han et al.
# Constructor
self.layers1 = nn.Sequential(
nn.GroupNorm(n_groups, d_in),
nn.Conv2d(d_in, d_out, kernel_size, padding=1)
)
self.layers2 = nn.Sequential(
nn.GroupNorm(n_groups, d_out),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Conv2d(d_out, d_out, kernel_size, padding=1),
nn.GroupNorm(n_groups, d_out)
)
# Forward
x = self.layers1(x_0)
x = self.layers2(x)
For the Normalization portion of our residual blocks, we are going to be using GroupNorm. GroupNorm is used instead of BatchNorm because its performance is independent of batch size which makes it perform better on small or variable batch sizes. Using GroupNorm also tends to improve the stability during training.
The Activation portion of the residual blocks is used for non-linearity. While ReLU is commonly used for residual networks, we are going to be using SiLU for our activation function. In the paper Ramachandran et al., SiLU was shown to outperform ReLU and other activation functions despite models and hyperparameters being set specifically for ReLU. Because of SiLU’s simplicity and similarity to ReLU, it can be easily implemented by just using it in ReLU’s place.
For the Convolution, we are going to be using Conv2d with a 3x3 kernel size and SAME padding in order to preserves spatial dimensions. SAME padding works by adding zeros to the border of the input to insure that the output shape is the same as the input shape without padding. The first convolution will project the input from d_in channels to d_out channels while the second convolution will just keep d_out channels.
The residual block will also be conditioned on the inputted embedding information. One of the main ways that conditioning is performed is by just adding the embedding to the input. For our model, we are going to perform a linear projection in order to get a scale and a bias value. The input is then multiplied by the scale and the bias is added afterwards. Nichol and Dhariwal showed that using this method of conditioning improved the FID score compared to the addition method. One thing to note when coding this part is that the conditioning embedding is likely to have less dimensions than the input. Because of this, we need to add dimensions to the end of the embedding. For example, if the input had shape (B, C, L) and the embedding had shape (B, C), you would need to add a dimension to give it shape (B, C, 1).
# Constructor
self.use_scale_shift = use_scale_shift
self.cond_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, d_out * 2 if use_scale_shift else d_out)
)
# Forward
emb = self.cond_layers(emb)
while len(emb.shape) < len(x.shape):
emb = emb[..., None]
if self.use_scale_shift:
y_s, y_b = emb.chunk(2, dim=1)
x = y_s * x + y_b
else:
x += emb
The conditioning is performed after the first convolution instead of at the beginning of the residual block because it allows the model to get the basic features before refining them with timestep information which results in better representations.
Finally skip connections are created by adding the original input of the ResNet block to the output. A 1x1 convolution is performed on the input if the number of channels needs to be changed in order to match the output. These skip connections are used to enhance feature propagation.
# Constructor
self.residual = nn.Conv2d(d_in, d_out, 1) if d_in != d_out else nn.Identity()
# Forward
x += self.residual(x_0)
Putting all of this together, the final code for the residual blocks should look something like this:
class ResidualBlock(nn.Module):
def __init__(self, d_in, d_out, cond_channels=128, n_groups=8, kernel_size=(3,3), dropout=0.0, use_scale_shift=True):
super().__init__()
self.use_scale_shift = use_scale_shift
self.layers1 = nn.Sequential(
nn.GroupNorm(n_groups, d_in),
nn.SiLU(),
nn.Conv2d(d_in, d_out, kernel_size, padding=1)
)
# Activation & Linear Projection for Embedding
self.cond_layers = nn.Sequential(
nn.SiLU(),
# d_out multiplied by 2 in order to split into scale & shift if necessary
nn.Linear(cond_channels, d_out * 2 if use_scale_shift else d_out)
)
self.layers2 = nn.Sequential(
nn.GroupNorm(n_groups, d_out),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Conv2d(d_out, d_out, kernel_size, padding=1),
nn.GroupNorm(n_groups, d_out)
)
self.residual = nn.Conv2d(d_in, d_out, 1) if d_in != d_out else nn.Identity()
def forward(self, x_0, emb):
x = self.layers1(x_0)
emb = self.cond_layers(emb)
# Adding dimensions to embedding
while len(emb.shape) < len(x.shape):
emb = emb[..., None]
# Conditioning input with embedding
if self.use_scale_shift:
# Getting scale and shift
y_s, y_b = emb.chunk(2, dim=1)
# Performing scale and shift
x = y_s * x + y_b
else:
# Adding embedding to input
x += emb
x = self.layers2(x)
# Skip Connection
x += self.residual(x_0)
return x
Decoder: Attention Blocks
In our model, we are going to place attention blocks after the residual blocks of the inner layers of the encoder and decoder as well as the bottleneck (bottom layer connecting encoder and decoder). For example, if the encoder and decoder have four layers, there would be an attention blocks on layers two and three.
Using attention blocks has multiple benefits to our model. One of the benefits of attention blocks is that it helps capture information about the spatial relationships between different parts of the image. Another benefit is that it helps weight how important the model’s features are. Finally, attention blocks can be used to condition the image generation process.
There are multiple ways to implement attention-mechanism for the attention blocks. Some of the options include self-attention and cross-attention. The method that we are going to use is going to be sort of a mix of the two which was used in the GLIDE model. For this method, we are going to get the keys and values from both the input and the conditioning information and concatenate them together to get the final keys and values.
Q, K, V = self.qkv(x).chunk(3, dim=-1)
Q = Q.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
K = K.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
V = V.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
# Concatenating keys and values of input and condition
if cond is not None:
k_c, v_c = self.cond_kv(cond).chunk(2, dim=-1)
k_c = k_c.view(B, cond.shape[1], self.n_heads, self.head_size).transpose(1, 2)
v_c = v_c.view(B, cond.shape[1], self.n_heads, self.head_size).transpose(1, 2)
K = torch.cat((K, k_c), dim=-2)
V = torch.cat((V, v_c), dim=-2)
One thing to note is that input are images that have shape (B, C, H, W), but for our attention mechanism we are going to want the shape to be (B, L, C). Because of this, we are going to need to combine the H and W dimensions and transpose it with the C dimension. After performing attention, we are going to need to convert the output back to its original shape.
b, c, h, w = x_0.shape
# Changing shape to perform attention
x = x.permute(0, 2, 3, 1).view(b, h * w, c) # (B, C, H, W) -> (B, H * W, C)
# Attention
x = self.attention(x, cond)
# Changing back to original shape
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
The final code should look something like this:
class AttentionBlock(nn.Module):
def __init__(self, n_channels, cond_channels, n_groups=8, n_heads=1, dropout=0.0):
super().__init__()
assert n_channels % n_heads == 0, "n_channels must be divisible by n_heads"
self.n_heads = n_heads
self.head_size = n_channels // n_heads
self.scale = self.head_size ** -0.5
self.group_norm = nn.GroupNorm(n_groups, n_channels)
self.qkv = nn.Linear(n_channels, n_channels * 3)
self.cond_kv = nn.Linear(cond_channels, n_channels * 2)
self.out_proj = nn.Linear(n_channels, n_channels)
self.dropout = nn.Dropout(dropout)
def attention(self, x, cond=None):
B, L, _ = x.shape
# Getting queries, keys, and values for input
Q, K, V = self.qkv(x).chunk(3, dim=-1)
Q = Q.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
K = K.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
V = V.view(B, L, self.n_heads, self.head_size).transpose(1, 2)
# Concatenating keys and values of condition to keys and values of input
if cond is not None:
k_c, v_c = self.cond_kv(cond).chunk(2, dim=-1)
k_c = k_c.view(B, cond.shape[1], self.n_heads, self.head_size).transpose(1, 2)
v_c = v_c.view(B, cond.shape[1], self.n_heads, self.head_size).transpose(1, 2)
K = torch.cat((K, k_c), dim=-2)
V = torch.cat((V, v_c), dim=-2)
# Get dot product between queries and keys
attention = torch.matmul(Q, K.transpose(-2, -1))
# Scale
attention = attention * self.scale
# Applying softmax
attention = torch.softmax(attention, dim=-1)
# Get dot product with values
attention = torch.matmul(attention, V)
# Combine heads
attention = attention.transpose(1, 2)
attention = attention.contiguous().view(x.shape)
# Output projection
attention = self.out_proj(attention)
# Dropout
attention = self.dropout(attention)
return attention
def forward(self, x_0, cond=None):
b, c, h, w = x_0.shape
# Group normalization
x = self.group_norm(x_0)
# Changing shape to perform attention
x = x.permute(0, 2, 3, 1).view(b, h * w, c) # (B, C, H, W) -> (B, H * W, C)
# Attention
x = self.attention(x, cond)
# Changing back to original shape
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
# Residual connection
x = x + x_0
return x
Decoder: Downsampling
In between each of our encoder layers, we are going to need to reduce the resolution of the input. There are two main downsampling methods: pooling and using convolutional layers. Pooling is parameter-free which makes it more computationally efficient and can help with overfitting. The benefit of convolutional layers is that it does have parameters which makes it able to learn and preserve important features. We are going to use a stride of 2 for our model which will reduce the resolution of the inputs by a factor of 2.
We are going to code our model to be able to use either method, but when training our model, we are going to have it used the convolutional method.
class Downsample(nn.Module):
def __init__(self, n_channels, kernel_size=(3,3), stride=2, down_pool=False):
super().__init__()
if down_pool:
self.down = nn.AvgPool2d(stride)
else:
self.down = nn.Conv2d(n_channels, n_channels, kernel_size, stride=stride, padding=1)
def forward(self, x):
x = self.down(x)
return x
Decoder: Upsampling
In between each of our decoder layers, we are going to need to increase the resolution of the inputs that are passed through. In order to do this, we are going to first perform interpolation, which will multiply the height and width of the input by a factor of 2. Afterwards, we are going to pass it through a convolutional layer, which will learn and preserve important features while also ensuring that the number of channels is correct.
class Upsample(nn.Module):
def __init__(self, d_in, d_out, kernel_size=(3,3)):
super().__init__()
self.conv = nn.Conv2d(d_in, d_out, kernel_size, padding=1)
def forward(self, x):
x = nn.functional.interpolate(x, scale_factor=2)
x = self.conv(x)
return x
Decoder: Final Model
When training the decoder, the first thing that the model needs to do is set up the conditioning information. To do that, we first need to sample from the Prior model in order to get the CLIP image embeddings. Like with the CLIP model, when loading the Prior model in, the layers should be frozen and the mode should be set to eval.
# Constructor
self.prior = DiffusionPrior(config).to(config.device)
self.prior.load_state_dict(torch.load(config.prior.model_location, map_location=config.device))
freeze_model(self.prior)
# Forward
img_embeddings = self.prior.sample(caption, mask).to(x.device)
Afterwards, the inputted timesteps need to be used to get the timestep embeddings (more info in previous timestep embedding section). The CLIP image embeddings that were generated by the prior model are then projected and added to the timestep embeddings. These embeddings are the ones that are used to condition the model’s residual blocks.
# Constructor
self.img_projection = nn.Sequential(
nn.Linear(config.latent_dim, config.decoder.cond_channels),
nn.SiLU(),
nn.Linear(config.decoder.cond_channels, config.decoder.cond_channels)
)
#Forward
c_emb = self.time_mlp(time) + self.img_projection(img_embeddings)
for module in self.encoder:
if isinstance(module, ResidualBlock):
x = module(x, c_emb)
For the attention conditioning information, we first need to pass the text captions through a text Transformer encoder. In the unCLIP paper, these text encodings were used because Ramesh et al. thought that it would help learn aspects of natural language that CLIP couldn’t. During testing, they found that it did not have that much effect with this, so this part is optional.
# Constructor
self.text_embedding = nn.Embedding(config.vocab_size, config.latent_dim)
self.positional_encodings = nn.Parameter(torch.randn(config.text_seq_length,config.latent_dim) * (config.latent_dim ** -0.5))
self.text_encoder = nn.ModuleList(
[TransformerBlock(
config.latent_dim,
cond_width=config.latent_dim,
n_heads=config.decoder.n_heads,
dropout=config.decoder.dropout,
r_mlp=config.decoder.r_mlp,
bias=config.decoder.bias
) for _ in range(config.decoder.text_layers)]
)
self.final_ln = nn.LayerNorm(config.latent_dim)
# Function
def encode_text(self, text, mask=None):
x = self.text_embedding(text)
x = x + self.positional_encodings
for block in self.text_encoder:
x = block(x, mask=mask)
x = self.final_ln(x)
return x
# Forward
text_encodings = self.encode_text(text, mask)
After encoding the text, the CLIP image embeddings are projected into four extra tokens and concatenated onto the end of the encoded text.
# Constructor
self.get_img_tokens = nn.Linear(1, config.decoder.n_img_tokens)
# Forward
img_tokens = self.get_img_tokens(img_embeddings[..., None]).permute(0, 2, 1)
c_attn = torch.cat([text_encodings, img_tokens], dim=1)
After setting up the conditioning information, the noisy images are passed through an initial convolutional layer in order to get the number of channels to the initial model channels.
# Constructor
ch = config.decoder.model_channels
self.in_conv = nn.Conv2d(config.img_channels, ch, config.decoder.kernel_size, padding=1)
# Forward
x = self.in_conv(x)
Now that the noisy images have the desired amount of channels, they can be passed through the UNet layer along with the conditioning information. The UNet is going to have four encoder and decoder layers with a bottleneck layer in between. For the number of channels at each layer, we are going to be using [1, 2, 4, 8] as the ratios of layer channels to model channels. The encoder, decoder, and bottleneck layers of the UNet all are going to have two residual blocks. In between the bottleneck residual blocks and after each residual block in the inner layers of the encoder and decoder, attention blocks are placed (Figure 4). Skip connections are placed between the residual blocks of the encoder and decoder.
# Config
model_channels:int = 32
channel_ratios:list[int] = field(default_factory=lambda: [1, 2, 4, 8])
n_layer_blocks:int = 2
# Constructor
# UNet Encoder Layers
self.encoder = nn.ModuleList([])
for r in config.decoder.channel_ratios:
for _ in range(config.decoder.n_layer_blocks):
self.encoder.append(ResidualBlock(ch, config.decoder.model_channels * r, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
ch = config.decoder.model_channels * r
if r != config.decoder.channel_ratios[0] and r != config.decoder.channel_ratios[-1]:
self.encoder.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
if r != config.decoder.channel_ratios[-1]:
self.encoder.append(Downsample(ch, config.decoder.kernel_size, config.decoder.stride, config.decoder.down_pool))
# UNet Bottleneck Layers
self.bottleneck = nn.ModuleList([])
for block in range(config.decoder.n_layer_blocks):
self.bottleneck.append(ResidualBlock(ch, ch, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
if block != config.decoder.n_layer_blocks - 1:
self.bottleneck.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
# UNet Decoder Layers
self.decoder = nn.ModuleList([])
for r in range(len(config.decoder.channel_ratios))[::-1]:
for _ in range(config.decoder.n_layer_blocks):
self.decoder.append(ResidualBlock(ch * 2, ch, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
if r != 0 and r!= len(config.decoder.channel_ratios) - 1:
self.decoder.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
if r != 0:
ch = config.decoder.model_channels * config.decoder.channel_ratios[r-1]
self.decoder.append(Upsample(config.decoder.model_channels * config.decoder.channel_ratios[r], ch, config.decoder.kernel_size))
# Forward
for module in self.encoder:
if isinstance(module, ResidualBlock):
x = module(x, c_emb)
self.connections.append(x)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
for module in self.bottleneck:
if isinstance(module, ResidualBlock):
x = module(x, c_emb)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
for module in self.decoder:
if isinstance(module, ResidualBlock):
x = torch.cat([x, self.connections.pop()], dim=1)
x = module(x, c_emb)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
The output of the UNet decoder layers are then passed through a GroupNorm and SiLU activation layer before using Conv2d in order to get the output back to the original number of channels.
# Constructor
self.output = nn.Sequential(
nn.GroupNorm(config.decoder.n_groups, config.decoder.model_channels),
nn.SiLU(),
nn.Conv2d(config.decoder.model_channels, config.img_channels, config.decoder.kernel_size, padding=1)
)
# Forward
x = self.output(x)
Putting everything together, the final code for the Decoder model will look something like this:
class Decoder(nn.Module):
def __init__(self, config):
super().__init__()
# Loading Prior Model
self.prior = DiffusionPrior(config).to(config.device)
self.prior.load_state_dict(torch.load(config.prior.model_location, map_location=config.device))
freeze_model(self.prior)
# MLP to get time embeddings
self.time_mlp = nn.Sequential(
SinusoidalPositionalEmbedding(config.decoder.max_time, config.decoder.model_channels),
nn.Linear(config.decoder.model_channels, config.decoder.cond_channels),
nn.SiLU(),
nn.Linear(config.decoder.cond_channels, config.decoder.cond_channels)
)
# MLP to project CLIP image embeddings
self.img_projection = nn.Sequential(
nn.Linear(config.latent_dim, config.decoder.cond_channels),
nn.SiLU(),
nn.Linear(config.decoder.cond_channels, config.decoder.cond_channels)
)
# Projection to get image tokens
self.get_img_tokens = nn.Linear(1, config.decoder.n_img_tokens)
# Embedding layer for text captions
self.text_embedding = nn.Embedding(config.vocab_size, config.latent_dim)
# Learned positional encodings for text captions
self.positional_encodings = nn.Parameter(torch.randn(config.text_seq_length,config.latent_dim) * (config.latent_dim ** -0.5))
# Transformer encoder blocks to encoder text captions
self.text_encoder = nn.ModuleList(
[TransformerBlock(
config.latent_dim,
cond_width=config.latent_dim,
n_heads=config.decoder.n_heads,
dropout=config.decoder.dropout,
r_mlp=config.decoder.r_mlp,
bias=config.decoder.bias
) for _ in range(config.decoder.text_layers)]
)
self.final_ln = nn.LayerNorm(config.latent_dim)
ch = config.decoder.model_channels
# Initial convolution
self.in_conv = nn.Conv2d(config.img_channels, ch, config.decoder.kernel_size, padding=1)
# UNet Encoder Layers
self.encoder = nn.ModuleList([])
for r in config.decoder.channel_ratios:
for _ in range(config.decoder.n_layer_blocks):
self.encoder.append(ResidualBlock(ch, config.decoder.model_channels * r, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
ch = config.decoder.model_channels * r
if r != config.decoder.channel_ratios[0] and r != config.decoder.channel_ratios[-1]:
self.encoder.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
if r != config.decoder.channel_ratios[-1]:
self.encoder.append(Downsample(ch, config.decoder.kernel_size, config.decoder.stride, config.decoder.down_pool))
# UNet Bottleneck Layers
self.bottleneck = nn.ModuleList([])
for block in range(config.decoder.n_layer_blocks):
self.bottleneck.append(ResidualBlock(ch, ch, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
if block != config.decoder.n_layer_blocks - 1:
self.bottleneck.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
# UNet Decoder Layers
self.decoder = nn.ModuleList([])
for r in range(len(config.decoder.channel_ratios))[::-1]:
for _ in range(config.decoder.n_layer_blocks):
self.decoder.append(ResidualBlock(ch * 2, ch, config.decoder.cond_channels, config.decoder.n_groups, config.decoder.kernel_size, config.decoder.dropout, config.decoder.use_scale_shift))
if r != 0 and r!= len(config.decoder.channel_ratios) - 1:
self.decoder.append(AttentionBlock(ch, config.latent_dim, config.decoder.n_groups, config.decoder.n_heads, config.decoder.dropout))
if r != 0:
ch = config.decoder.model_channels * config.decoder.channel_ratios[r-1]
self.decoder.append(Upsample(config.decoder.model_channels * config.decoder.channel_ratios[r], ch, config.decoder.kernel_size))
# Output projection
self.output = nn.Sequential(
nn.GroupNorm(config.decoder.n_groups, config.decoder.model_channels),
nn.SiLU(),
nn.Conv2d(config.decoder.model_channels, config.img_channels, config.decoder.kernel_size, padding=1)
)
# Skip connections
self.connections = []
def encode_text(self, text, mask=None):
x = self.text_embedding(text)
x = x + self.positional_encodings
for block in self.text_encoder:
x = block(x, mask=mask)
x = self.final_ln(x)
return x
def forward(self, x, time, caption=None, mask=None):
# Sample prior model to get CLIP image embeddings
img_embeddings = self.prior.sample(caption, mask).to(x.device)
# Get conditioning information for residual blocks
c_emb = self.time_mlp(time) + self.img_projection(img_embeddings)
# Get conditioning information for attention blocks
c_attn = self.get_img_tokens(img_embeddings[..., None]).permute(0, 2, 1)
if caption is not None:
c_attn = torch.cat([self.encode_text(caption, mask), c_attn], dim=1)
# Initial convolution
x = self.in_conv(x)
# UNet encoder layers
for module in self.encoder:
if isinstance(module, ResidualBlock):
x = module(x, c_emb)
self.connections.append(x)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
# UNet bottleneck layers
for module in self.bottleneck:
if isinstance(module, ResidualBlock):
x = module(x, c_emb)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
# UNet decoder layers
for module in self.decoder:
if isinstance(module, ResidualBlock):
x = torch.cat([x, self.connections.pop()], dim=1)
x = module(x, c_emb)
elif isinstance(module, AttentionBlock):
x = module(x, cond=c_attn)
else:
x = module(x)
# Output projection
x = self.output(x)
return x
Decoder: Loss
An important part of training the decoder model is the loss function. The first thing that were are going to do in our loss function is sample random timesteps for each image in the batch.
timesteps = torch.randint(0, config.decoder.max_time, (image.shape[0],), device=config.device, dtype=torch.long)
These timesteps are then used, along with the schedule values, to get the noisy images and noise added through forward diffusion.
schedule_values = get_schedule_values(config)
noisy_image, noise = forward_diffusion(image, schedule_values, timesteps)
The noisy images, timesteps, text captions, and text masks are then passed through the model to get the predicted noises.
pred_noise = decoder(noisy_image, timesteps, caption, mask)
Finally, we can use the predicted noise from the model and the actual noise from the forward diffusion function to calculate the loss. The loss that we are going to be using for the decoder is the mean squared error loss between the predicted noise and actual noise.
loss = nn.functional.mse_loss(pred_noise, noise)
Putting this all together we get:
# Calculating Loss
get_schedule_values(config)
timesteps = torch.randint(0, config.decoder.max_time, (image.shape[0],), device=config.device, dtype=torch.long)
noisy_image, noise = forward_diffusion(image, schedule_values, timesteps)
pred_noise = decoder(noisy_image, timesteps, caption, mask)
loss = nn.functional.mse_loss(pred_noise, noise)
Config
@dataclass
class CLIPConfig:
# Vision Transformer
patch_size:tuple[int,int] = (4,4)
vit_width:int = 256
vit_layers:int = 6
vit_heads:int = 8
# Text Transformer
text_width:int = 256
text_layers:int = 6
text_heads:int = 8
# Attention
dropout:float = 0.2
r_mlp:int = 4
bias:bool = False
# Training
augment_data:bool = True
num_workers:int = 0
batch_size:int = 128
lr:float = 5e-4
lr_min:float = 1e-5
weight_decay:float = 1e-4
epochs:int = 200
warmup_epochs:int = 5
grad_max_norm:float = 1.0
get_val_accuracy:bool = False
model_location:str = "../clip_fmnist.pt"
@dataclass
class PriorConfig:
# Diffusion
max_time:int = 1000
schedule:str = "cosine"
schedule_offset:float = 0.008
# Transformer Decoder
width:int = 256
n_layers:int = 6
n_heads:int = 8
# Attention
dropout:float = 0.2
r_mlp:int = 4
bias:bool = False
# Training
augment_data:bool = False
num_workers:int = 0
batch_size:int = 128
lr:float = 5e-4
lr_min:float = 1e-5
weight_decay:float = 1e-4
epochs:int = 150
warmup_epochs:int = 5
grad_max_norm:float = 1.0
model_location:str = "../prior_fmnist.pt"
@dataclass
class DecoderConfig:
# Diffusion
max_time:int = 1000
schedule:str = "cosine"
# UNet
n_groups:int = 8
kernel_size:tuple[int, int] = (3,3)
model_channels:int = 32
cond_channels:int = 128
channel_ratios:list[int] = field(default_factory=lambda: [1, 2, 4, 8])
n_layer_blocks:int = 2
dropout:float = 0.1
use_scale_shift:bool = True
n_heads:int = 8
stride:int = 2
down_pool:bool = False
r_mlp:int = 4
bias:bool = False
text_layers:int = 4
n_img_tokens:int = 4
# Training
augment_data:bool = False
num_workers:int = 0
batch_size:int = 32
lr:float = 5e-4
lr_min:float = 1e-5
weight_decay:float = 1e-4
epochs:int = 100
warmup_epochs:int = 5
grad_max_norm:float = 1.0
sample_after_epoch:bool = True
model_location:str = "../decoder_fmnist.pt"
@dataclass
class FMNISTConfig:
latent_dim:int = 256
# Dataset Info
dataset:str = "fashion_mnist"
data_location:str = "./../datasets"
img_size:tuple[int,int] = (32,32)
img_channels:int = 1
vocab_size:int = 256
text_seq_length:int = 64
# Data Augmentation / Normalization
prob_hflip:float = 0.5
crop_padding:int = 4
train_mean:list[float] = field(default_factory=lambda: [0.2855552])
train_std:list[float] = field(default_factory=lambda: [0.33848408])
# Training
train_val_split:tuple[int,int] = (50000, 10000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model Configs
clip = CLIPConfig()
prior = PriorConfig()
decoder = DecoderConfig()
Training
For this text-to-image diffusion model, we are actually going to have to train 3 separate models: CLIP, diffusion prior, and the decoder. Including all the code of the three training scripts would be a lot, so if you want to see them, you can find them in the GitHub repo that I linked in the introduction.
For the dataset, all of the images were resized from (28,28) to (32,32). The images were also normalized with the mean and standard deviation set to the mean and standard deviation of the training split.
transform = T.Compose([
T.Resize(config.img_size),
T.ToTensor(),
T.Normalize(config.train_mean, config.train_std)
])
For the training set of the CLIP model, I also implemented data augmentation for the training split by randomly flipping horizontally and also randomly cropping.
transform = T.Compose([
T.Resize(config.img_size),
T.RandomHorizontalFlip(p=config.prob_hflip)
T.RandomCrop(config.img_size[0], padding=config.crop_padding)
T.ToTensor(),
T.Normalize(config.train_mean, config.train_std)
])
When training all three of the models, I used AdamW for the optimizer when using weight decay and Adam otherwise.
if config.clip.weight_decay == 0:
optimizer = Adam(clip.parameters(), lr=config.clip.lr)
else:
optimizer = AdamW(clip.parameters(), lr=config.clip.lr, weight_decay=config.clip.weight_decay)
For all of the models, I also used a cosine annealing learning rate scheduler with a linear warmup. The schedule was updated at the end of every epoch.
if config.clip.warmup_epochs > 0:
warmup = lr_scheduler.LinearLR(optimizer=optimizer, start_factor=(1 / config.clip.warmup_epochs), end_factor=1.0, total_iters=(config.clip.warmup_epochs - 1), last_epoch=-1)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=(config.clip.epochs - config.clip.warmup_epochs), eta_min=config.clip.lr_min)
Gradient clipping with a max norm of 1.0 was also used for all of the models to help prevent exploding gradients.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.[model].grad_max_norm)
When training the CLIP model, the training data was randomly split into a training split and a validation split. The model weights were saved when the validation loss for an epoch was lower than or equal to the previous lowest validation loss.
Sampling
The first step in the sampling process is to randomly generate an image that contains only noise.
B = prompts.shape[0]
# Get completely noisy image
img = torch.randn((B, config.img_channels, config.img_size[0], config.img_size[1]), device=config.device)
After getting the noisy image, we are going to calculate some of the schedule values that we are going to need to calculate x_{t-1}.
def get_schedule_values(config):
schedule_values = {}
schedule_values["betas"] = get_beta_schedule(config.decoder.schedule, config.decoder.max_time).to(config.device)
schedule_values["alphas"] = 1.0 - schedule_values["betas"]
schedule_values["alpha_bars"] = torch.cumprod(schedule_values["alphas"], axis = 0)
schedule_values["sqrt_recip_alphas"] = torch.sqrt(1.0 / schedule_values["alphas"])
schedule_values["sqrt_alpha_bars"] = torch.sqrt(schedule_values["alpha_bars"])
schedule_values["sqrt_one_minus_alpha_bars"] = torch.sqrt(1.0 - schedule_values["alpha_bars"])
schedule_values["alpha_bars_prev"] = torch.cat((torch.ones(1, device=config.device), schedule_values["alpha_bars"][:-1]))
schedule_values["sigma"] = schedule_values["betas"] * (1.0 - schedule_values["alpha_bars_prev"]) / (1.0 - schedule_values["alpha_bars"])
return schedule_values
schedule_values = get_schedule_values(config)
We then need to iteratively go through all of the timesteps, from 0 to max_time-1, in reverse order. We also need to expand the timestep so that there is a timestep for every item in the batch.
for t in range(0, config.max_time)[::-1]:
timesteps = torch.full((B,), t, device=config.device, dtype=torch.long)
For sampling at each timestep, the first step is to get the schedule values for only that timestep. After getting the timestep’s schedule values, they will need to be expanded so that the number of dimensions is equal to that of the images.
# Getting schedule values for timestep
sqrt_recip_alphas_t = extract_and_expand(schedule_values["sqrt_recip_alphas"], timesteps, img.shape)
betas_t = extract_and_expand(schedule_values["betas"], timesteps, img.shape)
sqrt_one_minus_alpha_bars_t = extract_and_expand(schedule_values["sqrt_one_minus_alpha_bars"], timesteps, img.shape)
sigma_t = extract_and_expand(schedule_values["sigma"], timesteps, img.shape)
Afterwards, we use the decoder model to predict the noise at timestep t.
# Predicting noise at timestep t with decoder
pred_noise = decoder(img, timesteps, caption=prompt, mask=mask)
If it is not the final timestep, we also need to generate random noise (z). We need this and sigma_t because of the fact that our model predicts the noise from x_t to x_0 directly instead of x_t to x_{t-1}. Because of this, the model predicts the noise from x_t to x_0 and adds the noise from x_{t-1} to x_0 back in. This makes it so that only the noise from x_t to x_{t-1} is removed. The sigma_t and z values are not need at the final step because the predicted noise from x_t to x_0 is the same as x_t to x_{t-1}.
# Generating random noise
z = torch.randn_like(img) if t > 0 else 0
With the generated noise, schedule values, and predicted noise, we can calculate the image at timestep t-1.
# Calculating image at timestep t-1
img = sqrt_recip_alphas_t * (img - (betas_t / sqrt_one_minus_alpha_bars_t) * pred_noise) + (sigma_t * z)
img = torch.clamp(img, -1.0, 1.0)
Putting this together, the code for sampling and image will look something like this:
@torch.no_grad()
def sample_image(config, prompt, mask, schedule_values=None):
# Load decoder model
decoder = Decoder(config).to(config.device)
decoder.load_state_dict(torch.load(config.decoder.model_location, map_location=config.device))
decoder.eval()
B = prompt.shape[0]
# Get completely noisy image
img = torch.randn((B, config.img_channels, config.img_size[0], config.img_size[1]), device=config.device)
# Calculate schedule values
if schedule_values is None:
schedule_values = get_schedule_values(config)
for t in range(0, config.decoder.max_time)[::-1]:
# Setting the timesteps for all the items in the batch
timesteps = torch.full((B,), t, device=config.device, dtype=torch.long)
# Getting schedule values for timestep
sqrt_recip_alphas_t = extract_and_expand(schedule_values["sqrt_recip_alphas"], timesteps, img.shape)
betas_t = extract_and_expand(schedule_values["betas"], timesteps, img.shape)
sqrt_one_minus_alpha_bars_t = extract_and_expand(schedule_values["sqrt_one_minus_alpha_bars"], timesteps, img.shape)
sigma_t = extract_and_expand(schedule_values["sigma"], timesteps, img.shape)
# Predicting noise at timestep t with decoder
pred_noise = decoder(img, timesteps, caption=prompt, mask=mask)
# Generating random noise
z = torch.randn_like(img) if t > 0 else 0
# Calculating image at timestep t-1
img = sqrt_recip_alphas_t * (img - (betas_t / sqrt_one_minus_alpha_bars_t) * pred_noise) + (sigma_t * z)
img = torch.clamp(img, -1.0, 1.0)
return img
Results
To view the results of our model, I am going to display the reverse diffusion process for each of the captions. In order to view this, we created a modified version of the sample_image function to plot the images at ten timesteps during the reverse diffusion process.
# Displaying Results
config = FMNISTConfig()
captions = {
0: "An image of a t-shirt/top",
1: "An image of trousers",
2: "An image of a pullover",
3: "An image of a dress",
4: "An image of a coat",
5: "An image of a sandal",
6: "An image of a shirt",
7: "An image of a sneaker",
8: "An image of a bag",
9: "An image of an ankle boot"
}
sample_captions = torch.stack([tokenizer(x, text_seq_length=config.text_seq_length)[0] for x in captions.values()]).to(config.device)
sample_masks = torch.stack([tokenizer(x, text_seq_length=config.text_seq_length)[1] for x in captions.values()]).to(config.device)
for i in range(len(sample_captions)):
caption = sample_captions[None, (i % len(sample_captions))]
mask = sample_masks[None, (i % len(sample_masks))]
test = sample_plot_image(config, caption, mask)
Possible Ways to Improve Models
- Implementing classifier free guidance
- Learn sigma and use hybrid loss objective in order to improve log-likelihoods.
- Increasing the number of tokens for the sequence in the prior model that is passed through the decoder-only Transformer.
- Learned padding for text captions
- K-Fold cross validation to help find optimal parameters (training time would significantly increase).
Relevant Papers
- Ho et al. 2020, DDPM Paper: https://arxiv.org/pdf/2006.11239
- Nichol et al. 2021, Improved DDPM: https://arxiv.org/pdf/2102.09672
- Nichol et al. 2022, GLIDE: https://arxiv.org/pdf/2112.10741
- Han et al. 2017, Pyramidal Residual Networks: https://arxiv.org/pdf/1610.02915
- Ramachandran et al. 2017, Activation Functions: https://arxiv.org/pdf/1710.05941
- Dhariwal et al. 2021, Diffusion Beats GANs: https://arxiv.org/pdf/2105.05233