A Crash Course on VAEs, VQ-VAEs, and VAE-GANs
A complete understanding of what I learned from training VAEs for half a year
Over the past half-year, I have been experimenting with different generative models in deep learning, or more specifically, different types of VAEs. In short, the Variational AutoEncoder is a generative model capable of reproducing images, and along with that, spitting out the latent vector of the images which can be used to train other neural networks or latent space manipulation. With the vanilla or plain VAE, I have tried producing high-quality images from ImageNet, only to realize that it is just simply not capable of producing these types of images in that latent space (mostly because every output looked like a potato).
I then discovered the VQ-VAE, which solved the latent space problem I encountered with the normal VAE by making the latent space discrete instead of continuous. Adding on to this madness, we have the VAE-GAN, which combines the best of the VAE and GAN to generate hyper-realistic images. Now, of course, we could go one step further and tack on all the acronyms to make one VQ-VAE-GAN, but that just includes all the concepts of the other networks and doesn’t need its own explanation.
Furthermore, the VQ-VAE can get results equal to GANs, and also generates higher quality content, so the VAE-GAN has become almost redundant at this point. Adding on to that, VAEs are much easier to train, and so in the future, they might be used almost exclusively for image reconstruction tasks such as generation, latent space manipulation, and text to image generation.
The vanilla VAE is one of the most well known generative models (apart from the GAN). It has three parts, which are consisted of a downsample, a latent vector usually called z, and an upsample. The downsample usually consists of convolutions, and the upsample either is a transposed convolution or a block containing a normal upsample and convolution with retained spatial dimensions from the upsample. However, the upsampling with convolutions usually produces smoother results than the normal transposed convolutions.
In between the convolutions, we find the latent vector, the hallmark of a VAE. The latent vector is a combination of the mean and standard deviation of the output of the convolutions (as shown in the diagram above). This latent vector can be used to generate random images (similar to a GAN) given a random latent vector but is especially used in image manipulation. This means that in an ideal case, you can take the latent vector of a person with glasses, subtract the latent vector for glasses, and decode this latent vector, getting the same person without glasses. This is one of the reasons that VAEs haven’t been replaced with GANs, because of this latent space manipulation. This latent space is continuous, meaning any point can be sampled, and the result is trained to be as accurate as possible.
The loss functions are one of the most interesting parts of a VAE. There is the normal mean squared error loss to calculate the difference in the pixel values of the images, but adding to that is a KL divergence loss as well. KL divergence put simply, is the ‘scaling’ of the latent space. If the latent space is too large, it can cause empty spaces in the latent space, which causes the results there to be inaccurate. a similar problem emerges when the latent space is too small. The KL loss aims to optimize that scaling of the latent space, essentially making it continuous, giving the VAE the variational part of its name.
Above are some images generated by the VAE. Due to the continuity of the latent space, the images have smoother edges than a typical GAN, which makes the images less realistic. When I tried to use the VAE on ImageNet, this problem manifested even more strongly, and every image started looking more and more like a potato, and it was clear the vanilla VAE just wasn’t cut out for this type of job. The VAE also suffers from posterior collapse where the KL term becomes 0. This brings us to our first modification to the VAE, which the authors of Oord et. al. called the VQ-VAE or the Vector Quantized VAE.
Note: this is an extremely condensed intuition of what goes on in a VAE, many people have written great articles like this one by Joseph Rocca which go into a lot more depth about what happens inside a VAE.
The next step in the evolution of the VAE comes in recreating the hallmark of a vanilla VAE, the latent space. This doesn’t mean removing the latent space, it just allows it to be discrete instead of continuous, hence the Vector Quantization part of the name. The images below are a visualization of these respective latent spaces.
This shows the fundamental difference between the VAE and the VQ-VAE. The points in the latent space are discrete in the VQ-VAE but continuous in the VAE.
But wait, if we make the latent space discrete, won’t we make the outputs unrealistic? That’s where we include a function to calculate the nearest embedding, or the nearest point in the latent space to sample from. This solves the problem of fuzzy outputs because if the VQ-VAE samples from a point it has trained on, then it naturally will be sharper.
However, there lies a problem here. If the VAE only samples from points it has trained on, then it will only output images that look like images in the training set. We can solve this problem by adding superficial points during training around each point in the training set. For simplicity, I will be referring to these points as the surrounding points. We can call the number of surrounding points we add k, which is usually set to 512 or 256. During training, these surrounding points are trained with the training data, so they are as sharp and accurate as the original points.
The loss function of a VQ-VAE remains largely the same compared to a vanilla VAE. However, there is no need for the KL divergence loss, as the nearest embedding function solves that problem. This solves the problem of the posterior collapse in the VAE because there is no KL term to become zero. However, we do need a loss for the vector quantization, or a loss to help fine-tune the values of the surrounding points in the latent space. This helps the latent space function more continuously while still retaining the good quality of the generated images.
In my experience, this type of VAE works well enough to contain the whole of ImageNet in its latent space, which entails that it can reconstruct and return a relatively accurate latent vector for any image taken at any point in time. For me, this model was good enough, and it will probably be good enough for many applications where the quality of images doesn’t matter as much as the content of the images. This is also much easier to train and can be easily reproducible, unlike a GAN. Surprisingly, the VQ-VAE 2 surpassed BIG-GAN in image quality and multiple generation diversity metrics. Although, another way to get the same result is by using the discriminator of a GAN.
While GANs are capable of generating hyper-realistic images that can even fool humans sometimes (see which face is real), you still might need some parts of a VAE, like the latent space, to generate images with the intended characteristics. To do this, we simply insert the discriminator of the GAN after the transposed convolutions of the VAE. We add a discriminator loss component to the loss function of the VAE to help the images become sharper and more realistic. In my experience, however, this is usually harder to train than a normal VQ-VAE, and the VQ-VAE is usually better for most applications.
When GANs started picking up steam, the VAE was left in the dust and was regarded as a lesser, inferior model because of its lower quality images. Then, when the VQ-VAE was designed, that changed because the VQ-VAE produced images on par with GANs, and had many advantages that GANs didn’t have. Now, researchers can finely control the content of the images generated without compromising the quality of the images. Hopefully, now you have a comprehensive understanding of the different types of VAEs and have some basic understanding of what happens inside the generative model. Thank you so much for reading, and have a good one.