How Pytorch has made GANs training super easy!

Maham Shafiq
Analytics Vidhya
Published in
6 min readJun 2, 2020

Generative Adversarial Network always amazed me when I dive into the field of Machine Learning. GANs are Machine Learning models that can imagine new things. Generative Adversarial Networks (GANs) were introduced for the first time in 2014, a generally new idea in the field of Machine Learning. GANs were introduced with the goal to generate artificial samples or images that are distinct and indistinguishable from real image samples.

What is GAN?

GANs predict by generating the most probable outcome of the given input samples sequence. As an example, a generative model can generate the next likely design based on the previously fed sample frames. Keeping this concept in mind, we now can tackle GANs. This model is a new unsupervised neural network architecture that outperforms traditional nets. To be more precise, a new way of training a neural network is GANs. GANs contain two independent networks that work distinctly and act as adversaries

1-Generator takes random noise as input then those noises run through a differentiable function to transform noise and reshape it into a recognizable structure. The output is a realistic image. The choice of input noise determines which image will come out of the generator network.

2- Discriminator is the neural network that has to undergo training and after that, it acts as a classifier that can discriminate between real (images in the dataset) and fake (images generated by generator).

What we can do with GANS?

With time I realized that GANs has super cool application. Gans has produced a ton of super interesting results in the previous years. Game development and animation production are expensive and hire many production artists for relatively routine tasks. GAN can auto-generate and colorize Anime characters.

StackGANs model can convert the textual description of a bird into high-resolution bird images. Pix2Pix GANs converts crude sketches into a realistic image. Image to image translation can be trained in an unsupervised task.GANs can Convert a photo of the face into a Cartoon of the face. And the one that is my favorite one Style GANs that can develop high-resolution images. Check out this video.

Why Pytorch?

In early 2017, PyTorch was released and has been making a pretty big impact in the deep learning community. Facebook AI Research team developed it as an open-source project. After some time it was adopted by teams everywhere in industry and academia. In my experience, it’s the best framework for learning deep learning and just a delight to work with in general. By the end of this article, you’ll have learned to train your own GAN. Isn’t it COOL?

Fashion MNIST-GANs using Pytorch

Let’s get our hands dirty into coding. In this article, I'll be building a generative adversarial network (GAN) trained on the Fashion MNIST dataset, this article is for extreme beginners of GANs. From this, we’ll be able to generate a new fashion item!

The diagram shows the general structure of GANS using Fashion MNIST images as data. The latent sample is a random vector that the generator uses to construct its fake images. This is often called a **latent vector** and that vector space is called **latent space**. As the generator trains, it figures out how to map latent vectors to recognizable images that can fool the discriminator.

This article is inspired by the GANs lesson in Udacity’s Nanodegree program but I applied the model on the Fashion dataset with some changes.

Step#01 —Importing Necessary Libraries and Loading MNIST Dataset

To get started we’ll need to install Pytorch. If you have Python >= 3.5 and Jupyter installed you can run this locally. Or alternatively, you can run this in Google Colab (if you have a Google account).Before starting we make all the necessary imports

Now we’ll define the number of subprocesses to use for data loading. Samples per batch to load, ideal batch size ranges from 32 to 128. Converting data to torch.FloatTensor. Later get the training datasets. Prepare data loader which helps to load the data in the batch size mentioned above.

Step#02 — Visualizing the Dataset

Step#03 —Defining Discriminator of our Model

The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we’ll need at least one hidden layer, and these hidden layers should have one key attribute:

Step#04 -Defining Generator of our Model

The generator network will be almost exactly the same as the discriminator network, except that we’re applying a tanh activation function to our output layer.

Step#05 Defining Hyperparameters For Model and Build Complete Network

Hyperparameters, in this case, would be the size of the input layer, the Generator & Discriminator’s hidden layers, and output size.

Now we’re instantiating the discriminator and generator from the classes defined above. Make sure you’ve passed in the correct input arguments.

Detail description of Generator and Discriminator is printed.

Step#06 Defining Discriminator and Generator Losses

Discriminator Losses

  • For the discriminator, the total loss is the sum of the losses for real and fake images, d_loss = d_real_loss + d_fake_loss.
  • Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.

Generator Loss

The generator loss will look similar only with flipped labels. The generator’s goal is to get D(fake_images) = 1. In this case, the labels are flipped to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!

Step #07 Optimizers and Training

We want to update the generator and discriminator variables separately. So, we’ll define two separate Adam optimizers.

Training will involve alternating between training the discriminator and the generator. We’ll use our functions real_loss and fake_loss to help us calculate the discriminator losses in all of the following cases.

Discriminator training

  1. Compute the discriminator loss on real, training images
  2. Generate fake images
  3. Compute the discriminator loss on fake, generated images
  4. Add up the real and fake loss
  5. Perform backpropagation + an optimization step to update the discriminator’s weights

Generator training

  1. Generate fake images
  2. Compute the discriminator loss on fake images, using flipped labels!
  3. Perform backpropagation + an optimization step to update the generator’s weights

Saving Samples

As we train, we’ll also print out some loss statistics and save some generated “fake” samples.

Training loss

Here we’ll plot the training losses for the generator and discriminator, recorded after each epoch.

Generated Samples

I trained for 200 epochs and got these results. These are samples from the final training epoch. You can see the generator is able to reproduce images like a shoe, shirt , cap skirt since this is just a sample, it isn’t representative of the full range of images this generator can make. For better results, you can play with the parameters like increase the number of epochs or try and test different optimizers.

This is not the end I am discovering GANs more. Insha Allah will write more on different GANs architectures soon. Stay Tuned!

References:

I recommend you to refer my https://medium.com/@mahamshafiq98/machine-learning-basics-524e63c8238 article as an additional resource. Udacity Deep Learning Nanodegree briefly covered this topic as I mentioned earlier so do check it out!

--

--