Auto-Encoder is a class of generative models, which tries to compute hidden states from training examples, and generates new examples from the hidden states. Varational Auto-Encoder (VAE)(paper: https://arxiv.org/abs/1312.6114) computes means and variances of features, then use a normalize distribution to generate resampled features, lastly the resampled features are used as hidden states to generate new examples. The process is obviouse if you check the pytorch example (https://github.com/pytorch/examples/blob/master/vae/main.py) .
For the Mnist dataset, the encoder works as follows:
image -> flatten -> Linear(784, 400)->features
features -> Linear(400, 20) -> mean
features -> Linear(400, 20) -> log(Variance)
The reparameteration works as follows:
hidden_states = mean + Normal(0, sqrt(Variance))
hidden_states -> Linear(20, 400) -> Relu -> Linear(400, 784) -> image
The loss function is defined as a combination of reconstruction loss and regularization term.
Reconstruction is just the pixel-wise classification loss between the training image and the generated image.
KL Divergence loss is used as a regularization term. VAE assumes the hidden states follows Normal(0, 1) distribution, while the network gets Normal(mean, sqrt(Variance)) for the hidden states. There is a discrepency. KL Divergence can be used to measure the difference between two distributions. The formular to calculate KL Divergence is as follows:
From the following reconstructed and generated images, we can see the KL Divergence loss as a regularization term helps the model to generilise.