Training GANs on Spatio-Temporal Data: A Practical Guide (Part 2)

Shantanu Chandra
AI FUSION LABS
Published in
10 min readJul 19, 2023

Part 2 : Practical tips for stabilizing GAN training

In part 1 of this series on practical guide to training GANs, we discussed a) how an imbalance between discriminator (D) and generator (G) training causes mode collapse and muted learning due to vanishing gradients, and b) GANs’ sensitivity to hyperparameters.

In this article, we will provide multiple solutions to each of these instabilities. These solutions have empirically worked well in our experiments after extensively trying out every trick in the book to stabilize GAN training. We compile the list in the order of their ease of implementation and their respective impact, to give a recommendation on iterative enhancements to your GAN training. Also, note that all the solutions discussed here are generic to any form of GAN training, and are directly relevant to the spatio-temporal usecase as well. The problems and solutions specific to JUST spatio-temporal GANs will be discussed in further detail in the final part of the series.

Taming GAN instabilities

1. Imbalance between Generator and Discriminator

As discussed in the previous article, the imbalance between the training G and D (i.e., either of G or D is disproportionately trained to be better than the other) causes vanishing gradients, as well as mode collapse when the G has no incentive to produce diverse samples to fool its competitor. To tackle this issue, the usual solutions revolve around:

· Changing the cost function for a better optimization goal.

· Adding additional penalties to the cost function to enforce constraints (e.g. diversity).

· Avoid overconfidence and overfitting.

The solutions to both vanishing gradients and mode collapse are discussed in detail in the following sub-sections. In each of these sections we first list all the (ranked) suggested solutions and the intuition behind each. Subsequently, we summarize the takeaways of each section in the end.

1.1 Vanishing Gradients:

To mitigate vanishing gradients, the usual strategies deployed revolve around making the task of D harder and giving G a chance to catch up. This is motivated by the belief that “it is easy to tell if a painting is Van Gogh or not, but it is very difficult to actually make one.” Thus, the underlying assumption is that the task of G is much harder than that of D.

1. One-sided label smoothing: if the D is overconfident of its predictions, it leads to vanishing gradients and the G cannot learn from such observations — predicting real samples as ~1 (e.g., 0.999) and generated samples as ~0 (e.g., 0.0001) gives ~0 loss. A simple yet highly effective technique to tackle this issue is to convert all the “1” ground truth labels of real data to a range of [0.7 to 1.2], and that of generated data from “0” to [0.0 to 0.3]. This penalizes the D when it gets too confident of its predictions, and ensures gradient flow even under correct prediction scenarios, enabling the G to learn from these instances. Note that this is done ONLY when updating the weights of D and NOT during G updates (hence the name “one-sided”).

2. One-sided label flipping: if you notice that the loss still goes down to 0 quickly, you can further cripple the over-performing D. Practitioners usually flip the label of real and generated data (real data labels randomly flipped from 1 to 0 ; generated data from 0 to 1). This adds noise to D training and prevents it from getting too powerful in any stage of the training. Again, this is done only for D updates.

3. Since the task of G is more difficult, G is usually trained for x steps (~ 2–5) (while keeping the D constant) before training the D again. This allows the generator to bridge the gap between p and q distributions early in the training and get meaningful feedback from the D to improve the generation. We advise not to spend too much time on this step before trying other suggestions in the list, as different phases of GAN training cannot be controlled by this heuristic of fixed relative updates. The relative updates will vary dynamically based on the evolving learning dynamics between G and D, and these can not be pre-defined heuristically. However, a factor of 2:1 for G:D training steps can be ideally followed.

4. Use batch normalization (BN) in D in conjugation with tip #5 below. Batch normalization is a supervised learning method for normalizing the interlayer outputs of a neural network. It aids in stabilizing the training process (by reducing the covariate shift) and improving generalization by preventing overfitting.

5. Feed generated and real samples to D separately. This little trick prevents the D from using shortcuts to make classifications that give the G no feedback to improve its generation. The purpose of BN is to reduce internal covariance shift in activation maps by making all of the activations be distributed equally (with zero mean and std equal to 1). In this case, there is no necessity for a NN to adapt to changes in distributions of activations, that occur due to the changes in weights during the training process. As a result, such normalization simplifies learning significantly. At the very beginning of GAN’s training real and fake samples in a mini-batch have very very different distributions, thus if we try to normalize it, we won’t end up with well-centered data. Moreover, the distribution of such normalized data will be changing significantly during the training (because the G provides better and better results progressively), and the discriminator will have to adapt to these changes.

6. Use other loss functions that offer more stable gradient distribution such as WGAN, RSGAN, etc. However, the paper by Google Brain empirically showed that no loss function is a clear winner (Mario Lucic, 2018), and the choice of GAN loss function is still a hazy area that is yet to be conquered.

7. Some form of pre-training the generator on a supervised task so that it aligns with the latent space broadly and learns to capture some basic features of the task (such as edges and contours in case of image generation). This helps in bridging the gap between p and q when the adversarial training begins, which prevents vanishing gradients from JSD due to a huge disconnect between G outputs and real-world data.

1.2 Mode Collapse:

To mitigate this, the usual strategies deployed are:

1. Use labels: whenever possible, use labels with an auxiliary classifier GAN setting (Augustus Odena, 2016). This encourages the G to establish connections between different regions of latent space to different labels used as conditional inputs. This prevents the G from generating the same output irrespective of its input, hence preventing mode collapse.

Fig 5 : The ACGAN architecture

2. Feature matching: this promotes diversity in generation by modifying the goal of the G from fooling the D successfully at any cost, to matching the latent features of real data. This involves taking the L-2 distance between the means of the respective feature vectors at a batch level. This minibatch setting introduces randomness which makes it harder to overfit on a single mode.

3. Minibatch-discrimination: to tackle mode collapse, real and generated data is fed to the D separately in different batches, and the similarity of datapoint x with data points in the same batch is computed. The similarity o(x) in then concatenated with features of the pen-ultimate layer of the discriminator to classify whether this data x is real or generated. In the event of mode collapse, the similarity of generated data starts to increase, and the D can use this information to start classifying the generated images correctly again and penalizing the G for lack of diversity. The exact mechanism to compute this is slightly more involved than feature matching, but is claimed to work better (Tim Salimans, 2016) in practice and is summarized below.

Here, xi is the input image and xj is the rest of the images in the same batch. These operations are pictorially depicted below:

Takeaways:

1. Trying to play with the “capacity” of G and D will most probably not make much of a difference (i.e., their relative parameter sizes). Training dynamics are affected more by the loss function that is being optimized for and the exact experiment setup, than the relative model sizes of G and D.

2. Trying to optimize for G:D training step ratio is a futile exercise. While one intuition urges you to train the G more, the other shows that this can be harmful. It is hard to design this static heuristic when the GAN training process is highly dynamic and sensitive. Many have tried this and failed.

3. Try one-sided label smoothing and label flipping as the first step. Very simple to integrate but highly effective. We saw huge learning gains by employing these two simple tricks.

4. Use batch normalization and feed the generated and real samples separately to the D.

5. Use alternate loss functions such as WGAN, RSGAN, etc to stabilize training via better gradients (but with a pinch of salt; no clear winners here).

6. Employ auxiliary classifier GAN framework (when labels are available), feature matching, and minibatch discrimination to promote diversity and tackle mode collapse.

2. Sensitivity to Hyper-parameters

GANs are VERY sensitive to hyper-parameters, period. Although it takes a lot of patience and time to optimally tune the hyper-parameters, this exercise can prove to be decisive in making or breaking your architecture performance. To aid this process, the general tips are:

1. Learning Rate (LR): Learning rate is one of the most important hyperparameters that can make or break your training and there are multiple heuristics to keep in mind while choosing one:

a) Two Time-Scale Update Rule (TTUR): this essentially means using different learning rates for G and D, with LR of G being lower than D. This ensures that G takes small steps to fool D which helps prevent mode collapse. If G takes steps too fast and too precise early during training, then it is more likely to pick a single mode that fools the D to win the adversarial game.

b) LR should depend on the batch size: higher LR for bigger batch sizes is fine, as they offer less noisy updates across batches that could cause huge fluctuations in GAN training. But it is advisable to be on the conservative side with the LR in general.

2. Batch size: larger batch size is preferred as more modes get covered within a batch which prevents the G from learning heavily from any single dominant mode in a batch and falling prey to mode collapse.

3. Early stopping: GAN training will always fluctuate, and a common mistake is to stop the training prematurely when the G loss starts diverging (especially earlier in training). Do not employ heuristic-based early stopping, instead track the evaluation metrics to look for mode collapse or vanishing gradients and restart training based on this observed behavior.

Takeaways:

  1. Learning Rate: train G and D with different LRs, with the LR of G being preferably lower than D’s.
  2. Batch size: prefer larger batch sizes to cover more modes in a mini-batch.
  3. Early stopping: Do not employ heuristic-based early stopping, instead track the evaluation metrics to look for mode collapse or vanishing gradients and restart training based on this observed behavior

This brings us to the end of the second blog of the series. In this part, we deep-dived into the potential solutions to the GAN instabilities discussed in Part 1. Note that the recommended ranked list of solutions is based on our experience and experiments, but can vary for your specific use case.

In the next and final part of the series, we will explore the special case of spatio-temporal data. We will first talk about objective evaluation metrics to track during training to detect some of the discussed pitfalls. Finally, we will elucidate some instabilities that arise specifically in training on spatio-temporal data and potential solutions to them.

References

Augustus Odena Christopher Olah, Jonathon Shlens Conditional Image Synthesis With Auxiliary Classifier GANs [Journal]. — 2016.

Ian Goodfellow Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio Generative Adversarial Nets [Journal]. — [s.l.] : Advances in Neural Information Processing Systems, 2014.

Mario Lucic Karol Kurach, Marcin Michalski, Sylvain Gelly, Olivier Bousquet Are GANs created equal? a large-scale study [Journal]. — [s.l.] : International Conference on Neural Information Processing Systems, 2018.

Tim Salimans Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen, Xi Chen Improved Techniques for Training GANs [Journal]. — [s.l.] : Advances in Neural Information Processing Systems, 2016.

About the author: Shantanu is an AI Research Scientist at the AI Center of Excellence lab at ZS. He did his Bachelor’s in Computer Science Engineering and Master’s in Artificial Intelligence (cum laude) from University of Amsterdam with his thesis at the intersection of geometrics deep learning and NLP in collaboration with Facebook AI, London and King’s College, London. His research areas include Graph Neural Networks (GNNs), NLP, multi-modal AI, deep generative models and meta-learning.

--

--

Shantanu Chandra
AI FUSION LABS

AI Research Scientist, AI Lab @ ZS | MS in AI, Univ of Amsterdam