Variational Auto Encoder with missing data
In this post, I’m going to share my experiences flexibility of VAE with missing data. All of my implementations are in Pytorch and can be found in my GitHub repo.
I suppose you already know what VAE is. If not, you can find good information on the web! I found this blog and its paper informative. In simple terms (Deep Learning terminology), VAE is a network that learns the distribution of data by an encoder network by fitting it to a gaussian distribution and generates data by a decoder by sampling from the learned distribution. Variational means encoder network estimates µ and σ parameters (that we would call them latent variable) of the Gaussian distribution. The main idea is that we can derive any distribution by a Normal distribution [Wiki].
However, in real-world applications almost always we have missing values. What can we do with these missing values? I conducted some experiments to find out how robust VAE is in handling those missing values.
TLDR;
I found out by using Dropout, VAE can handle missing values very well.
I will demonstrate it by using two datasets. First with MNIST and secondly with a synthetical time series.
Testing with MNIST
I trained VAE with a fully connected encoder and decoder network with latent variable µ and σ with dimension 2 and dropout 0.5.
Below shows a generated sample and test reconstructed samples in epoch 50:
Because the latent variables are in 2 dimensions, we can draw them in a chart. The following chart is the scatter plot of test samples in epoch 50. Here we used µ variable for demonstration.
As an interesting visualization, we can see how these charts are evolved in the training phase through an animation:
Now we want to examine if our VAE is robust with missing values. For that, we masked the first half of data (by zero) and showed that the model still can reconstruct pretty good.
Of course, there are samples that are impossible to reconstruct them correctly because it is not distinct for example between 3 and 8:
Distribution of the data in the latent space can show clearly how much error would occur in the reconstruction:
Testing with synthetic time-series data
Now we rerun our experiments with a time series data but with the same network architecture. This data is randomly generated in 9 classes with different amplitudes and frequencies. The amplitudes are between [0.25, 0.5, 1] and frequencies between [1,2,4].
Now let’s test this model with missing data:
Conclusion
In this blog post, I shared my experiences in working with VAE in Pytorch and demonstrated that with Dropout we can pretty good handle missing values in test data.
Of course, we can investigate other types of missing value patterns (e.g. random) but, masking first of data was the severest pattern.
We can increase the portion of missing data, I didn’t bring it here but in the implementation, it is easy to change this parameter.
I showed this with two different datasets, namely MNIST and synthetic time-series data.
Encoder/Decoder network that I used was vanilla fully connected which is simple enough to demonstrate the idea and good enough to learn our chosen datasets.