A Simple AutoEncoder and Latent Space Visualization with PyTorch

Tingsong Ou
6 min readDec 26, 2022

--

I. Introduction

Playing with AutoEncoder is always fun for new deep learners, like me, due to its beginner-friendly logic, handy architecture(well, at least not as complicated as Transformers), visualizable latent space, and also interactivity in generating outputs. In this story, I’ll introduce a simple AutoEncoder model from scratch, along with some methods to visualize the hidden states to make learning a bit of fun.

Colab: https://colab.research.google.com/github/terrence-ou/DL_Playground/blob/main/%5BHC%5DAutoEncoder.ipynb

II. Preliminaries

Here we first import the preliminary libraries

III. Dataset

We’re using FashionMNIST dataset for this task. Here is a link to the dataset on Kaggle: https://www.kaggle.com/datasets/zalando-research/fashionmnist/code. The dataset is already included in the torchvision library; we can directly import and process the dataset with a few lines of code.

The first step is to write a collate function to convert the dataset from PIL image to torch tensors, and padding 2 pixels on each side:

Then we download/load train and validation datasets, put them into data loaders:

We can use the following code to inspect the dataset, see if everything is on the right track:

outputs of the dataset sanity check code
Figure 1 Samples of the dataset

When we have the above plots and we get the batch shape of torch.Size([64, 1, 32, 32]), we are good to move to the AutoEncoder part.

IV. Encoder-Decoder Stack

(Encoder-Decoder Stack here)

Here we will implement a mirrored encoder-decoder stack with three convolutional layers each for simplicity.

# Model parameters:
LAYERS = 3
KERNELS = [3, 3, 3]
CHANNELS = [32, 64, 128]
STRIDES = [2, 2, 2]
LINEAR_DIM = 2048

Encoder

torchsummary is quite a convenient tool for checking and debugging the model’s architecture; we can check the layers, the tensor shape in each layer, and parameters of the model.

----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 16, 16] 320
BatchNorm2d-2 [-1, 32, 16, 16] 64
GELU-3 [-1, 32, 16, 16] 0
Dropout2d-4 [-1, 32, 16, 16] 0
Conv2d-5 [-1, 64, 8, 8] 18,496
BatchNorm2d-6 [-1, 64, 8, 8] 128
GELU-7 [-1, 64, 8, 8] 0
Dropout2d-8 [-1, 64, 8, 8] 0
Conv2d-9 [-1, 128, 4, 4] 73,856
BatchNorm2d-10 [-1, 128, 4, 4] 256
GELU-11 [-1, 128, 4, 4] 0
Dropout2d-12 [-1, 128, 4, 4] 0
Flatten-13 [-1, 2048] 0
Linear-14 [-1, 2] 4,098
================================================================
Total params: 97,218
Trainable params: 97,218
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.45
Params size (MB): 0.37
Estimated Total Size (MB): 0.83
----------------------------------------------------------------

Decoder

Decoder layer is a mirrored Encoder in our case; it’s important to ensure each layer's input and output shape. Also, we should tweak the padding and output_padding parameters in the transpose convolutional layers to ensure the same dimensionality of the output (generated) image and the input (original) image.

And the summary of Decoder architecture:

----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 1, 2048] 6,144
ConvTranspose2d-2 [-1, 128, 8, 8] 147,584
BatchNorm2d-3 [-1, 128, 8, 8] 256
GELU-4 [-1, 128, 8, 8] 0
Dropout2d-5 [-1, 128, 8, 8] 0
ConvTranspose2d-6 [-1, 64, 16, 16] 73,792
BatchNorm2d-7 [-1, 64, 16, 16] 128
GELU-8 [-1, 64, 16, 16] 0
Dropout2d-9 [-1, 64, 16, 16] 0
ConvTranspose2d-10 [-1, 32, 32, 32] 18,464
GELU-11 [-1, 32, 32, 32] 0
Dropout2d-12 [-1, 32, 32, 32] 0
Conv2d-13 [-1, 1, 32, 32] 33
================================================================
Total params: 246,401
Trainable params: 246,401
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 1.52
Params size (MB): 0.94
Estimated Total Size (MB): 2.46
----------------------------------------------------------------

AutoEncoder

Now we have the Encoder and Decoder ready, let's put them together in a nn.Module class to build the AutoEncoder.

V. Visualizing the Latent Space

Before moving to the training part, let’s spend some time writing a function to visualize our model's latent space, which is the dataset’s low-dimensional representation. You may already notice and doubt that a two-dimensional latent space is insufficient to represent the input's information. The reason for choosing the 2D latent dimension is purely for latent space visualization; increasing the dimension is definitely a good move for a better reconstruction.

We can write functions to store latent space’s scatter plots in training or show them after training.

Here’s a plot generated during training; the process shows the model’s latent space distribution with time. Although classes are clustered in training, several classes are mixed with others, causing ambiguity when sampling and generating new images.

Figure 2 Latent space distribution in training steps

VI. Training

One more step before writing training and validation functions is defining the objective function and optimization method.

The vanilla AutoEncoder is a self-supervised model, the input is also the ground truth for optimizing the network, so we can use MSE (Mean Square Error) loss to evaluate the pixel-wise loss between the input and the reconstructed image. We have wide choices of optimizers, and here my choice is AdamW as I have used it quite a lot during the past several months.

Train and Validation Functions

The validation function is a bit simpler as we don’t need to update the model in it.

Then we train the model with the epochs we set earlier. FashionMNIST is a toy dataset that we don’t really need to train a lot; the train and validation loss is very low and there’s not a lot of space to improve after three epochs.

Epoch 1/3
Train loss: 0.0285 Validation loss: 0.0255 lr: 0.0005
Epoch 2/3
Train loss: 0.0248 Validation loss: 0.0238 lr: 0.0005
Epoch 3/3
Train loss: 0.0235 Validation loss: 0.0230 lr: 0.0005

VII. Results

We can now plot and check the latent space again. The clustering of classes is better than the ones in the training steps, but some classes are mixed in in the same clusters. This problem can be addressed by increasing the latent dimension or using other loss functions like center loss or triplet loss. Another problem is that, AutoEncoder is a deterministic method that cannot create new samples; one should input an exact sample as encoded to the decoder to reconstruct the original image. But for this post, we only concentrate on vanilla AutoEncoder optimized by MSE loss.

Figure 3 Latent space distribution after 3 epochs of training

Then we can draw some points randomly in the region to reconstruct the images.

Figure 4 Reconstructed images from random samplings

Feel free to explore different parameter combinations or play AutoEncoders with other datasets, and hope you enjoy exploring the unlimited possibilities with even a simple setting. Thanks for your time reading this essay!

Reference

  1. Murphy, Kevin P. Probabilistic machine learning: an introduction. MIT press, 2022.
  2. Foster, David. Generative deep learning: teaching machines to paint, write, compose, and play. O’Reilly Media, 2019.

--

--

Tingsong Ou

MSCD, Carnegie Mellon University. Interested in Deep Learning and Reinforcement Learning