Generate Images Using Variational Autoencoder (VAE)

DiShi Zhu
5 min readApr 19, 2020

--

Co-author: Vera Tang

In this post, we want to introduce the variational autoencoder (VAE) and use it to generate new images of handwritten digits by using MNIST as training data. VAE is a generative model that can help to generate fictional data by capturing the characteristics of training data. This model is built on top of neural networks but is not like CNN or other discriminative models (logistic regression, SVM, etc.) that learn the parameter to find the boundary of classifying different classes, this model learns the parameters in order to model the distribution of the data points.

Autoencoder

In order to understand variational autoencoders, it is important to first understand what is an autoencoder? Autoencoder is a very simple neural network structure that consists of two parts, the encoder, and the decoder. Encoder is the one that learns to compress the original input into a compressed form in much smaller vector space and the decoder it the one learns to reconstruct the compressed data into the original input with some loss.

[Fig. 1]

VAE

We observed that traditional autoencoders learn to compress and reconstruct data but not really help with generating new data. This is where Variational Encoder (VAE) came in handy. VAE learns the distribution of the data instead of just a compressed image, and by using the distribution, we can decode and generate new data. The encoder is trying to learn the parameters φ to compress data input x to a latent vector z, and the output encoding z is drawn from Gaussian density with parameters φ. As for the decoder, its input is encoding z, the output from the encoder. It parametrizes the reconstructed x’ over parameters θ, and the output x’ is drawn from the distribution of the data.

[Fig. 2]

Build Encoder

We used two dense layers to construct the encoder

The Encoder graph looks like the following:

Encoder’s graph

Build Decoder

We use the same layers as the encoder to reconstruct images

Decoder’s graph

Instantiate VAE

VAE Graph

Loss function

Then the most important question is to define the loss function. What is the metric we use to evaluate and improve the model? The loss function includes two terms.

The first term is the negative log-likelihood of the decoder which measures that for each data point i, how effectively is the latent vector z reconstructed to x’ and in the code, we use x to approximate x’ and use the built-in binary cross-entropy loss function.

The second term is the KL divergence, which is a regularizer that measures the information loss when we use the encoder q_φ to produce z, and in the VAE, p(z) is a standard Gaussian distribution N(0,1). This is to encourage the encoder to produce z that are close to the gaussian distribution and this means to keep z representations of each different kind of data (in our case, MNIST dataset each digit) to be sufficiently diverse and close to each other in the latent space. Otherwise, the VAE could cheat and map each datapoint in different regions of space without actually learning that for all data points that represent digit 2. They have a much closer euclidean distance than with data points of digit 1 or digit 4.

Loss function implementation

Sampling Trick

The other problem encountered is to know how to train the VAE. Like most of the neural network, we are using gradient descent and back-propagation to minimize the loss function. However, given z is drawn from the qφ(z|x) how can we calculate the derivative of (a function of) z w.r.t. Φ? We know that p(z) is a gaussian distribution so z can be re-parameterized as z= µ + σ ⋅ ϵ, where ϵ is a random sampled number from Normal(0,1) and then the derivative of the layer z to µ and the derivative of z to σ can be calculated.

Implementation of sampling trick

Generating Handwritten Images

After we have trained the VAE, we use only the decoder to generate images by providing it a vector of size two. We generate an interpolation graph by providing the decoder two numbers from -3 to 4, and we generate 196 of those images.

The resulting interpolation graph is:

interpolation graph

Pros and Cons

As we could see from the interpolate result of the generator (decoder), the digit 4 and 9 are very close in Euclidean distance and it’s truly easy for people to have a handwritten that’s indistinguishable between those two. However, we could also see that the resulting image is not sharp enough and usually we could get a better result by using Generative Adversarial Network (GAN) because that GAN is not learning by a user-defined loss function but a loss that is learned by the data and a classifier try to distinguish generated data and original data so that the loss is much more complicated. Thus GAN generates sharper images and is much more difficult to train. In addition, although GAN can generate better results, it does not have an encoder so it never learns to compress the image like.

References

[Fig. 1] https://blog.goodaudience.com/using-tensorflow-autoencoders-with-music-f871a76122ba

[Fig. 2] https://rpubs.com/zkajdan/533047

--

--