A must-have training trick for VAE(variational autoencoder)
The trick is called Cyclical KLAnnealing Schedule, as described in a paper by Duke University and Microsoft Research, Redmond.
VAE is a powerful deep generative model commonly seen in NLP tasks. Although the concept of VAE is not the emphasis of this article, a brief intro to VAE is helpful for comprehension of this trick. When training a VAE model, we use the training data itself as the label and suppress the data into a low-dimension space. We have two parts in VAE: encoder and decoder. The encoder encodes the data into a low-dimension space while the decoder reconstructs the original data from the latent representation. By doing this, we force the model to express our training data in a compact way and group similar data together in the latent space. The similarity between neighbor points in that space depends on the design of your objective function. A good representation means a lot in the NLP field for that a word or sentence is usually described in very high dimensions, hence challenging to model. With a meaningful and compact representation, we can perform better in the downstream tasks. For instance, researchers had to search the chemical space of 10⁶⁰ in size for a new molecule in drug discovery, but now they can narrow down the search space and navigate in the desired area.
All training tricks for VAE serve one purpose: to represent better the original data with a minor loss of information.
In the objective function are two components: reconstruction loss and the loss of Kullback–Leibler divergence term(KL loss). The former indicates how well VAE can reconstruct the input sequence from the latent space, while the latter measures how similar two data distributions are with each other. In real-world applications, we usually calculate the KL divergence between the latent distribution and standard normal distribution. As we can see in the below function, coefficient Beta controls how much KL-divergence weighs in the total loss.
Here we have the notorious KL vanishing problem: the KL term becomes vanishingly small(close to zero) during training. It happens when the decoder works in an auto-regressive fashion. And the KL vanishing issue would lead to a less interesting, if not meaningless, representation where data distribution complies with the standard normal distribution.
One way to alleviate the KL-vanishing issue is to apply annealing schedules for the KL term.
The traditional way is monotonic annealing. By starting with a small coefficient Beta, we force the model to focus on reconstructing the input sequences rather than minimizing the KL loss. As Beta increases, the model gradually emphasizes the shape of data distribution, and eventually Beta reaches one or a pre-specified value.
Researchers from Duke and Microsoft Research proposed the cyclical annealing schedule, in which the traditional annealing is repeated multiple times. This schedule helps build a better organized latent space with trivial additional costs in computation. The official demonstration in the paper is shown below,
As we can see, there’re three candidates on how to increase the weight of the KL term: Linear, Sigmoid, and Cosine. Empirically I found them approximately the same in training, and the performance in the downstream task improved by double digits in percent. The official GitHub repo of the paper is here.
Reference:
Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing
https://aclanthology.org/N19-1021.pdf
Extraneous thoughts
Try to think of ELBO as a two-objective optimization. Then we can apply the trick to any multi-objective optimization problem(especially when one component of the loss function leads to the model bypassing all other loss components).