Coding your first GAN algorithm with Keras

Brijesh Modasara
Analytics Vidhya
Published in
7 min readMay 13, 2020

Get some coffee, put on the headphones and let’s get started with coding your first GAN algorithm! If you are unfamiliar with GANs or how GANs work then, check out my article here.

Get ready for coding. Source

We will start straight away without wasting any time. The gan algorithm consists of following components:

  1. Import necessary Libraries & Dataset
  2. Sample Real Images
  3. Generate Fake Images
  4. Creating Models: Generator, Discriminator, GAN
  5. Training
  6. Model Evaluation
  7. Predictions

1. Import necessary Libraries & Dataset

We will start with importing essential libraries. The neural network is created using keras API with tensorflow backend. I am using keras version ‘2.2.5’ and tensorflow version ‘1.15.0’.

After importing the libraries, we need to import the training dataset. We will be using MNIST dataset to train the model. The MNIST dataset is a collection of 70,000, 28 x 28 sized images of handwritten digits in the range of 0 to 9. This dataset is particularly useful for beginners as the dataset is publicly available, fully balanced and very little pre-processing is required. All the images are in grey scale, so the pixel values are in the range of 0 to 255.

Sample images from the MNIST dataset

The MNIST dataset is already available within the keras library and we will only need to load the dataset and assign it to the respective variables. The only pre-processing required is to convert the unsigned integer pixel values from [0, 255] to [0, 1]. Here is the code to import and pre-process the data:

2. Sample Real Images

To sample the real images, we will be creating a function called sample_real_images. We will select a batch of images from the entire dataset and label each image as “1”. The inputs to the function will be the entire dataset (x_train) and the batch_size. Please note that we will be using label “1” for real images and “0” for fake images. Therefore the labels for sampled real images are set to “1”.

3. Generate Fake Images

Generating fake images is little tricker and longer. Every image is generated from a vector of n-elements sampled from a Normal Gaussian Distribution. Correspondingly, to generate a batch of images, we need a batch of vectors. To accomplish this, we will create two functions: generate_latent_points to generate a batch of vectors and generate_fake_images to generate images from these vectors. The inputs to the generate_latent_points function will be the latent_dim (length of the vector) and the batch_size.

We will now define a function called generate_fake_images which will call the above function to generate a batch of vectors and then generate an image corresponding to each one. The inputs to this function will be a model (generator), length_of_vector (latent_dim) and the batch_size. As these are fake images, the labels for all of these images are set to “0”.

4. Creating Models

Let’s just recall the GAN model architecture before moving forward. As shown in the diagram below, GAN model comprises of a Generator model and a Discriminator model in concatenation.

GAN Model Architecture

So we will first create the standalone models of Discriminator and Generator and then combine them to make a complete GAN model. There are other ways to create but I found this to be very simple to understand and implement.

The Discriminator model is a binary classification convolutional neural network. It will be tasked with differentiating between the real and the fake images. We will take two equally-sized mini-batches of images, one of real images and one of generated images, and give it as an input to the Discriminator. The output of the Discriminator model will be the probabilities of an image belonging to real (“1”) and fake (“0”) classes.

The following parameters have been proven to give better performance results for discriminator:

  1. LeakyRELU activation function for the hidden layers with a slope of 0.2
  2. Dropout regularization with neuron dropout rate of 0.3~0.4
  3. Adam optimizer with a learning rate of 0.0002 and momentum of 0.5

Now let’s create a Generator model. For this example, we will be using a 100-element latent vector to generate an image. The length of 100 is arbitrary, we are free to choose 50, 200, 500 or even 1000 as the length of the vector.

In a convolutional neural network, an image is transformed into multiple feature maps, followed by pooling, flattening processes and finally followed by a couple of fully connected Dense layers. The output of the network will be a vector of class scores for the corresponding image. As the generator will create an image from a vector, it will follow exact reverse process of a convolutional neural network.

  1. The 100-element vector will be connected to a Dense layer with enough nodes to create multiple feature maps upon reshaping.
  2. Deconvolution or Upsampling will be done using Conv2DTranspose layer. We will be using a stride of (2x2) to double the image dimension on each side and kernel size that is a multiple to stride (one of the standard practices in CNN).
  3. LeakyRELU activation function for the hidden layers with a slope of 0.2
  4. The output layer activation function will be sigmoid to have pixel values between 0 and 1.
  5. Using BatchNormalization is optional but it is recommended to enhance performance.

It is important to understand why we do not compile the generator model. This is because the generator model is not trained directly like the discriminator model. The performance of the generator model is determined by the ability of discriminator in classifying the fake images to be real. We will have a better understanding once we dive into the training of the models.

We will now create a complete GAN model with create_gan_model function. This composite model consists of generator model (batch of fake images) as input to discriminator model. We set the trainable property of weights of discriminator to False. This is done in order to avoid changing the weights of the discriminator when the generator is trained for fake images.

5. Training

The figure below shows two distinct training paths. The discriminator model is trained and updated in a standalone fashion. Earlier, when we set the trainable property of weights as False, it is only applicable to the discriminator model which is a part of GAN model. It does not affect the standalone model in anyway.

For training, we will create a function called train which will train both the models successively. We will train the model for 100 epochs and evaluate its performance at an interval of 10 epochs. Once the training starts, we can see that after 30 epochs, the generator model is able to generate fairly realistic images.

6. Model Evaluation

The convergence of GAN models are not fully understood and defined. The images generated by the generator model need to be subjectively evaluated to judge the performance of the model. The most common method involves periodically checking discriminator accuracy, saving the generator model and the images generated by it. These generated images then need to be subjectively evaluated by the us. For this purpose we have three functions:

  1. model_evaluation: We will check accuracy of discriminator model on the generated and the real images. We will also save the generator and discriminator models at regular intervals for further use and evaluation.
  2. save_image: to save the generated images

3. create_animation: create an animation from generated images. This will help in visualizing the evolution of generator model.

7. Predictions

Once the training has been finished, it’s time to predict and check the results. We will generate 20 images using the last saved model. The results below show that after 100 epochs of training, the model is able to generate fairly good images similar to the ones in the training dataset.

Number generated using GAN generator model

The complete code can be found on my github here.

Next Steps: We can play around with different hyperparameter like length of the input vector, batch size, number of epochs, number of layers etc to see whether the model can generate better results.

I hope you found something useful in this article. If you liked the article don’t forget to give it a few claps!

References:

  1. https://arxiv.org/pdf/1406.2661.pdf
  2. https://www.tensorflow.org/tutorials/generative/dcgan
  3. For more detailed understanding, please refer this tutorial: https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-an-mnist-handwritten-digits-from-scratch-in-keras/
  4. https://keras.io/api/datasets/mnist/

--

--

Brijesh Modasara
Analytics Vidhya

Automotive Engineer | Deep Learning Enthusiast | Traveller