Generative Adversarial Networks: An Intuitive Explanation & Some Keras Code
Introduction
Generative Adversarial Networks (GANs), first brought to light by Ian Goodfellow in 2014, introduced a novel way for training generative models.¹ Having two split models, a GAN is essentially a zero-sum game: a generator that generates fake examples is pitted against a discriminator that must distinguish examples as real or fake.
One analogy representing a GAN would be a fine art dealership. The generator would be a forger who creates fake paintings that try to pass for the real deal. The discriminator would be an authenticator who checks a myriad of fine art examples with the forger’s fake paintings in the mix. The authenticator, likewise to the discriminator, must label all of these paintings as real or fake. As time passes, the forger will improve more and more by filling up gaps that previously gave away fake paintings, and the authenticator will find more details that will help him confirm genuity. For GANs, this passing time that teaches the forger and authenticator is analogous to the training that will adjust the weights of the generator and discriminator.
GANs have a wide variety of applications, ranging anywhere from 3D object generation to video prediction.² As an emerging concept, GANs still have so much potential that needs to be explored. I encourage you to try to brainstorm your own creative applications after reading this tutorial because when it comes to artificial intelligence, the boundaries are limitless.
*This tutorial assumes knowledge of introductory linear algebra, basic concepts behind convolutional neural networks, and moderate proficiency in Python.
Discriminator
The discriminator is a neural network that can range from being simplistic to very deep. The model typically contains convolutions, taking an entire image and narrowing it down to a single decision node. The beauty of neural networks is that their structure is entirely based on the user’s preference and the decisions on what the hyperparameters should be.
Taking in images as input, the discriminator will ultimately give the probability of the authenticity of the image it is given to label.
Generator
The generator is a neural network that is sort of antithetical to the discriminator. Taking in a latent vector of noise, the generator transforms a relatively small amount of randomized values into an entire image ready to be labeled by the discriminator. The generator typically contains deconvolutions to transform data; however, likewise to the discriminator, the framework of the generator is entirely dependent on the user’s preference and hyperparameters.
Loss Function
Now that we have a clearer idea of the functionality of the generator and the discriminator, it is key to know how we adjust their respective weights in order to produce meaningful data. At first, if an untrained GAN were to run, the generator would produce images similar to static and the discriminator would aimlessly label images.
Similar to common stand alone convolutional neural networks, GANs use backpropagation to adjust weights. From finding the gradient of the loss function, we know what direction and degree magnitude to adjust each weight in order to start producing meaningful results. But what is the loss function? Since we have two sub-models, we have two different loss functions to account for. To begin deriving the loss functions, we have to first think about the objectives of both the generator and the discriminator and how we’re going to penalize.
The discriminator has two factors in its loss function. Given a scale from 0 to 1 with 0 representing a fake label and 1 representing a real label, the discriminator has two objectives based on the image input: first, if a fake image is inputted, the discriminator’s output should be minimized to be as close to 0 as possible and second, if a real image is inputted, the discriminator’s output should be maximized to be as close to 1 as possible.³ Adding both loss functions together gives the discriminator’s loss function in its totality.
The generator must try to fool the discriminator and hence the generator’s loss function is simply the opposite of the discriminator’s loss function. The generator works to maximize the discriminator’s output as close as possible to 1 when a fake image is inputted.
The zero-sum game is shown by the opposite goals of the discriminator and generator. The generator wants the discriminator’s output of a fake image, D(G(ϵ)), to be 1 while the discriminator wants that output to be 0. Thus, at equilibrium, we are left with a value that is ultimately in between those two goals. Depending on the relative strengths of each model at the end of training, that value could be closer to 0 or 1.
Take a moment to fully understand the purpose of the equations and the steps behind the derivations. Knowing the backend of the math will significantly help conceptual knowledge. Now, we can derive a minimax equation combining both loss functions together.
At the end of training, there should be a noticeable distribution of real images and an overlapping distribution representing fake images. The overlapping explains the discriminator’s difficulty in labeling and the generator’s success once the generator has been trained.
Coding an MNIST-trained GAN
Now it’s time to try coding our own GAN. This walkthrough will use the MNIST database, containing 60,000 training examples and 10,000 test examples of handwritten digits. Our goal is to have the generator drastically improve throughout epochs and eventually produce recognizable images of numbers.
Before any code can be written, we need imports so we have the MNIST dataset and key functions that will enable us to build, train, and test our GAN. There are extra imported functions to allow you to experiment with different functions and hyperparameters to see the effects on your results.
import numpy as npimport keras as kerfrom tqdm import tqdmfrom keras.models import Model, Sequentialfrom keras.layers import Dense, Dropout, Flatten, Reshape, LeakyReLU, Conv2DTranspose, Inputfrom keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2Dfrom keras.layers.normalization import BatchNormalizationfrom keras.utils import np_utilsfrom keras.optimizers import Adam, SGD, RMSpropfrom keras import backend as Kfrom keras import initializersfrom numpy import expand_dims, ones, zeros, vstackfrom numpy.random import rand, randint, randnimport matplotlib.pyplot as plt#import MNIST databasefrom keras.datasets import mnist#use theano for tensor shape, try "image_dim_ordering" for updated APIK.set_image_dim_ordering('th')
Next, we need to set up a few things that will make future endeavors for our models more compatible. We need to prepare our MNIST dataset so that it can be read by the discriminator. This entails loading and reshaping the handwritten digit images. Additionally, to create reproducible results, we want to have our latent dimension constant as well as seed the random number generator. With reproducible results, finding the roles behind different hyperparameters will be easier in the future.
# Load MNIST data(X_train, y_train), (X_test, y_test) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5)/127.5X_train = X_train.reshape(60000, 784)#for reproducible resultsseed = 1np.random.seed(seed)latent_dim = 100
Now it’s time to start building our GAN, starting with the discriminator. The discriminator, as mentioned before, can be either very simple or deep, containing numerous convolutions. For this tutorial, we’ll stick to just using dense layers, leaky ReLUs, and dropouts. Leaky ReLUs avoid the zero slope issue in a normal ReLU and have been noted for the increase in training speed for GANs. Dropout will protect our discriminator from overfitting of the training examples.
discriminator = Sequential()discriminator.add(Dense(1024,input_dim=784))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(512))discriminator.add(LeakyReLU(0.2))discriminator.add(Dropout(0.3))discriminator.add(Dense(256))discriminator.add(LeakyReLU(0.2))discriminator.add(Dense(units=1, activation='sigmoid'))discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
Following the discriminator, the generator will take in our latent vector dimension so it is compatible to the size of the noise generated later for training. Moreover, since the generator directly feed to the discriminator, we must make sure that our end number of units is the same as the discriminator’s start number of units. If the key does not match the lock, the code will have a compiler error. To make the generator mode complicated, try adding some upsampling layers or even deconvolutional layers.
generator=Sequential()generator.add(Dense(256,input_dim=latent_dim))generator.add(LeakyReLU(0.2))generator.add(Dense(512))generator.add(LeakyReLU(0.2))generator.add(Dense(1024))generator.add(LeakyReLU(0.2))generator.add(Dense(784, activation='tanh'))generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
Let’s combine our discriminator and generator to create our GAN. Additionally, we need to set up records of respective losses per epoch.
discriminator.trainable = Falsegan = Sequential()gan.add(generator)gan.add(discriminator)gan.compile(loss='binary_crossentropy', optimizer='adam')discrim_losses = []gen_losses = []
After our generator has produced images, we need a place to store and view them. Create a folder of your choice; I defaulted the name to images but change it to your preference.
#print generated imagesdef printGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):noise = np.random.normal(0, 1, size=[examples, latent_dim])generated_images = generator.predict(noise)generated_images = generated_images.reshape(examples, 28, 28)plt.figure(figsize=figsize)#format array of picturesfor i in range(generated_images.shape[0]):plt.subplot(dim[0], dim[1], i+1)plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')plt.axis('off')plt.tight_layout()plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)
Let’s also plot our losses neatly on a graph. The viewage and storage will be the same as the printed generated images.
#plot each batch's discriminator and generator lossesdef plotLoss(epoch):plt.figure(figsize=(9, 6))plt.legend()plt.plot(discrim_losses, label='Discriminator loss')plt.plot(gen_losses, label='Generator loss')plt.xlabel('Epoch #')plt.ylabel('Loss')plt.savefig('images/gan_loss_epoch_%d.png' % epoch)
Lastly, let’s make a training function and run our program. Using tqdm, we can see a live status bar on the training and which epoch we are on. Likewise to the losses graph and our generated images, we will store the parameters of our GAN inside a folder. Once again, feel free to change the destination name. We will save generated images every 25 epochs and have a GAN losses plot at the end of training. To increase training, simply increase the number of epochs.
#train gandef train(epochs=1, batch_size=128):batch_count = X_train.shape[0] / batch_sizeprint('Epochs:', epochs)print('Batch size:', batch_size)print('Batches per epoch:', batch_count)for e in range(1, epochs+1):print('Epoch %d' % e)for _ in tqdm(range(int(batch_count))):#create randomized noise and imagesnoise = np.random.normal(0, 1, size=[batch_size, latent_dim])image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]#generate fake imagesgenerated_images = generator.predict(noise)X = np.concatenate([image_batch, generated_images])#label fake or notY = np.zeros(2*batch_size)Y[:batch_size] = 0.9#train discriminator and generatordiscriminator.trainable = Trued_loss = discriminator.train_on_batch(X, Y)noise = np.random.normal(0, 1, size=[batch_size, latent_dim])y_gen = np.ones(batch_size)discriminator.trainable = Falseg_loss = gan.train_on_batch(noise, y_gen)#store epoch lossdiscrim_losses.append(d_loss)gen_losses.append(g_loss)#print images and save weights per arrayif e == 1 or e % 25 == 0:printGeneratedImages(e)generator.save('model_parameters/gan_generator_epoch_%d.h5' % e)discriminator.save('model_parameters/gan_discriminator_epoch_%d.h5' % e)#plot lossesplotLoss(e)#runif __name__ == '__main__':train(200, 128)
Check out the full code with image examples: https://github.com/kev-guo/MNIST-GAN
Acknowledgments
This paper was written as part of a RETINA-AI Health Inc. summer internship.
References
[1]:Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville: “Generative Adversarial Networks”, 2014; arXiv:1406.2661.
[2]:Shin, Minchul, Curated list of awesome GAN applications and demo, (2018), GitHub repository, https://github.com/nashory/gans-awesome-applications
[3]:Madhu Sanjeevi. (January 14 2019). Ch:14 Generative Adversarial Networks (GAN’s) with Math. https://medium.com/deep-math-machine-learning-ai/ch-14-general-adversarial-networks-gans-with-math-1318faf46b43