Understanding Variational Autoencoders

Rajinish Aneel Bhatia
CodeX
Published in
5 min readOct 8, 2021

This blog is not going to focus the rigorous mathematics behind the understanding of variational autoencoders, there are tons of sources available on the internet for that(Ali Ghodsi’s lecture is a great source). What I want to do is, hopefully, provide some useful intuition of what the variational autoencoder is and what its loss function actually defines intuitively(this is the way I understand it). The videos and articles I have read mainly go into the nitty gritty details of things without looking at an approachable view of things.

The goal of variatonal autoencoders is to be generative: meaning that after we are done with training the model, we want to sample from it the things that are vaguely similar what we trained the model on. Think of what a regular auto-encoder does: say we have an image of 256 dimensions; first we encode it to some lower dimension, say, 2 dimensional space. Then the decoder decodes this lower dimensional representation of the image. Now, I want you to imagine this process for the MNIST data: maybe the encoder encodes different digits in different areas in the 2 dimensional space. Imagine a different region for each digit in 2d space. The decoder then learns to decode these lower dimensional representations of the images. So for instance, say, the number 6 comes in. The encoder encodes this as a lower dimensional representation somewhere in the region for the number 6; the decoder then decodes this lower dimensional representation. Now if only we knew where each of the regions of digits were, we could sample a random 2d vector from those regions and pass it to the decoder to get a new image. Maybe we could sample a 2d vector from the region of digit 9 or something. The thing is the 2d space is vast and choosing a random 2d vector is not going to cut it, so we need to somehow enforce the encoder to encode the representations in a confined space of some kind from which we could sample. This is where the cleverness of variational autoencoders comes in.

Here’s how we assume the data (MNIST) has been generated: there are different regions (probability distributions) for digits in 2d, we pick with some probability distribution a region of a certain digit. After the region is chosen, we assume that it has some kind of a probability distribution of its own to generate the samples, say, the normal distribution. After this has been generated, this is then fed to the decoder to generate the image. Now the idea is confine these regions of different digits in some closed space so that we can sample.

Here’s how a regular auto-encoder’s (one hidden layered) pipeline looked like: we have a training example x; it gets mapped to a hidden state z of dimension n (the lower dimensional representation), then the decoder decodes this hidden state. The dimension n is vast and the hidden state can be anywhere. In other words there is some probability distribution p(z|x) for the state z. We don’t know this probability distribution, this is the distribution created internally by the encoder. For us to sample from the distribution in n dimensional a hidden state z, we must get hold of p(z|x). You can read about p(z|x)’s intractability, but the main idea is that since we don’t know p(z|x)- the distribution created by the model- we’re going to force the model to have a distribution we want, say, the standard Gaussian distribution(0 mean, identity covariance matrix). What does that mean? This means that if we manage to do this then p(z|x), the underlying encoder’s distribution, will be a standard Gaussian distribution, meaning that the state z, for each class, would be somewhere on the unit circle in two 2d, on a unit sphere in 3d, so on and so forth because that is the shape of the standard Gaussian distribution.

The question is how do we do this? We minimize the KL divergence between the model’s distribution and the unit Gaussian distribution: we have the neural network output a mean and a variance in n dimensional space, this mean and variance will parameterize the Gaussian distribution of the model and a random hidden state z will be sampled from this. The KL divergence between this distribution of the model and the unit Gaussian will contribute to the loss. The thing is, if we just do this then the model will just output the unit Gaussian for each class which will trouble the decoder, since it will be getting (almost) the same hidden state z for each class. So a simple solution is to add the reconstruction error to the loss. The model would just output 0 mean and identity covariance for each class but it can’t do that because then the reconstruction error would be high, so it has to find a middle ground to maintain a balance between the KL divergence loss and the reconstruction loss. What it’ll end up doing, intuitively, is the following: The encoder will output mean close to 0 and covariance close to identity matrix so as to minimize the KL divergence loss but this mean and variance will be different for each class so as to minimize the reconstruction error. Imagine the MNIST dataset going through this pipeline. Say the hidden state z was two dimensional, and then say for the digit 6 the model outputs some mean and variance it’ll be near the unit Gaussian’s mean and variance i.e it’ll be in the unit circle but it’ll be away from the mean and variance of the other digits(for minimization of reconstruction error). So you can imagine we will have different Gaussians engraved in the unit circle, a Gaussian for each digit after training. Now to generate digits after training we simply sample from the unit Gaussian a state z and pass that to the decoder. Since the encoder’s underlying distribution was forced to be made close the unit Gaussian all the digits’ hidden states will be in the unit circle in 2d. Sampling from the unit Gaussian will give us a state z of some number. We managed to confine the space of hidden states from the infinite 2d to a unit circle in 2d, isn’t that a very clever method? The underlying mathematics behind this is a bit subtle and non intuitive the first time you go through it (at least this was the case for me), hopefully going at it with a bit of intuition would be helpful.

--

--