Generative Adversarial Networks: An Intuitive Explanation & Some Keras Code

Kevin Y. Guo
A.I./Machine Learning Tutorials
9 min readJan 28, 2020

Introduction

Generative Adversarial Networks (GANs), first brought to light by Ian Goodfellow in 2014, introduced a novel way for training generative models.¹ Having two split models, a GAN is essentially a zero-sum game: a generator that generates fake examples is pitted against a discriminator that must distinguish examples as real or fake.

Fig 1. The framework of a basic GAN. The discriminator is trained by image batches of real images before it is pitted against the generator. The discriminator then takes in image batches that are a randomized mixture of real and fake images and labels them. The generator takes in a latent vector, or noise, and attempts to generate fake images from those randomized values.

One analogy representing a GAN would be a fine art dealership. The generator would be a forger who creates fake paintings that try to pass for the real deal. The discriminator would be an authenticator who checks a myriad of fine art examples with the forger’s fake paintings in the mix. The authenticator, likewise to the discriminator, must label all of these paintings as real or fake. As time passes, the forger will improve more and more by filling up gaps that previously gave away fake paintings, and the authenticator will find more details that will help him confirm genuity. For GANs, this passing time that teaches the forger and authenticator is analogous to the training that will adjust the weights of the generator and discriminator.

GANs have a wide variety of applications, ranging anywhere from 3D object generation to video prediction.² As an emerging concept, GANs still have so much potential that needs to be explored. I encourage you to try to brainstorm your own creative applications after reading this tutorial because when it comes to artificial intelligence, the boundaries are limitless.

*This tutorial assumes knowledge of introductory linear algebra, basic concepts behind convolutional neural networks, and moderate proficiency in Python.

Discriminator

The discriminator is a neural network that can range from being simplistic to very deep. The model typically contains convolutions, taking an entire image and narrowing it down to a single decision node. The beauty of neural networks is that their structure is entirely based on the user’s preference and the decisions on what the hyperparameters should be.

Fig 2. Example of a very simple discriminator.

Taking in images as input, the discriminator will ultimately give the probability of the authenticity of the image it is given to label.

Generator

The generator is a neural network that is sort of antithetical to the discriminator. Taking in a latent vector of noise, the generator transforms a relatively small amount of randomized values into an entire image ready to be labeled by the discriminator. The generator typically contains deconvolutions to transform data; however, likewise to the discriminator, the framework of the generator is entirely dependent on the user’s preference and hyperparameters.

Fig 3. Example of a very simple generator.

Loss Function

Now that we have a clearer idea of the functionality of the generator and the discriminator, it is key to know how we adjust their respective weights in order to produce meaningful data. At first, if an untrained GAN were to run, the generator would produce images similar to static and the discriminator would aimlessly label images.

Fig 4. The images produced by a generator after only 1 epoch of training. After very little training, we can see that the images are slowly shifting from static and are taking form. However, there are still many epochs to go before numbers can be confidently identified.

Similar to common stand alone convolutional neural networks, GANs use backpropagation to adjust weights. From finding the gradient of the loss function, we know what direction and degree magnitude to adjust each weight in order to start producing meaningful results. But what is the loss function? Since we have two sub-models, we have two different loss functions to account for. To begin deriving the loss functions, we have to first think about the objectives of both the generator and the discriminator and how we’re going to penalize.

The discriminator has two factors in its loss function. Given a scale from 0 to 1 with 0 representing a fake label and 1 representing a real label, the discriminator has two objectives based on the image input: first, if a fake image is inputted, the discriminator’s output should be minimized to be as close to 0 as possible and second, if a real image is inputted, the discriminator’s output should be maximized to be as close to 1 as possible.³ Adding both loss functions together gives the discriminator’s loss function in its totality.

Fig 5. The derivation of the discriminator loss function. The first half of the discriminator loss function attempts to improve real image recognition while the other half attempts to improve fake image recognition.

The generator must try to fool the discriminator and hence the generator’s loss function is simply the opposite of the discriminator’s loss function. The generator works to maximize the discriminator’s output as close as possible to 1 when a fake image is inputted.

Fig 6. The derivation of the generator loss function. This loss function can be considered just the negative of the discriminator loss function. This is why a GAN can be classified as a zero-sum game.

The zero-sum game is shown by the opposite goals of the discriminator and generator. The generator wants the discriminator’s output of a fake image, D(G(ϵ)), to be 1 while the discriminator wants that output to be 0. Thus, at equilibrium, we are left with a value that is ultimately in between those two goals. Depending on the relative strengths of each model at the end of training, that value could be closer to 0 or 1.

Take a moment to fully understand the purpose of the equations and the steps behind the derivations. Knowing the backend of the math will significantly help conceptual knowledge. Now, we can derive a minimax equation combining both loss functions together.

Fig 7. Minimax equation specifically formulated for a GAN’s zero-sum game.

At the end of training, there should be a noticeable distribution of real images and an overlapping distribution representing fake images. The overlapping explains the discriminator’s difficulty in labeling and the generator’s success once the generator has been trained.

Fig 8. The plot of real images (green) vs. fake images (red) at the end of training. Both images have very similar distributions which explains why the discriminator has such a difficult time differentiating at the end.

Coding an MNIST-trained GAN

Now it’s time to try coding our own GAN. This walkthrough will use the MNIST database, containing 60,000 training examples and 10,000 test examples of handwritten digits. Our goal is to have the generator drastically improve throughout epochs and eventually produce recognizable images of numbers.

Before any code can be written, we need imports so we have the MNIST dataset and key functions that will enable us to build, train, and test our GAN. There are extra imported functions to allow you to experiment with different functions and hyperparameters to see the effects on your results.

import numpy as npimport keras as kerfrom tqdm import tqdmfrom keras.models import Model, Sequentialfrom keras.layers import Dense, Dropout, Flatten, Reshape, LeakyReLU, Conv2DTranspose, Inputfrom keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2Dfrom keras.layers.normalization import BatchNormalizationfrom keras.utils import np_utilsfrom keras.optimizers import Adam, SGD, RMSpropfrom keras import backend as Kfrom keras import initializersfrom numpy import expand_dims, ones, zeros, vstackfrom numpy.random import rand, randint, randnimport matplotlib.pyplot as plt#import MNIST databasefrom keras.datasets import mnist#use theano for tensor shape, try "image_dim_ordering" for updated APIK.set_image_dim_ordering('th')

Next, we need to set up a few things that will make future endeavors for our models more compatible. We need to prepare our MNIST dataset so that it can be read by the discriminator. This entails loading and reshaping the handwritten digit images. Additionally, to create reproducible results, we want to have our latent dimension constant as well as seed the random number generator. With reproducible results, finding the roles behind different hyperparameters will be easier in the future.

# Load MNIST data(X_train, y_train), (X_test, y_test) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5)/127.5X_train = X_train.reshape(60000, 784)#for reproducible resultsseed = 1np.random.seed(seed)latent_dim = 100

Now it’s time to start building our GAN, starting with the discriminator. The discriminator, as mentioned before, can be either very simple or deep, containing numerous convolutions. For this tutorial, we’ll stick to just using dense layers, leaky ReLUs, and dropouts. Leaky ReLUs avoid the zero slope issue in a normal ReLU and have been noted for the increase in training speed for GANs. Dropout will protect our discriminator from overfitting of the training examples.

discriminator = Sequential()discriminator.add(Dense(1024,input_dim=784))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(512))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(256))discriminator.add(LeakyReLU(0.2))discriminator.add(Dense(units=1, activation='sigmoid'))discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

Following the discriminator, the generator will take in our latent vector dimension so it is compatible to the size of the noise generated later for training. Moreover, since the generator directly feed to the discriminator, we must make sure that our end number of units is the same as the discriminator’s start number of units. If the key does not match the lock, the code will have a compiler error. To make the generator mode complicated, try adding some upsampling layers or even deconvolutional layers.

generator=Sequential()generator.add(Dense(256,input_dim=latent_dim))generator.add(LeakyReLU(0.2))generator.add(Dense(512))generator.add(LeakyReLU(0.2))generator.add(Dense(1024))generator.add(LeakyReLU(0.2))generator.add(Dense(784, activation='tanh'))generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

Let’s combine our discriminator and generator to create our GAN. Additionally, we need to set up records of respective losses per epoch.

discriminator.trainable = Falsegan = Sequential()gan.add(generator)gan.add(discriminator)gan.compile(loss='binary_crossentropy', optimizer='adam')discrim_losses = []gen_losses = []

After our generator has produced images, we need a place to store and view them. Create a folder of your choice; I defaulted the name to images but change it to your preference.

#print generated imagesdef printGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):noise = np.random.normal(0, 1, size=[examples, latent_dim])generated_images = generator.predict(noise)generated_images = generated_images.reshape(examples, 28, 28)plt.figure(figsize=figsize)#format array of picturesfor i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')plt.axis('off')plt.tight_layout()plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

Let’s also plot our losses neatly on a graph. The viewage and storage will be the same as the printed generated images.

#plot each batch's discriminator and generator lossesdef plotLoss(epoch):plt.figure(figsize=(9, 6))plt.legend()plt.plot(discrim_losses, label='Discriminator loss')plt.plot(gen_losses, label='Generator loss')plt.xlabel('Epoch #')plt.ylabel('Loss')plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

Lastly, let’s make a training function and run our program. Using tqdm, we can see a live status bar on the training and which epoch we are on. Likewise to the losses graph and our generated images, we will store the parameters of our GAN inside a folder. Once again, feel free to change the destination name. We will save generated images every 25 epochs and have a GAN losses plot at the end of training. To increase training, simply increase the number of epochs.

#train gandef train(epochs=1, batch_size=128):batch_count = X_train.shape[0] / batch_sizeprint('Epochs:', epochs)print('Batch size:', batch_size)print('Batches per epoch:', batch_count)for e in range(1, epochs+1):print('Epoch %d' % e)for _ in tqdm(range(int(batch_count))):#create randomized noise and imagesnoise = np.random.normal(0, 1, size=[batch_size, latent_dim])image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]#generate fake imagesgenerated_images = generator.predict(noise)X = np.concatenate([image_batch, generated_images])#label fake or notY = np.zeros(2*batch_size)Y[:batch_size] = 0.9#train discriminator and generatordiscriminator.trainable = Trued_loss = discriminator.train_on_batch(X, Y)noise = np.random.normal(0, 1, size=[batch_size, latent_dim])y_gen = np.ones(batch_size)discriminator.trainable = Falseg_loss = gan.train_on_batch(noise, y_gen)#store epoch lossdiscrim_losses.append(d_loss)gen_losses.append(g_loss)#print images and save weights per arrayif e == 1 or e % 25 == 0:printGeneratedImages(e)generator.save('model_parameters/gan_generator_epoch_%d.h5' % e)discriminator.save('model_parameters/gan_discriminator_epoch_%d.h5' % e)#plot lossesplotLoss(e)#runif __name__ == '__main__':train(200, 128)
Fig 9. Epoch 200 generator output.
Fig 10. Loss plot.

Check out the full code with image examples: https://github.com/kev-guo/MNIST-GAN

Acknowledgments

This paper was written as part of a RETINA-AI Health Inc. summer internship.

References

[1]:Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville: “Generative Adversarial Networks”, 2014; arXiv:1406.2661.

[2]:Shin, Minchul, Curated list of awesome GAN applications and demo, (2018), GitHub repository, https://github.com/nashory/gans-awesome-applications

[3]:Madhu Sanjeevi. (January 14 2019). Ch:14 Generative Adversarial Networks (GAN’s) with Math. https://medium.com/deep-math-machine-learning-ai/ch-14-general-adversarial-networks-gans-with-math-1318faf46b43

--

--