A Basic Variational Autoencoder in PyTorch Trained on the CelebA Dataset

Moshe Sipper, Ph.D.
The Generator
Published in
5 min readOct 31, 2023

--

Pretty much from scratch, fairly small, and quite pleasant (if I do say so myself)…

AI-generated image (craiyon)

I recently found myself in need of a way to encode images into latent embeddings, tweak the embeddings, and then generate new images. There are powerful methods to create embeddings or generate from embeddings. If you want to do both, a natural, fairly simple approach would be to use a variational autoencoder.

Such a deep network not only does both encoding and decoding, it is also fairly simple, and I could use it down the road for my research, without worrying too much about various hidden complexities in the encoding-decoding phase. I also preferred to have as much control over the software innards.

So, with all these specs in mind, I gathered bits and pieces from GitHub, sprinkled some of my own magic, and ended up with a nice, simple variational autoencoder. I’ll describe the main pieces below, with the full package available at:

It’s rather smallish and completely self-contained — which was my intention!

Autoencoders

To keep this article short and palatable I’ll refrain from providing a lengthy overview of variational autoencoders. Besides, you’ll find excellent articles on the basics right here on Medium. I’ll just provide three quick pics.

This is what a basic autoencoder looks like:

Source: https://commons.wikimedia.org/wiki/File:Autoencoder_schema.png

In a nutshell, the network compresses the input data into a latent vector (also called an embedding), and then decompresses it back. These two phases are known as encode and decode.

A variational autoencoder (VAE) looks very similar, except for the embedding part in the middle. Instead of a vector in latent space, the encoder of a VAE outputs parameters of a predefined distribution in the latent space, for every input:

Source: https://commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png

One final pic: If we’re dealing with inputs that are images, we’ll want a convolutional VAE, something like this:

Source: https://github.com/arthurmeyer/Saliency_Detection_Convolutional_Autoencoder

Note #1: Observe how the encoder part adds more and more filters in each layer, with the images getting smaller and smaller; the reverse happens with the decoder.

Note #2: Be careful with notations. If there is one channel the terms filter and kernel are basically the same. With more than one channel each filter is a collection of kernels. Check out this great Medium article: “Intuitively Understanding Convolutions for Deep Learning”.

CelebA

The dataset I’ll be working with is CelebA, which contains 202,599 images of celebrity faces.

It can be accessed through torchvision:

from torchvision.datasets import CelebA

train_dataset = CelebA(path, split='train')
test_dataset = CelebA(path, split='valid') # or 'test'

VAE class

My VAE is based on this PyTorch example and on the vanilla VAE model of the PyTorch-VAE repo (it shouldn’t be too hard to replace the vanilla VAE I’m using with any of the other models in PyTorch-VAE).

The file vae.py contains the VAE class along with definitions of the image size, the dimension of the two latent vectors (mean and variance), and the path to the dataset:

CELEB_PATH = './data/'
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

In the VAE class I used the following hidden-filter dimensions:

hidden_dims = [32, 64, 128, 256, 512]

The encoder looks like this:

in_channels = 3
modules = []
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)

Then come the latent vectors:

self.fc_mu = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)
self.fc_var = nn.Linear(hidden_dims[-1] * self.size * self.size, LATENT_DIM)

And finally we go “backwards” with the decoder:

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)

self.decoder = nn.Sequential(*modules)

That’s the gist of it — there are a few more bits and pieces in vae.py to complete the VAE class.

Training

The file trainvae.py contains the code to train the VAE we’ve just coded. Honestly, nothing fancy… There are 3 main functions: train (which also outputs the loss value as training proceeds), test (which also builds a small sample of reconstructed images), and loss_function. Training and testing are fairly run-of-the-mill, and the loss function is standard VAE, with a reconstruction component (MSE) and a KL divergence component.

The main loop over epochs performs 4 operations: 1) train, 2) test, 3) generate random latent vectors and call decode to output the corresponding output images, and 4) save epoch’s model into a pth file.

Here are the outputs of a sample run. With 20 training epochs you end up with 20 reconstructed-image files, 20 latent-sampling files, and 20 python-model files:

Here’s reconstruction_20.png, with top row showing 8 original pics and bottom row showing the respective reconstruction by the trained VAE.

Reconstructed (output) images from the model at epoch 20.

And here’s sample_20.png, which shows 64 images generated from random latent vectors:

Just for fun I added a small bit of code — genpics.py — which picks a random image from the dataset and generates 7 reconstructions. Here are some examples (leftmost image is original):

--

--

Moshe Sipper, Ph.D.
The Generator

Swashbuckling Buccaneer of Oceanus Verborum 🌊 4x Boosted Writer 🚀