How to build a CGAN for generating MNIST digits in PyTorch

Simple Schwarz
4 min readJun 29, 2022

--

This post introduces how to build a CGAN (Conditional Generative Adversarial Network) for generating synthesis handwritten digit images based on a given label by using MNIST dataset in PyTorch. All snippets are written in Jupyter notebook. Before getting to the point, if you haven’t built a GAN model, please have a look the following post first.

Import libraries

Define variables

#CUDA: If you want to use CUDA, change to CUDA = TRUE.

#DATA_PATH: Make sure that you should create a folder named “data” in the same folder containing your ipynb file before running the code.

#batch_size: The number of sub-samples given to the network. Bigger batch sizes learn faster but require more memory space.

#epochs: The number of times the entire training dataset is trained in the network. Lager epoch number is better, but you should be careful of overfitting.

#lr: The size of each step the network takes when it descends the error mountain. A learning rate that is too big speeds up the learning but is not guaranteed to find the minimum error.

#classes : The number of classes. The MNIST dataset has 10 labels from 0 to 9. 10 is set as a variable here.

#channels: All images in MNIST are single channel, which means the gray scale image. Therefore, a value of channels is 1.

#img_size: An original image size in MNIST is 28x28. I will change 28x28 to 64x64 with a resize module for the network.

#latent_dim: Size of z latent vector (i.e. size of generator input). It is used to generate random numbers for the generator.

#log_interval: The message interval you can check results during training.

Set up CUDA

Preparation the dataset

#MNIST dataset: A dataset of handwritten digits which contains 60,000 training images and 10,000 testing images. If you set download = True, you don’t need to download the MNIST dataset by yourself.

#transforms.Resize: image size is changed from 28x28 to 64(img_size).

#transforms.ToTensor: Convert a PIL image or numpy.ndarray to tensor.

#transforms.Normalize: Normalize a tensor image with mean and standard deviation. Why? To ensure that each input feature (pixel, in the image case) has a similar data distribution. For example, one image may have a pixel value range from 0 to 255, and another may have a range of 20 to 200. It is preferred to normalize the pixel values to the range of 0 to 1 to boost learning performance and make the network converge faster. Data normalization is done by subtracting the mean from each pixel and then dividing the result by the standard deviation. I didn’t calculate the exact mean and the standard deviation here and just used both 0.5.

#dataloader: DataLoader wraps an iterable around the Dataset to enable easy access to the samples. num_works refers to how many subprocesses to use for data loading.

Generator

#The generator class consists of three functions by using torch.nn.Module which help you build your network models easily: __init__, _create_layer,and forward.

#The __init__ method is where we typically define the attributes of a class. You can do any setup here. I set the number of classes, the number of channels, the size of image, the dimension of latent vector, and the nn.Embedding module. This module is simple lookup table that stores embeddings of a fixed dictionary and size. It is used to process the label information with the random latent vector.

#_create_layer is where we define layers. It consists of 5 linear layers, 3 of which are connected to batch normalization layers, and the first 4 linear layers have LeakyReLu activation functions while the last has a Tahn activation function. Batch normaliazation is a method for the extracted features in the hidden units to make training faster and more stable.

#The forward method is called when we use the neural network to make a prediction. torch.cat is used to concatenate the given sequence of seq tensors in the given dimension, -1 here.

Discriminator

#The discriminator outputs a single value to show how close an image is to the real images as given the label information. The network consists of 5 linear layers, 2 of which are connected to dropout layers to to prevent overfitting. It makes all the nodes work well as a team by making sure no node is too weak or too strong through some neurons are not included in a particular forward or backward pass. BCE loss function is typically used for the binary classification tasks.

Define models and optimizers

Train

#netG.train() and netD.train() turn on train mode. Then, latent vectors (viz_noise) and label (viz_label) are defined. They are used to occasionally produce images during training so that we can track how the model is trained. Torch.randn returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1. Torch.LongTensor defines 64-bit integer(signed) as a data type.

Plot fake images

This is the result after 15 epochs training. You can see the digits are sorted in order by using label information. You can find the code here.

--

--