GAN: Build using Tensorflow
Let the Code Begin!
Introduction
In the previous section, we discussed the Why What & How of GAN. Now we will be using Tensorflow to code a simple GAN
Input
We will train our GAN on the MNIST database
Output
At the end of the article, you will be able to Generate Handwritten Digit using your trained Generator
You can find the Jupyter notebook on my Github
Data Flow
Time to create our Generative Adversarial Network. A GAN has two components:
- Discriminator: Differentiate between real & fake images (Data)
- Generator: Creates fake images to fool the Discriminator
Let’s look at the architecture of GAN
- A real image X is fed to the Discriminator, D
- D predicts whether it is real or fake
- Next, a latent Vector Z is created using Random noise
- Z is fed into the Generator, G
- G creates a fake image
- This fake image is sent to the Discriminator, D
- D predicts whether this image is real or fake
This is how data flows within the network. The loss during training is used to train the model using Gradient Decent & Back-propagation
Methods
There are two ways we can go about the training:
- Using Gradient Tape
- Using the fit function
Both are good options. In both cases, we need to modify something called the train_step function. This function is responsible for training the model. I just prefer the log output of the fit function.
In the GAN Git repository, I have used both methods to train the model. You can choose whichever you prefer. But here we will use the fit function.
Import Data
Let’s import the MNIST database for our training. The MNIST database is a collection of handwritten digits. Each image is 28x28 pixels.
Code
Normalization
We will normalize our image between [-1, 1]
Code
Model
Let’s use a simple 3 layered Neural Network to construct both the components
Code: Discriminator
- Input: 28x28 image is flattened to create a 784 long input vector
- Activation: Leaky_relu to add non-linearity
- Output Layer: Sigmoid Activation is used for the prediction probability
Code: Generator
- Input: A latent vector of size 10
- Activation: Relu is used for hidden layers
- Output: A 28x28 image which we reshape into 28x28x1
Ideally, we should use tanh activation on the Generator’s output layer as we have normalized our model to be between [-1, 1]. But this resulted in a problem called Mode collapse due to the simplicity of our model.
Mode collapse is a condition where the model instead of capturing the variation in the real data, just learns the most common symbols/data points. Here is an example
Loss Functions
The loss function used in the GAN paper is
😨😱😭😭😭‼⁉
Don’t run! Let me make it intuitive and take it one term at a time
- x: Real Image fed to Discriminator
- z: Latent Vector fed to Generator
- D: Discriminator Model
- G: Generator Model
- G(z): Fake image generated after feeding z to G
- D(x): Probability given by D {0: Fake, 1: Real} for real image
- D(G(z)): Probability given by D {0: Fake, 1: Real} for Fake image
- E: Expectation or simply averaging loss for a batch
- log: Old school & simple log
It’s okay if you do not understand it all in the first go. Just take the big picture for now. It will get easier. If you have any doubts hit the comment box.
Log(D(x)) will be 0 if D(x) = 1 i.e. real image is correctly classified & negative infinity if D(x) = 0 i.e. Real Image is considered a fake. Similarly, Log(1-D(G(x))) is 0 when the Discriminator correctly classifies the fake image G(x).
On the other hand, the goal of the Generator is to fool the Discriminator. It tries to maximize the value of D(G(z)). This is how both models are trained.
Training
We have two different models, the Discriminator (D) & the Generator (G). But we need to train them together. We will create a custom model (class) to combine these two called, GAN. It will contain 3 functions:
- Constructor: Add our G & D models
- compile: Add our compiler, loss function & optimizer used in training
- train_step: Here we calculate loss & gradient as well as backpropagation
Code
Saving Generated Image
While training we are interested in saving the output of the Generator. We will be using the below function for saving the file
Code
Now, we want these images every, say 50, epochs. For this, we will be using something called callbacks. A callback is a class that allows us to call functions after each epoch (among other scenarios)
Code
Hyper Parameters
Setting up the number of epochs, batch size, etc.
Code
Train
Finally, we can train the model & save it
Code
Results
These are the results after 5000 epochs of training
During the initial stages, the Generator creates noise but later in training, it starts creating something resembling digits.
Tips
- This model requires a decent GPU for fast training
- If facing resource constraints please use Google Colab to run your experiment where it will take around 3 hours to train the model with current hyperparameters.
- When using Google Colab just make sure the page is not idle for long else the training will stop
- Reduce the number of epochs or batch size as needed
This is it, folks! Now you know how to generate synthetic digits.
Have fun learning & experimenting!
If you have any doubts or feedback please comment. Or you can get in touch with me on LinkedIn or Twitter.