Training a Conditional DC-GAN on CIFAR-10

After some promising results and tons of learning (summarized in my previous post) with a basic DC-GAN on CIFAR-10 data, I wanted to play some more with GANs. One issue with a traditional DC-GAN was that the data is expected to have similar properties in order for the training to converge properly. For instance, in case of CIFAR-10, training the DC-GAN on images of a single class was much easier and more likely to produce sharp images than training on all 10 classes. In that post on GAN learnings, I had casually mentioned Conditional GANs as an improvement over traditional GANs when the training data might come from different classes. This post describes how to setup a Conditional DC-GAN to generate images from all the classes of CIFAR-10 data.

Conditional Generative Adversarial Networks

Generative Adversarial Networks have two models, a Generator model G(z) and a Discriminator model D(x), in competition with each other. G tries to estimate the distribution of the training data and D tries to estimate the probability that a data sample came from the original training data and not from G. During training, the Generator learns a mapping from a prior distribution p(z) to the data space G(z). The discriminator D(x) produces a probability value of a given x coming from the actual training data.

This model can be modified to include additional inputs, y, on which the models can be conditioned. y can be any type of additional inputs, for example, class labels. The conditioning can be achieved by simply feeding y to both the Generator — G(z|y) and the Discriminator — D(x|y).

The original paper on Conditional GAN used a fully connected network for both the Generator and the Discriminator and was trained on MNIST data to produce digit images. We will be training a Conditional Deep Convolutional GAN on CIFAR-10 data. As such, we will slightly differ from the paper in how we provide the conditioning input.

Constructing the GAN

For the CIFAR-10 data, the conditioning input will be class label of the image, in a One-hot representation. We define a tensor variable to do this

We then define the Generator to accept this tensor as an input along with the latent variable tensor. This is done using the Keras Concatenate layer.

Now that we have the Generator defined, lets define the Discriminator. Given an input image and a class label for the image, the job of the Discriminator is to decide whether the image is a real image of that class or not. To do this, we need to provide the conditioning input to the Discriminator as well and this is where things get tricky due to the Convolutional layers (remember we are using a DCGAN, so the Discriminator has Conv2D layers). It wasn’t really intuitive to me how the conditioning input can be applied to the Convolutional layers of the Discriminator . The only place that seemed appropriate was at the input of the top Dense layer. This setup also made sense because we can think of it as if the Discriminator is learning high level features from the image and using them in conjunction with the conditioning input to make the final decision.

We define the discriminator model as follows:

Finally, we can define our GAN model, keeping in mind the conditioning inputs that are required:

Training the GAN

The training process is as usual — alternating between training the discriminator and the generator. We make use of some tricks to make training easier, however. The detailed list is my previous post. We use flipped labels, soft targets and add noise to the discriminator targets.

If you train a the GAN with everything mentioned above, you will probably end up getting a result like this:

Not impressive, is it? Apart from the fact that the images are blurry, the bigger problem is that several images ended up looking similar (some examples highlighted), even when they belong to different classes.

Mode collapse

The above problem is extremely common in GAN training and is a major issue. The issue is referred to as Mode Collapse and a lot of work is being done in reducing its impact on the training process. Lets try to understand why it happens and then we can see some simple tricks to fix mode collapse to some extent.

What is Mode Collapse?
Mode Collapse refers to the scenario when the Generator produces the same (or almost same) images every time and is able to successfully fool the discriminator. Not only is mode collapse pretty common, it gets triggered unpredictably making it very difficult to train and evaluate GANs. The underlying reason behind mode collapse is simple:

Real world data has distributions that are usually multi-modal. That is, there are some peaks — corresponding to high probability — in the distribution where the data usually resides. If the Generator is somehow able to identify one of these peaks (modes) and the Discriminator has not been trained well, it will fail to recognize that the generated data is simply coming from a single mode. One way to fix this would be to have the Discriminator assign low probability to samples generated from this mode. The Generator can, however, simply change to a different mode. In essence, the Generator keeps switching between a small number of modes rather than generating from the entire distribution. The Discriminator is unable to keep up with the Generator’s switching and hence, the result is a bunch of similar looking images.

Tricks to tackle Mode Collapse
Mode collapse is an active area of research although there are some tricks that can be used to reduce the severity of the problem:

1. Minibatch Discrimination
Can we penalize the generator for generating similar looking samples directly? Turns out, yes. The idea is to use samples generated in a batch to determine whether the entire batch is real or fake. A term that represents the diversity in the samples (computed using feature matching) in a batch is added to the Generator’s cost function. If several samples in a batch are similar, the Discriminator is able to easily detect that and hence, the Generator is forced to generate diverse samples.

2. Wasserstein GANs
Traditional GANs try to minimize the JS divergence between the Generator’s distribution and the real data distribution. Instead, minimizing the Wasserstein distance has been found to work better and mode collapse is much less severe in Wasserstein GANs. Refer to the paper on WGANs for more details.

3. Experience Replay
Every now and then, we can show previously generated samples to the discriminator and hence prevent the generator from easily fooling the discriminator. This is easy to implement and is what we will be doing for our GAN.

There are several more ways to deal with mode collapse, but explaining all would require a separate post.

To implement experience replay, ideally, we would maintain a set of previously generated samples and during replay, pick a random subset of samples. As new samples come in, the set is modified to remove the older samples. However, we will implement a very naive way to do this, to avoid storing too many samples in memory. Every minibatch, we randomly pick one generated sample and save it. After N such samples have been collected, we show them to the discriminator and empty the set of samples. The code to do this looks like this:

Finally, after adding all these changes, the result of the Conditional GAN is shown below. There is still a lot of room for improvement, but we have surely come a long way. Spending some more time on fine tuning the network should definitely yield better results.

The complete code is available here.

Related Papers:

Machine Learning and Artificial Intelligence Research | Amatuer Astronomer

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store