Keep Calm and train a GAN. Pitfalls and Tips on training Generative Adversarial Networks
Generative Adversarial Networks (GANs) are among the hottest topics in Deep Learning currently. There has been a tremendous increase in the number of papers being published on GANs over the last several months. GANs have been applied to a great variety of problems and in case you missed the train, here is a list of some cool applications of GANs.
Now, I had read a lot about GANs, but never played with one myself. So, after going through some inspiring papers and github repos, I decided to try my hands on training a simple GAN myself and I immediately ran into problems.
This article is targeted towards deep learning enthusiasts who are starting out with GANs. Unless you are very lucky, training a GAN for the first time on your own can be a frustrating process and can take hours and hours to get right. Of course, over time with experience, you will get good at training GANs, but for starters there are several things that can go wrong and you might be clueless as to where to even start to debug. I want to share my observations and lessons learnt when training a GAN from scratch for the first time, hoping it might save someone starting out a few hours of debugging time.
Generative Adversarial Networks
Unless you have been living under a hut for the last year or so, everyone in Deep Learning — and even some not involved in Deep Learning — have heard and talked about GANs. GANs or Generative Adversarial Networks are Deep Neural Networks that are generative models of data. What this means is, given a set of training data, GANs can learn to estimate the underlying probability distribution of the data. This is very useful, because apart from other things, we can now generate samples from the learnt probability distribution that may not be present in the original training set. As listed in the link above, this has led to some really useful applications.
There are already several amazing resources by experts in the field explaining GANs and how they work, so I will not try to replicate their work. But for the sake of completeness, here is a quick overview.
Generative Adversarial Networks are actually two deep networks in competition with each other. Given a training set X (say a few thousand images of cats), The Generator Network, G(x), takes as input a random vector and tries to produce images similar to those in the training set. A Discriminator network, D(x), is a binary classifier that tries to distinguish between the real cat images according the training set X and the fake cat images generated by the Generator. As such, the job of the Generator network is to learn the distribution of the data in X, so that it can produce real looking cat images and make sure the Discriminator cannot distinguish between cat images from the training set and cat images from the Generator. The Discriminator needs to learn keep up with the Generator trying new tricks all the time to generate fake cat images and fool the Discriminator.
Ultimately, if everything goes well, the Generator (more or less) learns the true distribution of the training data and becomes really good at generating real-looking cat images. The Discriminator can no longer distinguish between training set cat images and generated cat images.
In this sense, the two networks are continuously trying to make sure the other does not do a good job at their task. So then, how can this work at all?
Another way to look at the GAN setup is that the Discriminator is trying to guide the Generator by telling it what real cat images look like. And eventually, the Generator figures it out and starts generating real-looking cat images. The method of training GANs is similar to the Minimax algorithm from Game Theory and the two networks try to achieve what is called the Nash Equilibrium with respect to each other. Refer to the references at the bottom if you would like to learn about this in more detail.
Challenges in GAN Training
Coming back to actually training GANs. To start with something easy, I trained a GAN (DC-GAN, to be precise) on the MNIST dataset using Keras with Tensorflow backend. This was not too difficult and after some minor tweaks to the Generator and Discriminator networks, the GAN was able to generate sharp images of MNIST digits.
Black and White digits are only so much fun. Color images of objects and people are what all the cool guys play with. And this is where things started to get tricky. After MNIST, the obvious next step is to generate CIFAR-10 images. After days and days of tweaking hyperparameters, changing network architectures, adding and removing layers, I was finally able to generate decent looking images similar to CIFAR-10.
I started out with a pretty deep (but, mostly non-performing) network and ended up with a much simpler network that actually worked. As I started adjusting the networks and the training process, the generated images after 15 epochs went from looking like this,
to this,
to ultimately this:
Below is a list of mistakes I realized I had made and things I have learnt along the way. So, if you are new to GANs and are not seeing a lot of success in training, maybe looking at the following aspects might help:
Obligatory Disclaimer: This is just a list of things I tried and the results I got. I do not claim to have solved all GAN training problems.
1. Large kernels and more filters
Larger kernels cover more pixels in the previous layer image and hence, can look at more information. 5x5 kernels worked well with CIFAR-10, Using 3x3 kernels in discriminator caused the discriminator loss to rapidly approach 0. For the generator you want larger kernels at the top convolutional layers to maintain some kind of smoothness. At the lower layers, I didn’t see any major effects of changing kernel size.
The number of filters can increase the number of parameters by a large amount, but more filters are usually desirable. I used 128 filters in almost all convolutional layers. Using less filters, especially in the Generator, made the final generated images too blurry. So, looks like more filters help capture additional information which can eventually add sharpness to the generated images.
2. Flip labels (Generated=True, Real=False)
Although it seems silly at first, one major trick that worked for me was to change label assignments.
If you are using Real Images = 1 and Generated Images = 0, it helps to have it the other way around. As we will see later, this helps with the gradient flow in the early iterations and helps get things moving.
3. Soft and Noisy labels
This is extremely important when training the discriminator. Having hard labels (1 or 0) nearly killed all learning early on, leading the discriminator to approach 0 loss very rapidly. I ended up using a random number between 0 and 0.1 to represent 0 labels (real images) and a random number between 0.9 and 1.0 to represent 1 labels (generated images). This is not required when training the generator.
Also, it helps to add some noise to the training labels. For 5% of the images that were fed to the discriminator, the labels were randomly flipped i.e real was labeled as generated and generated was labeled as real.
4. Batch norm helps, but only if you have other things in place
Batch normalization definitely helps the final result. Adding Batch norm resulted in distinctly sharper generated images. But, if you have incorrectly set your kernels or filters, or if the discriminator loss quickly reaches 0, adding batch norm might not really help recover from that.
5. One class at a time
In order to make it easier to train GANs, it is useful to ensure the input data has similar characteristics. For example, instead of training a GAN on all 10 classes of CIFAR-10, it is better to pick one class (say, cars or frogs) and train a GAN to generate images from that class. There are other variants of DC-GAN that do a better job of learning to generate images from several classes. Conditional GANs for instance, take the class label as input and generate images conditioned on the class label. But, if you are starting out with a plain DC-GAN, it is better to keep things simple.
6. Look at the Gradients
If possible, try to monitor the gradients along with the losses in the networks. These can help give a good idea about the progress of training and can even help in debugging if things are not really working well.
Ideally, the generator should receive large gradients early in the training because it needs to learn how to generate real-looking data. The discriminator on the other hand does not always get large gradients early on, because it can easily distinguish real and fake images. Once the Generator has been trained enough, it becomes harder for the discriminator to tell apart real from fake images. It would keep making errors and get strong gradients.
My first few versions of the GAN on CIFAR-10 cars, had many convolutional and batch norm layers and no label flipping. Apart from the trends, it is also important to monitor the scale of the gradients. If the gradients at the layers of the Generator are too small, learning might be slow or not happen at all. This is visible in this version of the GAN.
The scale of the gradients at the bottom-most layer of the Generator was too small for any learning to take place. The Discriminator gradients were also consistent throughout, hinting that the Discriminator was not really learning anything. Now, lets compare this to the gradients of a GAN that had all the changes described above and produced good real-looking images:
The scale of the gradients reaching the bottom layer of the Generator is clearly higher than in the previous version. Also, the gradients flow as expected as the training progresses with the Generator getting large gradients early on and the Discriminator getting consistently high gradients at the top layer once the Generator has been trained enough.
7. No early stopping
A silly mistake I made — probably due to my impatience — was to kill training after a few hundred mini-batches when I saw the losses not making any discernible progress or if the generated samples stayed noisy. It is tempting to restart the job and save time than to wait for training to finish and realize in the end that the network never learnt anything. GANs take a long time to train and initial few values of losses and generated samples almost never show any trend or signs of progress. It is important to wait for a while before killing the training process and tweaking something in your setup.
One exception to this rule is if you see the Discriminator loss rapidly approaching 0. If that happens, there is almost no chance of recovery and its better to restart training, probably after changing something in the networks or the training process.
The final GAN that ended up working looked like this:
So, thats it. I hope these pointers help anyone just training their first DC-GAN from scratch. Here are some resources I followed and others containing tons of information about GANs:
GAN Papers:
Generative Adversarial Networks
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
Improved Techniques for Training GANs
Other Links:
The GAN code in Keras for the final working version is available on my Github.