Training GANs on Spatio-Temporal Data: A Practical Guide (Part 1)

Shantanu Chandra
AI FUSION LABS
Published in
8 min readApr 17, 2023

--

Part 1: A deep dive into the most notorious instabilities in GAN training.

GANs are by far the most popular deep generative models out there, mainly due to the incredible results they have produced on image generation tasks in recent times. However, GANs are not that easy to train owing to a myriad of instabilities that they introduce due to their fundamental design. If you have tried training a GAN on anything but MNIST, you will quickly realize that all the claims about the pain in training them (and the related research areas trying to solve this issue) are not magnifying the problem out of proportions.

We will systematically address the reasons and solutions to these notorious instabilities that we found to work well empirically in our experiments after extensively trying almost all the tricks in the book. This three-part series on practical guide to training GANs with emphasis on spatio-temporal data generation is structured as follows:

1. Part 1: A deep dive into the most notorious instabilities in GAN training.

2. Part 2: Possible solutions to the common pitfalls discussed in Part 1.

3. Part 3: The special case of training GANs on spatio-temporal data — metrics to track, unique complexities and their solutions

The instabilities and solutions discussed in this series are model and use-case agnostic, and are relevant for the spatio-temporal case too. They serve as a good starting point for any GAN training exercise. In this article we will discuss why training GANs is so elusive by detailing the most notorious instabilities that are part of GAN training. We will study a) how imbalance between discriminator (D) and generator (G) training causes mode collapse and muted learning due to vanishing gradients ; b) GANs sensitivity to hyperparameters and, c) Misleading GAN loss when it comes to model performance.

[Note: we assume that the audience of this article has pre-requisites of GAN fundamentals and also some prior experience with training GANs at some point. To that end, we will skip the “What are GANs?” section here and defer the readers to this article for a quick recap.]

Why is training GANs so elusive?

In this section we detail some of the most notorious instabilities in GAN training, and detail possible solutions for each of them that have worked well in practice in our experiments. Having said that, it is advisable to run your first few iterations under the vanilla setting to probe which of the following pitfalls are observed in your architecture and task at hand. Subsequently, you can iteratively implement the mentioned solutions (which are ordered based on complexity of the solution and its effectiveness in our experience) to stabilize the training further. Note that these tips only serve as a directional starting point, and not an exhaustive list for one-shot solution. Readers are advised to probe their architectures and training dynamics further to achieve the best results.

1. Imbalance between Generator and Discriminator

It is easy to tell if a painting is Van Gogh or not, but it is very difficult to actually make one. It is therefore believed that the task of the G is more difficult than that of D. At the same time, how well the G learns to generate realistic outputs, depends on how well the D is trained. An optimal D will give the G rich signals to learn from and improve its generation. Thus, it is important to balance the training of G and D for optimal learning conditions.

GANs are based on the zero-sum non-cooperative game trying to achieve Nash equilibrium. However, it is known that some cost functions cannot converge with gradient descent, in particular for a non-convex game. This introduces many instabilities in GAN training due to an imbalance between the G and D training steps in their min-max game, leading to sub-optimal gradients for learning. These instabilities are discussed below:

  1. Vanishing gradients:

Whether the D should be better than G or vice-versa for optimal GAN training needs to be answered by looking at the following arguments-

a) If the D gets too good too quickly, the gradients for G vanish and it is unable to ever catch up.

b) On the other hand, if the D is sub-optimal, then G can easily fool it with even gibberish owing to D’s poor predictive performance. This again leads to no gradients to learn from, resulting in no improvement in G outputs.

Thus, in the ideal case, the G and D should get better than the other in a cyclical manner. If you see one of the losses moving monotonically in any direction, then your GAN training has most probably collapsed.

2. Mode collapse: If the G is trained disproportionately more, it converges to producing the same output repeatedly that fools the D well, without having any incentive to focus on the diversity of samples.

1.1 Vanishing Gradients:

The generator of the original GAN objective (Ian Goodfellow, 2014) optimizes the non-saturating JS divergence given by:

In this case, it is easy to see that JSD penalizes the generator if it misses some modes of the distribution (i.e., the penalty is high where p(x) > 0 but q(x) → 0), as well as if the generated data does not look real (i.e., high penalty if p(x) → 0 but q(x) > 0). This nudges the generator to produce better quality output, while also maintaining diversity.

However, this formulation leads to vanishing gradients for the generator when the discriminator gets optimal. This is evident from the following example, where p and q are Gaussian distributed, and the mean of p is zero. The plot on the right shows that the gradient for the JS-divergence vanishes from q1 to q3. This would result in the GAN generator learning extremely slowly (or even not at all) when the loss is saturated in those regions. This scenario is exhibited early in the GAN training, when p and q are very different, and the task of the D is easier as the G’s approximation is far off from the actual distribution.

1.2 Mode Collapse:

Mode collapse is by far the most difficult and non-trivial problem to solve when training GANs. Although there are many intuitive explanations for mode collapse, in practice our understanding of it is still very limited. One key intuitive explanation that has helped practitioners so far is the extreme case where the G is trained disproportionately more without enough updates on D. The generator eventually converges to finding the optimal image x* that fools D the most, i.e., the most realistic image from the discriminator perspective. In this case, x* becomes independent of z, implying for every z it generates the same image.

Eventually the D (when trained again) learns to discard images of that mode as fake. This in turn forces the generator to look for the next vulnerable point and it starts generating that. This cat-and-mouse chase between D and G continues and the G gets so focused on “cheating” that it loses the ability to even detect other modes. This is seen in the above image where the top row shows the ideal learning process that the G should have followed. The bottom row demonstrates the case of mode collapse where the G focuses on producing just one mode really well while ignoring the others.

2. Sensitivity to Hyper-parameters

GANs are VERY sensitive to hyper-parameters, period. None of the cost functions will work without good hyperparameters, hence it is advisable to first extensively tune the hyper-parameters instead of trying different loss functions right in the beginning. Tuning hyper-parameters takes time and a lot of patience, and it is important to understand the basic training dynamics of your architecture before you start playing with advanced loss functions that will introduce their own set of hyper-parameters.

3. Correlation between GAN loss and generation quality

In usual classification tasks, the cost function correlates with the accuracy of the model (lower loss means lower errors implying higher accuracy). However, the loss in GANs measure how well is one participant doing against the other in the min-max game (generator vs discriminator). It is very common to see the generator loss increasing, and yet the image quality improving. Thus, there is very little correlation between loss “convergence” and generation quality when it comes to training GANs as the unstable GAN loss is often misleading. One highly effective and widely accepted technique used in image generation tasks is to track the training progress via visual inspection of the generated images at different stages of training. But this subsequently makes model comparison harder, and complicates the tuning process further as it gets difficult to pick the best model from such subjective evaluation. However, during our experiments we quickly realized that this very critical aspect of GAN training — tracking generation progress via the correct metrics — is also one of the most neglected ones when people talk about training GANs. Moreover, unlike images, we cannot “visually” evaluate the training progression of spatio-temporal data effectively. Thus, it becomes critical to design and track metrics relevant for spatio-temporal data that are indicative of model performance objectively.

Now that we have detailed some prominent GAN training pitfalls, the next question that arises is how do we detect and solve them? We discuss this topic in detail in our next blog of this series where we provide multiple solutions for each of them after extensively trying out every trick in the book. We compile the list in the order of their ease of implementation and their respective impact to give a recommendation on iterative enhancements to your GAN training.

References

Ian Goodfellow Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio Generative Adversarial Nets [Journal]. — [s.l.] : Advances in Neural Information Processing Systems, 2014.

About the author: Shantanu is an AI Research Scientist at the AI Center of Excellence lab at ZS. He did his Bachelor’s in Computer Science Engineering and Master’s in Artificial Intelligence (cum laude) from University of Amsterdam with his thesis at the intersection of geometrics deep learning and NLP in collaboration with Facebook AI, London and King’s College, London. His research areas include Graph Neural Networks (GNNs), NLP, multi-modal AI, deep generative models and meta-learning.

--

--

Shantanu Chandra
AI FUSION LABS

AI Research Scientist, AI Lab @ ZS | MS in AI, Univ of Amsterdam