Generating GAN images by training on the MNIST dataset

What is GAN?

GAN, short for Generative Adversarial Network, is a powerful and innovative class of machine learning models used in the field of deep learning. It was introduced by Ian Goodfellow and his colleagues in 2014. GANs are designed to generate data, such as images, music, or text, that is indistinguishable from real data. They are particularly renowned for their ability to create highly realistic and diverse synthetic content.

A GAN consists of two neural networks, the Generator and the Discriminator, which are trained simultaneously through a competitive process:

  1. Generator: The Generator network takes random noise or some other form of input and attempts to generate data that resembles the real data it’s been trained on. Over time, through the learning process, it becomes increasingly proficient at creating convincing fake data
  2. Discriminator: The Discriminator network, on the other hand, is tasked with distinguishing between real data and data produced by the Generator. It learns to differentiate genuine data from the synthetic data generated by the Generator.

The key concept in GANs is adversarial training. The Generator tries to improve its ability to produce realistic data, while the Discriminator strives to get better at distinguishing real from fake data. This adversarial process continues iteratively until the Generator generates data that is difficult for the Discriminator to distinguish from real data. At this point, the Generator has effectively learned to generate highly convincing and realistic synthetic data.

block diagram of GAN
GAN block

GANs have a wide range of applications, including image generation, style transfer, data augmentation, super-resolution, and even drug discovery. They have revolutionized the field of generative modeling and have contributed significantly to the development of artificial intelligence techniques for creating and enhancing data. However, GANs also present challenges such as mode collapse and training instability, which researchers continue to address to make them even more effective and reliable.

Dall-E is the most famous example that comes to mind when we hear about GAN.

Generative models like Dall-E work on images with high resolution. The best way to start off with working on GAN is doing a project with low resolution images. In this article, we try to generate images trained on MNIST dataset.

What is MNIST dataset?

The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image.

It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.

MNIST dataset sample images

So how do we get started with our problem?

As mentioned above, we train the generator-discriminator pair using the MNIST dataset and generate new images from it.

First we load the necessary dependencies

We import the pytorch libraries which is a popular deep learning framework. Torchvision helps us to work on images. We use the MNIST dataset version provided by pytorch which is usually used for sample problems and is easier to load into python.

#import dependencies

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

Then,

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

Setting random seeds in deep learning projects is a good practice to ensure reproducibility of results. By fixing the random seed, you can make sure that the random initialization of model weights, data shuffling, and other random processes are consistent across different runs of your code.

# Define hyperparameters
num_epochs = 100
batch_size = 64
learning_rate = 0.0002
image_size = 28
image_channels = 1
latent_dim = 100

Here, we defined several hyperparameters that are crucial for training a Generative Adversarial Network (GAN). Let’s go through each of them:

  • num_epochs: This hyperparameter specifies the number of training epochs. An epoch is one complete pass through the entire training dataset. In this case, we set it to 100, meaning that the GAN will be trained for 100 epochs.
  • batch_size: The batch size determines how many data samples are processed in each forward and backward pass during training. A larger batch size can lead to faster convergence but requires more memory.
  • learning_rate: The learning rate controls the step size during the optimization process (typically gradient descent). It's a crucial hyperparameter that influences the training speed and stability. A smaller learning rate may lead to more stable training but require more epochs for convergence.
  • image_size and image_channels: These parameters specify the size and number of channels (e.g., 1 for grayscale or 3 for RGB) of the input images. In this case, MNIST dataset contains grey-scale 28x28 images.
  • latent_dim: The latent dimension represents the size of the random noise vector that serves as input to the generator. A larger latent dimension can allow the GAN to capture more complex patterns but may also require a more complex generator network. You've set it to 100.
# Create a custom dataset for MNIST
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])

mnist_dataset = datasets.MNIST(root='./data',
train=True,
transform=transform,
download=True)

data_loader = DataLoader(dataset=mnist_dataset,
batch_size=batch_size,
shuffle=True)

With this custom dataset and DataLoader, you are ready to train your GAN model on the MNIST dataset. The DataLoader will handle the loading of data in batches, and you can iterate over it during training.

Now comes the task of defining the generator and discriminator network.

# Define the generator network
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 256, 7, 1, 0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, image_channels, 4, 2, 1),
nn.Tanh()
)

def forward(self, z):
return self.model(z)

This network is responsible for generating fake images from random noise (latent vectors).This generator architecture is a common choice for GANs and is often referred to as a β€œDCGAN” (Deep Convolutional GAN) generator. It uses transposed convolutional layers to gradually upsample the input noise into a realistic image.

Next, we define the discriminator network

# Define the discriminator network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(image_channels, 128, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, 7, 1, 0),
nn.Sigmoid()
)

def forward(self, x):
return self.model(x)

This network is responsible for distinguishing between real and fake images. This discriminator architecture is also a common choice for GANs and is often used in conjunction with a generator network. During GAN training, the generator tries to produce fake images that the discriminator cannot easily distinguish from real images. The discriminator is trained to distinguish between real and fake images, and the generator aims to produce increasingly convincing fake samples to fool the discriminator. This adversarial process continues iteratively until the generator generates realistic images.

# Initialize the generator and discriminator
generator = Generator().cuda()
discriminator = Discriminator().cuda()

After defining the two functions. We create instances of the both class. The, .cuda() is a method used to transfer the computations to GPU to make it faster. On google colab, you can find a free GPU version by changing the run time or use your local GPU.

# Define loss and optimizers
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

We define the loss function and the optimizer required for the GAN.

Now since we have defined everything its time to train the model on the MNIST dataset.

# Training loop
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(data_loader):
real_images = real_images.cuda()
batch_size = real_images.size(0)

# Train discriminator with real images
optimizer_d.zero_grad()
label_real = torch.ones(batch_size, 1).cuda()
output_real = discriminator(real_images).view(-1, 1)
loss_real = criterion(output_real, label_real)
loss_real.backward()

# Train discriminator with fake images
noise = torch.randn(batch_size, latent_dim, 1, 1).cuda()
fake_images = generator(noise)
label_fake = torch.zeros(batch_size, 1).cuda()
output_fake = discriminator(fake_images.detach()).view(-1, 1)
loss_fake = criterion(output_fake, label_fake)
loss_fake.backward()
optimizer_d.step()

# Train generator
optimizer_g.zero_grad()
output = discriminator(fake_images).view(-1, 1)
loss_g = criterion(output, label_real)
loss_g.backward()
optimizer_g.step()

if (i + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], '
f'D_real: {output_real.mean():.4f}, D_fake: {output_fake.mean():.4f}, '
f'Loss_D: {loss_real.item() + loss_fake.item():.4f}, Loss_G: {loss_g.item():.4f}')
# Generate and save sample images at the end of each epoch
with torch.no_grad():
fake_samples = generator(torch.randn(64, latent_dim, 1, 1).cuda())
fake_samples = fake_samples.cpu()
fake_grid = torchvision.utils.make_grid(fake_samples, padding=2, normalize=True)
plt.imshow(np.transpose(fake_grid, (1, 2, 0)))
plt.axis('off')
plt.show()

This training loop follows the typical training procedure for GANs, where the discriminator and generator are updated alternately, with the generator trying to produce increasingly convincing fake images while the discriminator becomes better at distinguishing real from fake images. The loop allows you to monitor the training progress and visualize generated samples at the end of each epoch.

This code would give u a final generator image after 100 epochs. The code took around 30 mins to train and generate 100 epochs of images.

Generated images after 100 epochs
# Save the trained generator model
torch.save(generator.state_dict(), 'generator.pth')

You can use this code to save the trained generator.

Conclusion

GANs have in fact revolutionized the process of image generation, we have even come to point were artist are become obsolete due to such state of the art AI models.

The problem discussed in this article is a rather simple one on low resolution images. Even then we see the generated images aren’t actually perfect. Tuning the hyperparameters accordingly can produce and even better result.

--

--

Ashin Babu
π€πˆ 𝐦𝐨𝐧𝐀𝐬.𝐒𝐨

An AI/ML enthusiast on a mission to simplify the world of Artificial Intelligence and Machine Learning. Join me! πŸš€πŸ€– #AI #ML