Generative Adversarial Networks are a powerful class of neural networks with remarkable applications. They essentially consist of a system of two neural networks — the Generator and the Discriminator — dueling each other.
Given a set of target samples, the Generator tries to produce samples that can fool the Discriminator into believing they are real. The Discriminator tries to resolve real (target) samples from fake (generated) samples. Using this iterative training approach, we eventually end up with a Generator that is really good at generating samples similar to the target samples.
GANs have a plethora of applications, as they can learn to mimic data distributions of almost any kind. Popularly, GANs are used for removing artefacts, super resolution, pose transfer, and literally any kind of image translation, as shown below:
However, they are excruciatingly difficult to work with, owing to its fickle stability. Needless to say, many researchers have proposed brilliant solutions to mitigate some of the problems involved with training GANs. However, the research in this area evolved so fast that, it became hard to keep track of interesting ideas. This blog makes an effort to list out some popular techniques that are commonly used to make GAN training stable.
Drawbacks of using GANs — An Overview
GANs are difficult to work with for a bunch of reasons. Some of them are listed below in this section.
1. Mode collapse
Natural data distributions are highly complex and multimodal. That is, the data distribution has a lot of “peaks” or “modes”. Each mode represents a concentration of similar data samples, but are distinct from other modes.
During mode collapse, the generator produces samples that belong to a limited set of modes. This happens when the generator believes that it can fool the discriminator by locking on to a single mode. That is, the generator produces samples exclusively from this mode.
The discriminator eventually figures out that samples from this mode are fake. As a result, the generator simply locks on to another mode. This cycle repeats indefinitely, and this essentially limits the diversity of the generated samples. For a more detailed explanation, you can check out this blog.
A common question in GAN training is “when do we stop training them?”. Since the Generator loss improves when the Discriminator loss degrades (and vice-versa), we can not judge convergence based on the value of the loss function. This is illustrated by the image below:
As with the previous problem, it is difficult to quantitatively tell when the generator produces high quality samples. Additional perceptual regularization added to the loss function can help mitigate the situation to some extent.
The GAN objective function explains how well the Generator or the Discriminator is performing with respect to its opponent. It does not however represent the quality or the diversity of the output. Hence, we need distinct metrics that can measure the same.
Before we dive deep into techniques that can aid performance, let us review some terminologies. This will simplify explanations of the techniques presented in the next section.
1. Infimum and Supremum
Put simply, Infimum is the largest lower bound of a set. Supremum is the smallest upper bound of a set. They differ from minimum and maximum in the sense that the infimum and supremum need not belong to the set.
2. Divergence Measures
Divergence measures represent the distance between two distributions. Conventional GANs essentially minimize the Jensen Shannon divergence between the real data distribution and the generated data distribution. GAN loss functions can be modified to minimize other divergence measures such as the Kulback Leibler divergence or Total Variation Distance. Popularly, the Wasserstein GAN minimises the Earth Mover distance.
3. Kantorovich Rubenstein Duality
Some divergence measures are intractable to optimize in their naive form. However, their dual form (replacing infimum with supremum or vice-versa) may be tractable to optimize. The duality principle lays a framework for transforming one form to another. For a very detailed explanation about the same, you can check out this blog post.
4. Lipschitz continuity
A Lipschitz continuous function is limited in how fast it can change. For a function to be Lipschitz continuous, the absolute value of the slope of the function’s graph (for any pair of points) cannot be more than a real value K. Such functions are also known as K-Lipschitz continuous.
Lipschitz continuity is desired in GANs as they bound the gradients of the discriminator, essentially preventing the exploding gradient problem. Moreover, the Kantorovich-Rubinstein duality requires it for a Wasserstein GAN, as mentioned in this excellent blog post.
Techniques for Improving Performance
There are a plethora of tricks and techniques that can be used for making GANs more stable and powerful. To keep this blog concise I’ve only explained techniques that are either relatively new or complex. I’ve listed out other miscellaneous tricks and techniques at the end of this section.
1. Alternative Loss Functions
One of the most popular fixes to the shortcomings of GANs is the Wasserstein GAN. It essentially replaces the Jensen Shannon divergence of conventional GANs with the Earth Mover distance (Wasserstein-1 distance or EM distance). The original form of the EM distance is intractable, and hence we use its dual form (calculated by the Kantorovich Rubenstein Duality). This requires the discriminator to be 1-Lipschitz, which is maintained by clipping the weights of the discriminator.
The advantage of using Earth Mover distance is that it is continuous even when the real and generated data distributions are disjoint, unlike JS or KL divergence. Also, there is a correlation between the generated image quality and the loss value (Source). The disadvantage is that, we need to perform several discriminator updates per generator update (as per the original implementation). Moreover, the authors claim that weight clipping is a terrible way to ensure 1-Lipschitz constraint.
Another interesting solution is to use mean squared loss instead of log loss. The authors of the LSGAN argue that the conventional GAN loss function does not provide much incentive to “pull” the generated data distribution close to the real data distribution.
The log loss in the original GAN loss function does not bother about the distance of the generated data from the decision boundary (the decision boundary separates real and fake data). LSGAN on the other hand penalizes generated samples that are far away from the decision boundary, essentially “pulling” the generated data distribution closer to the real data distribution. It does this by replacing the log loss with mean squared loss. For a detailed explanation of the same, check out this blog.
2. Two Timescale Update Rule (TTUR)
In this method, we use a different learning rate for the discriminator and the generator (Source). Typically, a slower update rule is used for the generator and a faster update rule is used for the discriminator. Using this method, we can perform generator and discriminator updates in 1:1 ratio, and just tinker with the learning rates. Notably, the SAGAN implementation uses this method.
3. Gradient Penalty
In the paper Improved Training of WGANs, the authors claim that weight clipping (as originally performed in WGANs) lead to optimization issues. They claim that weight clipping forces the neural network to learn “simpler approximations” to the optimal data distribution, leading to lower quality results. They also claim that weight clipping leads to the exploding or vanishing gradient problem, if the WGAN hyperparameter is not set properly. The author introduces a simple gradient penalty which is added to the loss function such that the above problems are mitigated. Moreover, 1-Lipschitz continuity is maintained, as in the original WGAN implementation.
The authors of DRAGAN claim that mode collapse occurs when the game played by the GAN (i.e. discriminator and generator going against each other) reaches a “local equilibrium state”. They also claim that the gradients contributed by the discriminator around such states are “sharp”. Naturally, using a gradient penalty will help us circumvent these states, greatly enhancing stability and reducing mode collapse.
4. Spectral Normalization
Spectral normalization is a weight normalization technique that is typically used on the Discriminator to enhance the training process. This essentially ensures that the Discriminator is K-Lipschitz continuous.
5. Unrolling and Packing
As stated in this excellent blog, one way to prevent mode hopping is to peek into the future and anticipate counterplay when updating parameters. Unrolled GANs enables the Generator to fool the Discriminator, after the discriminator had a chance to respond (taking counterplay into account).
Another way of preventing mode collapse is to “pack” several samples belonging to the same class before passing it to the Discriminator. This method is incorporated in PacGAN, in which they have reported decent reduction of mode collapse.
6. Stacking GANs
A single GAN may not be powerful enough to handle a task effectively. We could instead use multiple GANs placed consecutively, where each GAN solves an easier version of the problem. For instance, FashionGAN used two GANs to perform localized image translation.
Taking this concept to the extreme, we can gradually increase the difficulty of the problem presented to our GANs. For instance, Progressive GANs (ProGANs) can generate high quality images of excellent resolution.
7. Relativistic GANs
Conventional GANs measure the probability of the generated data being real. Relativistic GANs measure the probability of the generated data being “more realistic” than the real data. We can measure this “relative realism” using an appropriate distance measure, as mentioned in the RGAN paper.
The authors also mention that the discriminator output should converge to 0.5 when it has reached the optimal state. However, conventional GAN training algorithms force the discriminator to output “real” (i.e. 1) for any image. This, in a way, prevents the discriminator from reaching its optimal value. The relativistic method solves this issue as well, and has pretty remarkable results, as shown below.
8. Self Attention Mechanism
The authors of Self Attention GANs claim that convolutions used for generating images look at information that are spread locally. That is, they miss out on relationships that span globally due to their restrictive receptive field.
Self-Attention Generative Adversarial Network allows attention-driven, long-range dependency modeling for image generation tasks. The self-attention mechanism is complementary to the normal convolution operation. The global information (long range dependencies) aid in generating images of higher quality. The network can choose to ignore the attention mechanism, or consider it along with normal convolutions. For a detailed explanation, you can check out their paper.
9. Miscellaneous Techniques
Here is a list of some additional techniques (not exhaustive!) that are used to improve GAN training:
- Feature Matching
- Mini Batch Discrimination
- Historical Averaging
- One-sided Label Smoothing
- Virtual Batch Normalization
Now that we have established methods to improve training, we need to quantitatively prove it. The following metrics are often used to measure the performance of a GAN:
1. Inception Score
The inception score measures how “real” the generated data is.
The equation has two components
p(y) . Here,
x is the image that is produced by the Generator, and
p(y|x) is the probability distribution obtained, when you pass image
x through a pre-trained Inception Network (pretrained on the ImageNet dataset, as in the original implementation). Also,
p(y) is the marginal probability distribution, which can be calculated by averaging
p(y|x) over a few distinct samples of generated images (
x). These two terms represent two different qualities that are desirable on real images:
- The generated image must have objects that are “meaningful” (objects are clear, and not blurry). This means that
p(y|x)should have “low entropy”. In other words, our Inception Network must be strongly confident that the generated image belongs to a particular class.
- The generated images should be “diverse”. This means that
p(y)should have “high entropy”. In other words, generator should produce images such that each image represents a different class label (ideally).
If a random variable is highly predictable, it has low entropy (i.e.
p(y|x)must be a distribution with a sharp peak). On the contrary, if it is unpredictable, it has high entropy (i.e.
p(y) must be a uniform distribution). If both these traits are satisfied, we should expect a large KL divergence between
p(y) . Naturally, a large Inception Score (IS) is better. For a deeper analysis on the Inception Score, you can checkout this paper.
2. Fréchet Inception Distance (FID)
A drawback of the Inception Score is that statistics of the real data are not compared with the statistics of the generated data (Source). Fréchet distance resolves the drawback by comparing the mean and covariance of the real and generated images. Fréchet Inception Distance (FID) performs the same analysis, but on the feature maps produced by passing the real and generated images through a pre-trained Inception-v3 Network (Source). The equation is described as follows:
A lower FID score is better, as it explains that the statistics of the generated images are very similar to that of the real images.
The research community has produced numerous solutions and hacks to overcome the shortcomings of GAN training. However, it is difficult to keep track of significant contributions due to the sheer volume of new research. The details shared in this blog is not exhaustive for the same reason, and may become outdated in the near future. Nevertheless, I hope this blog serves as a guideline for people looking for methods to improve the performance of their GANs.