understand by creating a model which generates images of handwritten digits similar to those from the MNIST database.
Introduction to Generative Modeling:
Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset. — Source.
To get a sense of the power of generative models, just visit thispersondoesnotexist.com. Every time you reload the page, a new image of a person’s face is generated on the fly. The results are pretty fascinating.
Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adversarial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling
While there are many approaches used for generative modeling, a Generative Adversarial Network takes the following approach:
There are two neural networks: a Generator and a Discriminator. The generator generates a “fake” sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is “real” (picked from the training data) or “fake” (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.
GANs, however, can be notoriously difficult to train and are extremely sensitive to hyperparameters, activation functions, and regularization. we’ll train a GAN to generate images of handwritten digits similar to those from the MNIST database.
Here’s what we’re going to do:
- Define the problem statement
- Load the data (with transforms and normalization)
- Denormalize for visual inspection of samples
- Define the Discriminator network
- Study the activation function: Leaky ReLU
- Define the Generator network
- Explain the output activation function: TanH
- Look at some sample outputs
- Define losses, optimizers, and helper functions for the training of Discriminator and Generator
- Train the model
- Define the problem statement: Train GAN to generate images of handwritten digits similar to those from the MNIST database.
- Load the data (with transforms and normalization): We begin by downloading and importing the data as a PyTorch dataset using the
MNISThelper class from
3. Denormalize for visual inspection of samples: Since we have normalized our dataset of images, so we need to define a helper to denormalize the images to view them. This function will also be useful for viewing the generated images.
4. Define the Discriminator network: The discriminator takes an image as input and tries to classify it as
“generated”. In this sense, it’s like any other neural network. While we can use a CNN for the discriminator, but we are using a simple feedforward network with 3 linear layers to keep things simple. We’ll treat each 28x28 image as a vector of size 784. Just like any other binary classification model, the output of the discriminator is a single number between 0 and 1, which can be interpreted as the probability of the input image being fake i.e. generated.
5. Study the activation function: Leaky ReLU: We are using the Leaky ReLU activation for the discriminator. Different from the regular ReLU function, Leaky ReLU allows the pass of a small gradient signal for negative values. As a result, it makes the gradients from the discriminator flows stronger into the generator. Instead of passing a gradient (slope) of 0 in the back-prop pass, it passes a small negative gradient.
6. Define the Generator network: The input to the generator is typically a vector or a matrix which is used as a seed for generating an image. Once again, to keep things simple, we’ll use a feedforward neural network with 3 layers, and the output will be a vector of size 784, which can be transformed to a 28x28 px image. The ReLU activation is used in the generator with the exception of the output layer which uses the Tanh function. we are taking
7. Explain the output activation function: TanH: In the Generator, we are using the TanH activation function for the output layer of the generator. Because we have observed that using a bounded activation allowed the model to learn more quickly to saturate and cover the color space of the training distribution.
Note that since the outputs of the TanH activation lie in the range of
[-1,1], we have applied the same transformation to the images in the training dataset.
8. Look at some sample outputs: Let’s generate an output vector using the generator and view it as an image by transforming and denormalizing the output.
As one might expect, the output from the generator is basically random noise. Now we need to define a helper function to train our model.
9. Define losses, optimizers, and helper functions for training: Now we need to define loss and optimizer functions to train our discriminator and generator. Since the discriminator is a binary classification model, we can use the binary cross-entropy loss function to quantify how well it is able to differentiate between real and generated images.
In the optimizers of generator and discriminator, we also set the learning rate to 0.0002 to train our model.
a. Discriminator Training: Here are the steps involved in training the discriminator.
- We expect the discriminator to output 1 if the image was picked from the real MNIST dataset, and 0 if it was generated.
- We first pass a batch of real images, and compute the loss, setting the target labels to 1.
- Then, we generate a batch of fake images using the generator, pass them into the discriminator, and compute the loss, setting the target labels to 0.
- Finally, we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.
It’s important to note that we don’t change the weights of the generator model while training the discriminator (
d_optimizer only affects the
D.parameters()). And here is the code of how we define a training function for the discriminator. We have also defined
reset_grad a function to reset the gradient to zero before training each epoch.
b. Generator Training: Since the outputs of the generator are images, it’s not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function. Here’s how it works:
- We generate a batch of images using the generator, pass them into the discriminator.
- We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator’s objective is to “fool” the discriminator.
- We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images.
Here’s what this looks like in code.
10. Training the Model: Now we did all work to train our model and now we are now ready to train the model. In each epoch, we train the discriminator first, and then the generator. The training might take a while if you’re not using a GPU.
Here’s what this looks like in code.
If you don’t have install GPUs in your hardware system, you can use google colab which is a free google service for data science and machine learning, and you can change runtime to GPU in your jupyter notebook on google colab.
Now we have trained our model and we can save our intermediate generated image to any file and we can see how our images become like images of MNIST dataset from random noise.
Here is a link to a video how images are changing
Here is a summary of what we have done to train our Generative Adversarial Network.
We have trained a model which takes a set of real images and generates some fake images which are similar to real images. For this work first, we load data from the MNIST dataset, take a look at images from the dataset, define a discriminator to differentiate between real and generated images, define a generator to generate images, then train discriminator and generator and then finally train our complete model using some hyperparameters like a number of epochs and learning rate, And that’s it.
As an exercise, you should try applying each technique independently and see how much each one affects the performance and final output. As you try different experiments, you will start to cultivate the intuition for picking the right architectures and, you will be better and better.
For your help, I am sharing the helping notebook in references.
bhupendrasingh62435/06-mnist-gan - Jovian
Collaborate with bhupendrasingh62435 on 06-mnist-gan notebook.
If you love this and found something interesting you can give one clap for me and you can share it with your friends