# Introduction

Denoising Diffusion Probabilistic Models (DDPM) are deep generative models that are recently getting a lot of attention due to their impressive performances. Brand new models like OpenAI’s DALL-E 2 and Google’s Imagen generators are based on DDPMs. They condition the generator on text such that it becomes then possible to generate photo-realistic images given an arbitrary string of text.

For example, inputting “A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat” to the new Imagen model and “a corgi’s head depicted as an explosion of a nebula” to the DALL-E 2 model produces the following images:

These models are simply mind-blowing, but understanding how they work requires understanding the Original work of Ho et. al. “Denoise Diffusion Probabilistic Models”.

In this brief post, I will focus on creating from scratch (in PyTorch) a simple version of DDPM. In particular, I will be re-implementing the original paper by Ho. Et al. We will work with the classical and non-resource-hungry MNIST and Fashion-MNIST datasets, and try to generate images out of thin air. Let’s start with a little bit of theory.

# Denoise Diffusion Probabilistic Models

Denoise Diffusion Probabilistic Models (DDPMs) first appeared in this paper.

The idea is quite simple: given a dataset of images, we add a little bit of noise step-by-step. With each step, the image becomes less and less clear, until all that is left is noise. This is called the “forward process”. Then, we learn a machine learning model that can undo each of such steps, and we call it the “backward process”. If we can successfully learn a backward process, we have a model that can generate images from pure random noise.

A step in the forward process consists in making the input image noisier (x at step t) by sampling from a multivariate gaussian distribution which mean is a scaled-down version of the previous image (x at step t-1) and which covariance matrix is diagonal and fixed. In other words, we perturb each pixel in the image independently by adding some normally distributed value.

For each step, there is a different coefficient beta that tells how much we are distorting the image in that step. The higher beta is, the more noise is added to the image. We are free to pick coefficients beta, but we should try to not have steps where too much noise is added all at once, and the overall forward process should be “smooth”. In the original work by Ho et. al., betas are put in a linear space from 0.0001 to 0.02.

A nice property of a gaussian distribution is that we can sample from it by adding to the mean vector a normally distributed noise vector scaled by the standard deviation. This results in:

We now know how to get the next sample in the forward process by just scaling what we already have and adding some scaled noise. If we now consider that the formula is recursive, we can write:

If we keep doing this and do some simplifications, we can go all the way back and obtain a formula for getting the noisy sample at step t starting from the original non-noisy image x0:

Great. Now no matter how many steps our forward process will have, we will always have a way to directly get the noisy image at step t directly from the original image.

For the backward process, we know our model should also work as a gaussian distribution, so we would just need the model to predict the distribution mean and standard deviation given the noisy image and time step. In practice, in this first paper on DDPMs the covariance matrix is kept fixed, so we only really want to predict the mean of the gaussian (given the noisy image and the time step we are at currently):

Now, it turns out that the optimal mean value to be predicted is just a function of terms that we are already familiar with:

So, we can further simplify our model and just predict the noise epsilon with a function of the noisy image and the time-step.

And our loss function is just going to be a scaled version of the Mean-Square Error (MSE) between the real noise that was added and the one predicted by our model

Once the model is trained (algorithm 1), we can use the denoising model to sample new images (algorithm 2).

# Let’s get coding

Now that we have a rough understanding of how diffusion models work, it’s time to implement something of our own. You can run the following code yourself in this Google Colab Notebook or with this GitHub repository.

As usual, imports are trivially our first step.

`# Import of librariesimport randomimport imageioimport numpy as npfrom argparse import ArgumentParserfrom tqdm.auto import tqdmimport matplotlib.pyplot as pltimport einopsimport torchimport torch.nn as nnfrom torch.optim import Adamfrom torch.utils.data import DataLoaderfrom torchvision.transforms import Compose, ToTensor, Lambdafrom torchvision.datasets.mnist import MNIST, FashionMNIST# Setting reproducibilitySEED = 0random.seed(SEED)np.random.seed(SEED)torch.manual_seed(SEED)# DefinitionsSTORE_PATH_MNIST = f"ddpm_model_mnist.pt"STORE_PATH_FASHION = f"ddpm_model_fashion.pt"`

Next, we define a few parameters for our experiment. In particular, we decide if we want to run the training loop, whether we want to use the Fashion-MNIST dataset and some training hyper-parameters

`no_train = Falsefashion = Truebatch_size = 128n_epochs = 20lr = 0.001store_path = "ddpm_fashion.pt" if fashion else "ddpm_mnist.pt"`

Next, we would really like to display images. Both the training images and those generated by the model are of interest to us. We write a utility function that given some images, will display a square (or as close as it gets) grid of sub-figures:

`def show_images(images, title=""):    """Shows the provided images as sub-pictures in a square"""    # Converting images to CPU numpy arrays    if type(images) is torch.Tensor:        images = images.detach().cpu().numpy()    # Defining number of rows and columns    fig = plt.figure(figsize=(8, 8))    rows = int(len(images) ** (1 / 2))    cols = round(len(images) / rows)    # Populating figure with sub-plots    idx = 0    for r in range(rows):        for c in range(cols):            fig.add_subplot(rows, cols, idx + 1)            if idx < len(images):                plt.imshow(images[idx][0], cmap="gray")                idx += 1    fig.suptitle(title, fontsize=30)    # Showing the figure    plt.show()`

To test this utility function, we load our dataset and show the first batch. Important: Images must be normalized in the range [-1, 1], as our network will have to predict noise values that are normally distributed:

`# Shows the first batch of imagesdef show_first_batch(loader):    for batch in loader:        show_images(batch[0], "Images in the first batch")        break`
`# Loading the data (converting each image into a tensor and normalizing between [-1, 1])transform = Compose([    ToTensor(),    Lambda(lambda x: (x - 0.5) * 2)])ds_fn = FashionMNIST if fashion else MNISTdataset = ds_fn("./datasets", download=True, train=True, transform=transform)loader = DataLoader(dataset, batch_size, shuffle=True)`

Great! Now that we have this nice utility function, we will use it for images generated by our model later on as well. Before we start actually dealing with the DDPM model, we will just get a GPU device from colab (typically a Tesla T4 for non colab-pro users):

# The DDPM Model

Now that we got the trivial stuff out of the way, it’s time to work on the DDPM. We will create a MyDDPM PyTorch module that will be responsible for storing betas and alphas values and applying the forward process. For the backward process instead, the MyDDPM module will simply rely on a network used for constructing the DDPM:

`# DDPM classclass MyDDPM(nn.Module):    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)):        super(MyDDPM, self).__init__()        self.n_steps = n_steps        self.device = device        self.image_chw = image_chw        self.network = network.to(device)        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(            device)  # Number of steps is typically in the order of thousands        self.alphas = 1 - self.betas        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)    def forward(self, x0, t, eta=None):        # Make input image more noisy (we can directly skip to the desired step)        n, c, h, w = x0.shape        a_bar = self.alpha_bars[t]        if eta is None:            eta = torch.randn(n, c, h, w).to(self.device)        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta        return noisy    def backward(self, x, t):        # Run each image through the network for each timestep t in the vector t.        # The network returns its estimation of the noise that was added.        return self.network(x, t)`

Note that the forward process is independent of the network used to denoise, so technically we could already visualize its effect. At the same time, we can also create a utility function that applies Algorithm 2 (sampling procedure) to generate new images. We do so with two DDPM’s specific utility functions:

`def show_forward(ddpm, loader, device):    # Showing the forward process    for batch in loader:        imgs = batch[0]        show_images(imgs, "Original images")        for percent in [0.25, 0.5, 0.75, 1]:            show_images(                ddpm(imgs.to(device),                     [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))]),                f"DDPM Noisy images {int(percent * 100)}%"            )        break`

To generate images, we start with random noise and let t go from T back to 0. At each step, we estimate the noise as eta_theta and apply the denoising function. Finally, extra noise is added as in Langevin dynamics.

`def generate_new_images(ddpm, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28):    """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)    frames = []    with torch.no_grad():        if device is None:            device = ddpm.device        # Starting from random noise        x = torch.randn(n_samples, c, h, w).to(device)        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):            # Estimating noise to be removed            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()            eta_theta = ddpm.backward(x, time_tensor)            alpha_t = ddpm.alphas[t]            alpha_t_bar = ddpm.alpha_bars[t]            # Partially denoising the image            x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)            if t > 0:                z = torch.randn(n_samples, c, h, w).to(device)                # Option 1: sigma_t squared = beta_t                beta_t = ddpm.betas[t]                sigma_t = beta_t.sqrt()                # Option 2: sigma_t squared = beta_tilda_t                # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0]                # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t                # sigma_t = beta_tilda_t.sqrt()                # Adding some more noise like in Langevin Dynamics fashion                x = x + sigma_t * z            # Adding frames to the GIF            if idx in frame_idxs or t == 0:                # Putting digits in range [0, 255]                normalized = x.clone()                for i in range(len(normalized)):                    normalized[i] -= torch.min(normalized[i])                    normalized[i] *= 255 / torch.max(normalized[i])                # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame                frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5))                frame = frame.cpu().numpy().astype(np.uint8)                # Rendering frame                frames.append(frame)    # Storing the gif    with imageio.get_writer(gif_name, mode="I") as writer:        for idx, frame in enumerate(frames):            writer.append_data(frame)            if idx == len(frames) - 1:                for _ in range(frames_per_gif // 3):                    writer.append_data(frames[-1])    return x`

Everything that concerns DDPM is on the table now. We simply need to define the model that will actually do the job of predicting the noise in the image given the image and the current time step. To do that, we will create a custom U-Net model. It goes without saying that you are free to use any other model of your choice.

# The U-Net

We start the creation of our U-Net by creating a block that will keep spatial dimensionality unchanged. This block will be used at every level of our U-Net.

`class MyBlock(nn.Module):    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):        super(MyBlock, self).__init__()        self.ln = nn.LayerNorm(shape)        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)        self.activation = nn.SiLU() if activation is None else activation        self.normalize = normalize    def forward(self, x):        out = self.ln(x) if self.normalize else x        out = self.conv1(out)        out = self.activation(out)        out = self.conv2(out)        out = self.activation(out)        return out`

The tricky thing in DDPMs is that our image-to-image model has to be conditioned on the current time step. To do so in practice, we use sinusoidal embedding and one-layer MLPs. The resulting tensors will be added channel-wise to the input of the network through every level of the U-Net.

`def sinusoidal_embedding(n, d):    # Returns the standard positional embedding    embedding = torch.zeros(n, d)    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])    wk = wk.reshape((1, d))    t = torch.arange(n).reshape((n, 1))    embedding[:,::2] = torch.sin(t * wk[:,::2])    embedding[:,1::2] = torch.cos(t * wk[:,::2])    return embedding`

We create a small utility function that creates a one-layer MLP which will be used to map positional embeddings.

`def _make_te(self, dim_in, dim_out):  return nn.Sequential(    nn.Linear(dim_in, dim_out),    nn.SiLU(),    nn.Linear(dim_out, dim_out)  )`

Now that we know how to deal with the time information, we can create a custom U-Net network. We will have 3 down-sample parts, a bottleneck in the middle of the network, and 3 up-sample steps with the usual U-Net residual connections (concatenations).

`class MyUNet(nn.Module):    def __init__(self, n_steps=1000, time_emb_dim=100):        super(MyUNet, self).__init__()        # Sinusoidal embedding        self.time_embed = nn.Embedding(n_steps, time_emb_dim)        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)        self.time_embed.requires_grad_(False)        # First half        self.te1 = self._make_te(time_emb_dim, 1)        self.b1 = nn.Sequential(            MyBlock((1, 28, 28), 1, 10),            MyBlock((10, 28, 28), 10, 10),            MyBlock((10, 28, 28), 10, 10)        )        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)        self.te2 = self._make_te(time_emb_dim, 10)        self.b2 = nn.Sequential(            MyBlock((10, 14, 14), 10, 20),            MyBlock((20, 14, 14), 20, 20),            MyBlock((20, 14, 14), 20, 20)        )        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)        self.te3 = self._make_te(time_emb_dim, 20)        self.b3 = nn.Sequential(            MyBlock((20, 7, 7), 20, 40),            MyBlock((40, 7, 7), 40, 40),            MyBlock((40, 7, 7), 40, 40)        )        self.down3 = nn.Sequential(            nn.Conv2d(40, 40, 2, 1),            nn.SiLU(),            nn.Conv2d(40, 40, 4, 2, 1)        )        # Bottleneck        self.te_mid = self._make_te(time_emb_dim, 40)        self.b_mid = nn.Sequential(            MyBlock((40, 3, 3), 40, 20),            MyBlock((20, 3, 3), 20, 20),            MyBlock((20, 3, 3), 20, 40)        )        # Second half        self.up1 = nn.Sequential(            nn.ConvTranspose2d(40, 40, 4, 2, 1),            nn.SiLU(),            nn.ConvTranspose2d(40, 40, 2, 1)        )        self.te4 = self._make_te(time_emb_dim, 80)        self.b4 = nn.Sequential(            MyBlock((80, 7, 7), 80, 40),            MyBlock((40, 7, 7), 40, 20),            MyBlock((20, 7, 7), 20, 20)        )        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)        self.te5 = self._make_te(time_emb_dim, 40)        self.b5 = nn.Sequential(            MyBlock((40, 14, 14), 40, 20),            MyBlock((20, 14, 14), 20, 10),            MyBlock((10, 14, 14), 10, 10)        )        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)        self.te_out = self._make_te(time_emb_dim, 20)        self.b_out = nn.Sequential(            MyBlock((20, 28, 28), 20, 10),            MyBlock((10, 28, 28), 10, 10),            MyBlock((10, 28, 28), 10, 10, normalize=False)        )        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)    def forward(self, x, t):        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)        t = self.time_embed(t)        n = len(x)        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)        out = self.conv_out(out)        return out    def _make_te(self, dim_in, dim_out):        return nn.Sequential(            nn.Linear(dim_in, dim_out),            nn.SiLU(),            nn.Linear(dim_out, dim_out)        )`

Now that we defined our denoising network, we can proceed to instantiate a DDPM model and play with some visualizations.

# Some visualizations

We instantiate the DDPM model using our custom U-Net as follows.

`# Defining modeln_steps, min_beta, max_beta = 1000, 10 ** -4, 0.02  # Originally used by the authorsddpm = MyDDPM(MyUNet(n_steps), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)`

Let’s check what the forward process looks like:

`# Optionally, show the diffusion (forward) processshow_forward(ddpm, loader, device)`

We haven't trained the model yet, but we can already use the function that allows us to generate new images and see what happens:

Not surprisingly, nothing happens when we do so. However, we will re-use this same method later on when the model will be finished training.

# Training loop

We now implement Algorithm 1 to learn a model that will know how to denoise images. This corresponds to our training loop.

`def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):    mse = nn.MSELoss()    best_loss = float("inf")    n_steps = ddpm.n_steps    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):        epoch_loss = 0.0        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):            # Loading data            x0 = batch[0].to(device)            n = len(x0)            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars            eta = torch.randn_like(x0).to(device)            t = torch.randint(0, n_steps, (n,)).to(device)            # Computing the noisy image based on x0 and the time-step (forward process)            noisy_imgs = ddpm(x0, t, eta)            # Getting model estimation of noise based on the images and the time-step            eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))            # Optimizing the MSE between the noise plugged and the predicted noise            loss = mse(eta_theta, eta)            optim.zero_grad()            loss.backward()            optim.step()            epoch_loss += loss.item() * len(x0) / len(loader.dataset)        # Display images generated at this epoch        if display:            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"        # Storing the model        if best_loss > epoch_loss:            best_loss = epoch_loss            torch.save(ddpm.state_dict(), store_path)            log_string += " --> Best model ever (stored)"        print(log_string)`

As you can see, in our training loop we simply sample some images and some random time steps for each of them. We then make them noisy with the forward process and run the backward process on those noisy images. The MSE between the actual noise added and the one predicted by the model is optimized.

By default, I set the training epochs to 20 as it takes 24 seconds per epoch (total of roughly 8 minutes to train). Note that it is possible to obtain even better performances with more epochs, a better U-Net, and other tricks. In this post, I omit those for simplicity.

# Testing the model

Now that the job is done, we can simply enjoy the results. We load the best model obtained during training according to the MSE loss function, set it to evaluation mode and use it to generate new samples

`# Loading the trained modelbest_model = MyDDPM(MyUNet(), n_steps=n_steps, device=device)best_model.load_state_dict(torch.load(store_path, map_location=device))best_model.eval()print("Model loaded")`
`print("Generating new images")generated = generate_new_images(        best_model,        n_samples=100,        device=device,        gif_name="fashion.gif" if fashion else "mnist.gif"    )show_images(generated, "Final result")`

The cherry on the cake is the fact that our generation function automatically creates a nice gif of the diffusion process. We visualize that gif in Colab with the following command:

And we are done! We finally have our DDPM model working!

# Further Improvements

Further improvements have been made, as to allow the generation of higher resolution images, accelerate sampling or obtain better sample quality and likelihood. Imagen and DALL-E 2 models are based on improved versions of the original DDPMs.

# More references

For more references regarding DDPMs, I strongly recommend reading the outstanding post by Lilian Weng and Niels Rogge and Kashif Rasul’s amazing Hugging Face Blog. Other authors are also mentioned at the end of the Colab Notebook.

# Conclusion

Diffusion models are generative models that learn to denoise images iteratively. Starting from some noise is then possible to ask the model to de-noise the sample until some realistic image is obtained.

We created a DDPM from scratch in PyTorch and had it learn to de-noise MNIST / Fashion-MNIST images. The model, after training, was finally capable of generating new images out of random noise. Quite magical, right?

The Colab Notebook with the shown implementation is freely accessible at this link, while the GitHub repository contains .py files. If you found this story useful, consider giving it a clap 👏. If you feel like something is unclear, don’t hesitate to contact me directly! I’d be glad to discuss it with you.

--

--

BSc in CS, MSc in AI. Currently Ph.D. Student in Machine Learning & Computer Vision. Passionate about AI and how to use it against climate change.