Training GANs on Spatio-Temporal Data : A Practical Guide (Part 3)

Shantanu Chandra
AI FUSION LABS
Published in
10 min readApr 25, 2024

Part 3 : Special case of spatio temporal data

In part 1 and part 2 of this series on a practical guide to training GANs, we discussed some of the reasons for the most notorious GAN training instabilities and their empirical solutions.

In this article we will focus on training GANs on spatio-temporal (ST) data specifically. Generation of ST data usually leverages recurrent neural networks (RNNs) in an auto-regressive setting to model intra- as well as inter-time step patterns in data. Multi-variate timeseries generation is one such use case that involves capturing the correlation between multiple variables with each other at a given time-step, as well as their joint evolution over the temporal dimension. Spatio-temporal generation suffers from some additional challenges on top of the usual GAN training instabilities, such as:

1. Evaluating generation quality: GAN training loss are almost never indicative of the generation quality (as covered in previous articles of the series). Although practitioners use visual inspection of generated images to judge training progress, such visual evaluation is not feasible in the case of spatio-temporal data. It is therefore important to track the right metrics to be able to objectively compare model performance and choose the best model while finetuning.

2. Pronounced vanishing gradients: ST data generation involves RNNs, and they are well-known to be susceptible to vanishing gradients. Sometimes multiple RNNs are stacked sequentially (auto-encoders + generator + discriminator), resulting in highly muted training of each of the component due to temporal unrolling of gradients (e.g, from the last time-step of the discriminator to the first time-step of generator).

We will discuss how to detect the instabilities by objectively measuring GAN performance via relevant metrics that are indicative of the generation quality as well as diversity. We will then go on to detail some instabilities that arise specifically in the generation of ST data and their possible solutions.

1. GAN evaluation metrics and training dynamics

GAN training loss are almost never indicative of the generation quality. Having said that, the ideal trend of the G and D loss should ideally follow the following trend:

a) The D-loss should be high when the training starts as the D has to learn from scratch to distinguish real data from generated data. This should then drop sharply as the task of D is comparatively easier and it gets better at telling the real data apart from the poorly-generated one by the generator which is in its initial phase of learning the distribution. As the training progresses, the G ideally starts getting better again in fooling the D and the D-loss should then rise again and ideally oscillate around the range of 0.5. Why 0.5 ? Since we use a binary cross entropy loss with a balanced #real vs #generated samples, 0.5 loss signifies random chance of distinguishing real from generated. The D-loss should thus look something like Fig 1 below.

Fig 1 : D-loss from our experiments

b) The G-loss behaves inversely to D. Initially the D is not trained well and can not distinguish real data from generated ones that well. This gives weak signals to the G to improve its current generation quality and thereby lower G loss early in the training. Gradually, as the D gets better at its task, the G can not fool it with ‘gibberish’ anymore and is penalized heavily for sub-par generation, leading to the G loss to rise sharply. Eventually, G gets better, and the loss then drops to ideally saturate in a region around [1.0 to 1.2].

Fig 2 : G-loss from our experiments

While the GAN loss can be misleading in being indicative of the actual generation quality, other simple metrics can be very effective. These metrics can thus become the core of tracking and improving the actual generation quality. The following simple metrics can aid in deciphering the actual GAN performance better:

1. Precision/Recall of D’s binary classification: although F1 summarizes the precision-recall scores and is commonly used to measure binary classification performance, it is important to look closely at its constituent metrics individually. We can demonstrate this using the simple example below. Do take a moment to think why each of the scenarios have been tagged as good and bad before reading further to know the answers (real = label 1, generated = label 0):

a. Precision=0.8, recall=0.5, F1=0.6BAD

b. Precision=0.5, recall=0.8, F1=0.6GOOD

We can see here that although the F1 score is same in both the scenarios, the underlying model performance paints drastically different pictures.

Case (a) is the undesirable situation due to low recall for the real class (remember real samples are labeled as 1). This indicates that the D is classifying real samples as fake, i.e., it can not even tell a real sample as real. The D HAS TO be able to tell the real samples as real. Only then when it faults in tagging generated samples as real, that we can assume these generated samples are actually good. A low recall implies that it is not an optimal D that is capable of giving rich & useful signals to G for optimal learning. The G can fool a sub-optimal D easily by generating even gibberish.

Case (b) is the desirable situation due to the low precision and high recall for the real class. High recall indicates that the D can identify the real samples as real (which means it is optimally trained), but low precision suggests that it predicts even the generated samples as real. This is the desirable situation as it means that the G is able to fool even a good D. We observed this reasoning to work very accurately in the model tuning phase of our experiments as models with similar F1 but low precision generated significantly better data. This also correlated well with other generation quality metrics such as FID.

Fig 3 : Precision (left) and Recall (right) from our experiments

2. Feature matching loss for diversity: while the precision/recall of the discriminator is indicative of the generation quality, the feature matching loss (see next section for details) helps to track the diversity of the generated samples. It is important to pick the best model that performs well on both generation quality and diversity rating. In many of our experiments we noted that although the D scores were improving consistently (precision decreasing, recall high) for some of the runs, the feature matching loss started to increase mid-way. This is indicative of mode collapse, and the training should be restarted at this point by deploying the methods discussed in the subsequent sections.

Fig 4 : Feature-matching loss from our experiments

These simple metrics helped us not only to detect some of the pitfalls early during training, but also to compare models objectively and choose the best model that balances generation quality and diversity well. We will now look into some of the strategies that can be employed to tackle the most notorious training instabilities of GANs.

2. Pronounced vanishing gradients

When working with ST data, the G and D need to be sequential processing units such as RNNs. However stacking RNNs sequentially in the GAN setting leads to vanishing gradients, in addition to what might already be introduced by the G and D imbalance as discussed in part 1. We can see this phenomenon unroll using the following illustration:

Fig 5 : Gradient flow in G when using sequence level D

Let us unpack the above illustration that I have created for demonstration purpose.

In the usual case, the G is used to generate all the time steps (i.e., one sequence is generated in entirety), which are then fed to the D iteratively. Finally, like in any sequence classification task, the final hidden state of the D’s RNN network is used to predict if the given sequence is real or generated. The gradients at each time step are calculated by (simplified form of backpropagation in principle):

downstream gradient = local gradient × upstream gradient.

Let us assume for the sake of simplicity that during G training (i.e., D is kept static), the local gradient = ½ , and we start with a loss of 1. As a result, the gradient that reaches all the way back to the first time-step generation of G in this case with just 4 time-steps is reduced to just 0.1875 ! This only will get worse as the sequence length increases (which is the case in real world use-cases). This prevents the G from learning good temporal dependencies over longer sequences from D’s feedback. In practice, during our experiments we found this architecture design gave gradients in the range of just 1e-8 for the G while D had gradients in the range of 1e-3 since it is closer to the point of origin of the gradients (final time step of D’s RNN).

Solution:

  1. CNN as Discriminator: One of the most popular solutions is to use a 2D CNN (Yoon, 2015) adopted from sentence classification literature as the discriminator (Fig 6), such that the 2D kernel (say, size = hk, where h=window size in the temporal dimension, k=no. of features at each step) slides over the sequence in the spatio-temporal dimension. This setup mitigates the issue of compounded vanishing gradients caused by stacking two RNNs sequentially one after the other.
Fig 6 : CNN architecture for sequence classification

2. Time-step level RNN Discriminator: The other solution that we deployed effectively is to continue using RNNs (due to one advantage of RNNs over CNNs covered in the next section), but instead of classifying the sequence as real or fake just at the end, we rather made the D to predict at each time step. Thus, we effectively changed the task of the D from “is this sequence comprehensively real or fake?” to a more granular “is the sequence seen TILL NOW real or fake?”. As seen in Fig 7, this simple change to D’s formulation leads to much stronger gradients all the way back to the first time-step of G, which is now 0.6875 (3.7x more than the previous case). This happens due to the fact that origin of gradients is not JUST at the end of the sequence, but at each time step of generation. This effect is more prominent for longer sequences due to the asymptotic properties of this setup.

Fig 7 : Gradient flow in G when using time-step level D

This re-formulation of sequence-level D to time-step level D has two advantages:

1. Stronger gradients: in our experiments we found that time-step level D had much stronger gradients throughout the length of the sequence to learn from. These strong gradients ensure strong gradients throughout the sequence length (Fig 8) as they asymptotically converge to 0.5 instead of 0 in the case in the sequence-level D (for the toy example above). The gradients of G were now in the range of 1e-3, which is similar to both the D as well as the case of using 1D CNN as discriminator. However, the RNN formulation had an additional advantage over the 1D CNN option even though they both exhibited similar gradients.

Fig 8 : Gradient flow in G when using sequence level D

2. More informed learning: although the time-step level RNN gives similar gradients strength as the CNN, the information conveyed is richer. The time-step level D has more granular feedback to give to the G as to which time steps were generated well and which were not and how should the journey progress. Although such error assignment also happens in the sequence-level D, it is more explicitly done in this case since the D is specifically trained to pick on time-step level discrepancies. We also see this translating to better generation quality in our experiments, as the time-step level RNN generated more coherent sequences than the CNN counterpart.

Conclusion

In this article we saw how to track training progress and assess the generation quality objectively for spatio-temporal data. These metrics helped us in model tuning process significantly, as well as aided in detecting common GAN training pitfalls and take evasive action to vastly improve the generation quality. We also discussed solutions to some pitfalls specific to spatio-temporal usecase. We found these methods to work well in our experiments on generating multi-variate mixed data types time-series data. Note that the recommendations presented in this series are guidelines rather than rules of thumb. But we feel these learnings can be a great starting point for any practitioner! We are certain that there are many other aspects to this complex but interesting process of training GANs , and more so for the under-explored space of spatio-temporal data. So please let us know if you have faced any other challenges and how did you go about solving them! With that I conclude this series. Hope you had as much fun reading as I had putting it together!

References

Augustus Odena Christopher Olah, Jonathon Shlens Conditional Image Synthesis With Auxiliary Classifier GANs [Journal]. — 2016.

Ian Goodfellow Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio Generative Adversarial Nets [Journal]. — [s.l.] : Advances in Neural Information Processing Systems, 2014.

Mario Lucic Karol Kurach, Marcin Michalski, Sylvain Gelly, Olivier Bousquet Are GANs created equal? a large-scale study [Journal]. — [s.l.] : International Conference on Neural Information Processing Systems, 2018.

Tim Salimans Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen, Xi Chen Improved Techniques for Training GANs [Journal]. — [s.l.] : Advances in Neural Information Processing Systems, 2016.

Yoon Kim Convolutional Neural Networks for Sentence Classification [Journal]. — 2015.

About the author: Shantanu is an AI Research Scientist at the AI Center of Excellence lab at ZS. He did his Bachelor’s in Computer Science Engineering and Master’s in Artificial Intelligence (cum laude) from University of Amsterdam with his thesis at the intersection of geometrics deep learning and NLP in collaboration with Facebook AI, London and King’s College, London. His research areas include Graph Neural Networks (GNNs), NLP, multi-modal AI, deep generative models and meta-learning.

--

--

Shantanu Chandra
AI FUSION LABS

AI Research Scientist, AI Lab @ ZS | MS in AI, Univ of Amsterdam