V-GAN (Variational Discriminator Bottleneck): an unfair fight between Generator and Discriminator
This blog is about the new paper titled “Variational Discriminator Bottleneck: Improving Imitation Learning, Inverse RL, and GANs by Constraining Information Flow” by X. Peng et al. I will try to explain the core concept of the paper here. Like always, I wanted to try this technique for myself, but the official code isn’t released yet. Specifically, I was quite excited about the V-GAN (Variational GAN) from the paper. So, I implemented it myself and I am sharing the code along with this blog post. You can find the code at my repository. Currently, The repo contains implementation of the V-GAN and all the variants like VGAN-GP. I will be working on adding the VAIL and VAIRL soon.
The Problem with a fair fight:
A GAN as powerful a technique it is, can be notoriously fickle and unstable to train. Primarily, the performance of the Generator and the Discriminator needs to be adjusted properly such that neither becomes too strong over the other. Generally, if the Discriminator becomes too strong, i.e. it can easily tell the samples apart, it would cease to supply plausible gradients to the Generator for training. This problem has been addressed in many prior works: off the top of my head, some of them are TTUR, Progressive Growing GANs, WGAN (to some extent) and even my earlier proposal of MSG-GAN. According to TTUR, the learning rates (update scales) of the G and D can be adjusted properly to overcome this problem [My opinion: too difficult to adjust the rates without running the training a dozen times]. Progressive Growing GANs don’t need much hyperparameters tuning, but, damn they are slow to train.
So, basically we need to lame the Discriminator slightly such that it doesn’t get too powerful. One could say, how about we model the discriminator as a smaller network? Unfortunately, there is a very thin line between not letting the Discriminator win and completely ruining the training. If the Discriminator has too few parameters, we restrict it’s capacity to discriminate between the distributions. “A Pupil is only as good as it’s Teacher” as the old saying says, Discriminator cannot be too weak.
The Solution: An unfair fight between two equally potent Generator and Discriminator:
The core concept proposed by the paper is to enforce an Information Bottleneck between the Input images and the Discriminator’s internal representation of them.
As shown in the diagram, the Discriminator is divided into two parts now: An Encoder and the actual Discriminator. Note that the Generator is still the same. The Encoder is modelled using a ResNet similar in architecture to the Generator, while the Discriminator is a simple Linear classifier. Note that the Encoder doesn’t output the internal codes of the images, but similar to a VAE’s encoder, gives the means and stds of the distributions from which samples are drawn and fed to discriminator.
The information bottleneck I(X, Y) < Ic is enforced by adding another term into loss function of the combined discriminator (Encoder + classifier). The complete loss function is given by:
Where, the first two terms are the classification objective of a normal GAN. While the term with the beta is the proposed regularization to enforce the information bottleneck. But wait what? Where is mutual information? We are reducing the KL divergence between the encoded distribution and r(z) which something totally new. Relax, its just a variational upper bound over the mutual information and the r(z) is a prior of our choice. Guess what should r be? Yes! You guessed it correct! it is indeed the unit Gaussian. Please refer to the paper to know about how the information bottleneck can be estimated with this KL divergence. Also, please refer to the wiki for Mutual Information for better insights.
Now, to the best part: the value of beta which is mathematically termed as the Lagrange multiplier is automatically inferred using the technique called dual gradient descent! I found Jonathan Hui’s blog very helpful to get my head around this complicated concept of dual gradient descent. So, the beta value is updated as:
Note that the (p~) distribution is the mixture of real and generated images. And, finally the generator is updated as:
Please note the “blink and miss” mu inside the D(). The discrimination is performed directly over the predicted means by the Encoder for the generator update pass of the VGAN which is a good estimate of the expected value of those distributions.
The only added hyperparameters here are the Ic: the constant for information bottleneck and the alpha_beta: the learning rate for the beta parameter.
My CelebA experiment:
I trained the VGAN-GP (just replace the normal GAN loss with WGAN-GP) on the CelebA dataset and the results are shown in the first figure of the blog. The value for Ic that I used is 0.2 as described in the paper and the architectures for G and D are also as described in the paper. The authors trained the model for 300K iterations, but the results that I displayed are at 62K iterations which took me 22.5 hours to train. I will be training them further, but I would really like the readers and enthusiasts to take this forward as I have made the code open-source.
The trained weights for this experiment are available at my shared drive here.
The technique is a quite clever use of the knowledge from the Information theory. Basically, since the fight is unfair, the Discriminator is forced to focus only on the most discerning features in the input images and thereby guides the generator better. My key observation in this experiment is that the Generator learned to model the light variations in the hair of Celebs very early in the training which buttresses original hypothesis of discriminator focusing on the most discerning features.
Food for thought: I have been wondering if we could make the Ic a trainable parameter too. Fundamentally, in the beginning of the training, the bottleneck should be high as the discriminator should not win very early, but as the generator gets better and there is indeed enough overlap between the supports of the two distributions, the bottleneck can be relaxed.
As always, Please feel free to provide any feedback / improvements / suggestions. Contributions to the code / technique are most welcome.
Thank you for reading!