Introduction to Generative Adversarial Networks (GANs)

Brijesh Modasara
7 min readApr 13, 2020

--

GANs are analogous to a never ending game of cat and mouse

If you know who is better among Tom and Jerry or who among Tom and Jerry wins the fight, then probably you have some idea of how GANs work. This article is my attempt to explain: what are GANs, how they work and how they are trained.

“The coolest idea in deep learning in the last 20 years.” — Yann LeCun on GANs.

This remark alone suffices the importance and innovativeness of GANs. GANs or Generative Adversarial Networks were introduced in 2014 by Ian Goodfellow and his team. It is method used for generating new data (mostly images) similar to those in the training data but completely unique.

Introduction

The Generative Adversarial Networks is a part of unsupervised machine learning techniques used for generative modelling. A GAN comprises of two parts:

  1. A Generator model G (Forger or Jerry Mouse): tasked with generating new images.
  2. A Discriminator model D (Police or Tom Cat): tasked with identifying whether the input image is real or fake (made by the Forger).

So how do these two models work?

Both the models are adversaries of each other and are playing a zero-sum game. Meaning that, the Generator model G generates a batch of images (fake) and these, along with the real images (from training data) are given to the Discriminator model D. The Discriminator model D is supposed to differentiate between real and fake images. Both the models train each other in cyclic manner until the generator model is smart enough to create the images that are real enough to fool the discriminator model. In mathematical terms, the model tries to learn the data distribution underlying the real images and replicate a similar data distribution for creating new images.

Model Architecture and Training

The Generator model G is a neural network with parameters θ and input vector z.

  1. Points in latent space i.e., 100-element vector (z) of Gaussian random number is given as an input the network. The subsequent layers are used for upsampling the data to create the required image.
  2. In a convolutional neural network, there are different feature maps, each with a different interpretation of the image. Similarly, when creating an image from n-length vector, multiple images need to be generated which can be condensed into one image at the end. Therefore, the first layer must contain enough neurons to create multiple feature maps.
  3. For example, to create an image of the MNIST dataset (28 x 28 image, 1 channel), the first Dense layer contains 6272 neurons. The output of first layer is then reshaped to have 128 feature maps each with 7 x 7 image.
  4. Subsequent layers apply upsampling techniques to create a 28 x 28 image.
  5. This is then given as one of the inputs to Discriminator model D.

The Discriminator model D is a binary classification model that classifies the images into two classes: real (“1”) or fake (“0”). It is a convolutional neural network with parameters Φ and the input x. The input x is a batch of images, half of which are real images from the training data and the other half are images generated by the Generator model G.

GAN model: Both the generator and discriminator models are combined to form a GAN model as shown in the figure below:

GAN Model Architecture

Model Training

A single training step consists of a batch of latent noise vectors being given to the generator as an input. The generator creates an image from each vector input. These images are assigned a label “0” (fake) when using them as input to discriminator. An equally sized batch of images is sampled from the real images (training images). These images are assigned a label “1” (real). Both the batches are given as input to the discriminator for classifying real and fake images. After the forward pass, two losses are computed independently: discriminator loss and generator loss. The discriminator loss is also the classification error. This classification error is used to update the weights of the discriminator model by a conventional backpropagation mechanism of the convolutional neural network.

The rate of updating the weights of the generator is dependent on the discriminator’s performance on the fake images. If the discriminator is able to identify the fake images, the weights are updated every epoch. But if the discriminator is unable to identify the fake images from the real images (meaning that the generator is able to create very real-like images), the generator weights are updated every few epochs. The generator loss is backpropagated to update the weights of generator neural network. During the generator backpropagation, the weights of the discriminator network are marked as untrainable. This ensures that the discriminator weights are not changed during generator backpropagation.

“The same generated images have two different labels. When the Discriminator is trained, the label is 0 and when the Generator is trained, the label is 1 “

The discriminator will output a probability of the image being real. When it outputs a probability value less than 0.4 (the threshold value), the image is classified as fake. In other words, the generated image is real with a confidence level of less than 0.4 or 40%. To train the generator to create more realistic images, the loss for the generator is computed based on how real the image is being classified by the discriminator. Therefore, when updating the weights of the generator, the generated images are marked as “1”. Thereby the error for generator is substantially large for generated images with low confidence level and it will update the weights to create more realistic images.

GAN Backpropagation for Discriminator and Generator

Loss functions for Discriminator and Generator Model

The loss functions for both, the discriminator and the generator can be derived from the binary cross-entropy equation:

Binary Cross-entropy Loss Equation

Discriminator Loss: The discriminator loss quantifies the ability of the discriminator model to distinguish between real and fake images. It compares the discriminator’s predictions on real images to an array of 1’s and its predictions on fakes images to an array of 0’s. The total loss for the discriminator is the cumulative loss for both the real and fake images.

Discriminator Loss for real and fake images

The objective of the discriminator is to maximize both the loss functions and hence the total loss for discriminator is defined as:

Total discriminator loss

Generator Loss: The generator loss quantifies the ability of the generator model to fool the discriminator into believing the fake images to be real images. Thus, if generator is trained well, the generated (fake) images are classified by discriminator as real (or 1), and therefore the output of discriminator for generated images is compared to an array of 1’s.

The generator on the other hand is trying to fool the discriminator and therefore it tries to minimize the loss function:

Total generator loss

The generator cannot directly affect the log(D(x)) term and therefore for the generator, minizing the loss is equivalent to:

Therefore the combined loss function of GAN is:

Loss of GAN Model

The GAN Algorithm is as described in the original Generative Adversarial Network paper by Ian Goodbfellow is follows:

Algorithm as described in original Generative Adversarial Nets

Cool Applications

The applications of GANs are numerous and countless. They are being used by researchers around the globe for numerous applications. Mostly they are developed for generating images similar to images in training data, image transformation like day-to-night photo or creating new anime characters. One of the interesting application is generating Music.

Generating human faces:

Realistic faces generated by GAN. Source

Image-to Image Translation

Image Transformation Source

Semantic-Image to Photo Translation

Semantic Image to Street View. Source

Anime Character Generation

Source

Generating Music

Melodies generating using different models of MidiNet. Source

Stay tuned for Keras implementation of GAN on MNIST dataset.

Thanks for reading!

References:

  1. https://arxiv.org/pdf/1406.2661.pdf
  2. https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-an-mnist-handwritten-digits-from-scratch-in-keras/
  3. https://medium.com/@jonathan_hui/gan-some-cool-applications-of-gans-4c9ecca35900
  4. https://machinelearningmastery.com/impressive-applications-of-generative-adversarial-networks/

--

--