A Brief Introduction To GANs
GANs, or Generative Adversarial Networks, are a type of neural network architecture that allow neural networks to generate data. In the past few years, they’ve become one of the hottest subfields in deep learning, going from generating fuzzy images of digits to photorealistic images of faces.
Variants of GANs have now done insane stuff, like converting images of zebras to horses and vice versa.
I found GANs fascinating, and in an effort to understand them better, I thought that I’d write this article, and in the process of explaining the math and code behind them, understand them better myself.
Here’s a link to a github repo I made for GAN resources:
Contribute to sarvasvkulpati/Awesome-GAN-Resources development by creating an account on GitHub.
So how do GANs work?
GANs learn a probability distribution of a dataset by pitting two neural networks against each other.
Here’s a great article that explains probability distributions and other concepts for those who aren’t familiar with them:
Probability concepts explained: probability distributions (introduction part 3)
Explaining the fundamentals of probability distributions
One model, the generator, acts akin to a painting forger. It tries to create images that look very similar to the dataset. The other model, the discriminator, acts like the police, and tries to detect whether the images generated were fake or not.
What basically happens, is that the forger keeps getting better at making fakes, while the police keep getting better at detecting fakes. Effectively, these two models keep trying to beat each other, until after many iterations, the generator creates images indistinguishable from the real dataset.
Training generative adversarial networks involve two objectives:
- The discriminator maximizes the probability of assigning the correct label to both training examples and images generated by the generator. I.e the policeman becomes better at differentiating between fakes and real paintings.
- The generator minimizes the probability that the discriminator can predict that what it generates is fake. I.e the generator becomes better at creating fakes
Let’s try and encode these two ideas into a program.
We’ll be following this code in this tutorial
Personal project to understand GANs better. Contribute to sarvasvkulpati/intro_to_gans development by creating an…
GANs need a dataset to use, so for this tutorial, we’ll be using the classic hello world to machine learning — MNIST, a dataset of handwritten digits.
The generator also needs random input vectors to generate images, and for this, we’ll be using numpy
The GAN Function
The GAN plays a minimax game, where the entire network attempts to optimize the function V(D,G). This is the equation that defines what a GAN is doing:
Now to anyone who isn’t well versed in the math behind it, it looks terrifying, but the idea it represents is simple, yet powerful. It’s just a mathematical representation of the two objectives as defined above.
The generator is defined by G(z), which converts some noise z we input into some data, like images.
The discriminator is defined by D(x), which outputs the probability that the input x came from the real dataset or not.
We want the predictions on the dataset by the discriminator to be as close to 1 as possible, and on the generator to be as close to 0 as possible. To achieve this, we use the log-likelihood of D(x) and 1-D(z) in the objective function.
The log just makes sure that the closer it is to an incorrect value, the more it is penalized.
Here’s an explanation for log loss if you aren’t sure what it does:
Understanding binary cross-entropy / log loss: a visual explanation
Have you ever thought about what exactly does it mean to use this loss function?
Coding the Generator
The generator is just a vanilla neural network model that takes a random input vector and outputs a 784-dim vector, which, when reshaped, becomes a 28*28 pixel image.
Coding the Discriminator
The discriminator is another neural network that takes the output of the previous network, a 784-dimensional vector, and outputs a probability between 0 and 1 that it came from the training dataset.
Compiling it into a GAN
We now compile both models into a single adversarial network, setting the input as a 100-dimensional vector, and the output as the output of the discriminator.
Training the GAN
- First, we load the data and split the data into several batches to feed into our model
- Here we just initialize our GAN network based on the methods defined above
- This is our training loop, where we run for the specified number of epochs.
- We generate some random noise and take out some images from our dataset
- We generate some images using the generator and create a vector X that has some fake images and some real images
- We create a vector Y which has the “correct answers” that corresponds to X, with the fake images labeled 0 and the real images labeled 0.9. They’re labeled 0.9 instead of 1 because it helps the GAN train better, a method called one-sided label smoothing.
- We need to alternate the training between the discriminator and generator, so over here, we update the discriminator
- Finally, we update the discriminator.
Our first GAN
Once we run the code above, we’ve effectively created our first GAN that generates digits from scratch!
Hopefully, this article provided an introduction to Generative Adversarial Networks and how to make one. In the near future, I’ll be writing a lot more about machine learning, do keep updated!
Thanks for reading,
Here’s some other posts I’ve written
Linear Regression From Scratch With Python
Implementing one of the most basic concepts in Data Science