Generative Adversarial Network
In this blog, we’ll be building a Generative Adversarial Network (GAN) trained on the MNIST(Handwritten Digit) dataset. From this, we’ll be able to generate new handwritten digits. GAN’s were introduced by Ian Goodfellow et al. in 2014. Since then, GANs have exploded in popularity. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics.
In Machine Learning, GANs are a class of artificial intelligence algorithms used in unsupervised machine learning. The idea behind GANs is that you have two networks, a generator G
, and a discriminator D
, competing with each other in a zero-sum game framework.
The generator makes fake data to pass to the discriminator, the technically generative network learns to map from a latent space z
to a particular data distribution of interest.
The discriminator also sees real data and predicts if the data it's received is real
or fake
by discriminates between instances from the real data distribution
and candidates produced by the generator
.
The generator is trained to fool the discriminator, and it wants to output data that looks as close as possible to real data, by producing novel synthesized instances that appear to have come from the real data distribution. Moreover, the discriminator is trained to figure out which data is real and which is fake. What ends up happening is that the generator learns to make data that is indistinguishable from real data to the discriminator.Backpropagation
is applied in both networks so that the generator produces better images, while the discriminator becomes more skilled at flagging synthetic images. The generator is typically a deconvolutional neural network, and the discriminator is a convolutional neural network.
Import Data and Libraries
Model Inputs
First, we need to create the inputs for our TensorFlow graph. We need two inputs, one for the generator and one for the discriminator. Here we’ll call the generator input inputs_z
and the discriminator input inputs_real
.
Generator
The input to the generator is a series of randomly generated numbers called latent space
( Latent variables are variables that are not directly observed but are rather inferred through a mathematical model from other variables that are observed).
The generator is a neural network which contains a hidden layer with Leaky ReLU
activation and tanh
output. It tries to map latent space to real dataset images using the backpropagation algorithm. Once trained, the generator can produce digit images from latent samples.
Discriminator
The discriminator is a classifier trained using supervised learning. It classifies whether an image is real (1)
or Fake (0)
. We train the discriminator using both the real dataset images and the images generated by the generator.
If the input image is from the MNIST dataset, the discriminator should classify it as real
. If the input image is from the generator, the discriminator should classify it as fake
. The discriminator network is almost the same as the generator network, except that we're using a sigmoid
output layer.
Hyperparameters
Define Network
To building the network from the functions defined above, we connect the generator and the discriminator to produce a GAN.
- First is to get our inputs,
input_real, input_z
frommodel_inputs
using the sizes of the input and z. - Then, we’ll create the generator. This builds the generator with the appropriate
input
andoutput
sizes. - Then the discriminators. We’ll build two of them, one for
real data
and one forfake data
.
We will set reuse=True
, since we want the weights to be the same for both real
and fake
data, we need to reuse the variables.
Define Losses
Now we need to calculate the losses for both generator and discriminator, which is a little tricky.
In discriminator, the total loss is the sum of the losses for real and fake images.
d_loss = d_loss_real + d_loss_fake
- For the real image logits,
d_logits_real
and thelabels
, we want them to be all since these are all real images. To help the discriminator generalize better, the labels are reduced a bit from 1.0 to 0.9. - For the fake data is similar, the logits are
d_logits_fake
, these fake logits are used withlabels
of allzeros
.
Finally, the generator losses are using, using d_logits_fake
, the fake image logits. But, now they labels
are all ones
. The generator is trying to fool the discriminator, so it wants to discriminator to output ones for fake images.
Optimizers
We are going to create two optimizers, one for generator and one for the discriminator. To update the generator and discriminator variables separately, we need a list of variables specific to the optimizer. To get all the trainable variables, we use tf.trainable_variables()
.
We have used a variable scope to start all of our generator variable names with generator
, and all the variables in the discriminator start with discriminator
. Now, we just need to iterate through the list from tf.trainable_variables()
and keep variables to start with generator
in g_vars
and discriminator
in d_vars
.
Network Training
Training Loss Visualization
In training, we have created a list of losses for generator
and discriminatort t
o check how well our GAN is trained. Now we can visualize using train_losses.pkl
.
Generate New Sample
To generate a new sample, we are going to load
the saved module, initialize
our session and pass random noise
to the generator.
Conclusion
I hope that in this blog, you have understood the underlying architecture of a new technique called Generative Adversarial Networks. GANs are one of the few successful and efficient techniques in unsupervised machine learning. GANs have been successfully applied in many fields including music generator, interactive image editing, 3D shape estimation, drug discovery, create books which seem written by authentic authors and much more. Some of the success stories include.
- Adobe Research is using GAN for designing products, generating novel imagery from scratch based on users’ scribbles.
- Facebook has built a real-time style transfer model running on mobile devices.
Source Code: https://github.com/llabhishekll/GAN-implementation