Implementing conditional variational auto-encoders(CVAE) from scratch
In the previous article we implemented a VAE from scratch and saw how we can use to generate new samples from the posterior distribution. There was one problem, however: it was not easy to generate samples with specific properties.
For example, if I wanted to generate a bunch of 2s or 5s specifically, I’d be in trouble. Of course one can generate a sufficiently large set of samples and then select those matching the desired criteria. It is obvious however that it’s gonna take too much time.
Also, what if we were to ask the following question: on the picture below, what are some reasonable completions of that picture that we can generate?
Generating images with a specific label
In order to generate an image with a specific label, our AE needs to learn how to decode the latent variable when given a hint. In that case we say that we condition our AE on some information. How can we do that? Well, one obvious idea is to pass one-hot encoded digit label to the decoder so it learns a decoding process conditioned on that. On a diagram it would look something like this:
Notice how we now have an additional source of information for our decoder. Why do we do linear projection and then summation? Well, it’s just one of the options. The linear projection is there is match the dimensions of the code layer and the label information. We then just sum them. We could also take a mean, or do point-wise multiplication, or just concatenate those vectors — anything like that will do. It’s just a way to communicate to decoder what exactly we are trying to decode.
Then at inference time the only thing we do is just pass one-hot label of the digit we are willing to generate. Here is how it looks in the code:
class ConditionalVAE(VAE):
# VAE implementation from the article linked above
def __init__(self, num_classes):
super().__init__()
# Add a linear layer for the class label
self.label_projector = nn.Sequential(
nn.Linear(num_classes, self.num_hidden),
nn.ReLU(),
)
def condition_on_label(self, z, y):
projected_label = self.label_projector(y.float())
return z + projected_label
def forward(self, x, y):
# Pass the input through the encoder
encoded = self.encoder(x)
# Compute the mean and log variance vectors
mu = self.mu(encoded)
log_var = self.log_var(encoded)
# Reparameterize the latent variable
z = self.reparameterize(mu, log_var)
# Pass the latent variable through the decoder
decoded = self.decoder(self.condition_on_label(z, y))
# Return the encoded output, decoded output, mean, and log variance
return encoded, decoded, mu, log_var
def sample(self, num_samples, y):
with torch.no_grad():
# Generate random noise
z = torch.randn(num_samples, self.num_hidden).to(device)
# Pass the noise through the decoder to generate samples
samples = self.decoder(self.condition_on_label(z, y))
# Return the generated samples
return samples
Notice, there is a new layer called `label_projector` that does that linear projection. Also, the latent code is passed through that layer both during forward pass and sampling processes.
A CVAE is trained with exactly the same loss function as it’s non-conditioned buddy. Here is the loss
def loss_function(recon_x, x, mu, logvar):
# Compute the binary cross-entropy loss between the reconstructed output and the input data
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")
# Compute the Kullback-Leibler divergence between the learned latent variable distribution and a standard Gaussian distribution
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Combine the two losses by adding them together and return the result
return BCE + KLD
Now let’s generate some samples of the digit 8, shall we?
num_samples = 10
random_labels = [8] * num_samples
show_images(
cvae.sample(num_samples, one_hot(torch.LongTensor(random_labels), num_classes=10).to(device))
.cpu()
.detach()
.numpy(),
labels=random_labels,
)
Here are two more examples, a bunch of 3s and 0s:
They look exactly the same. I had to spend some time to try to figure out what is going on. Basically, my initial though was that if it samples stuff that is similar, it’s must be getting a terrible score at reconstruction loss, but it is still somehow not reduced, even though the overall training loss has gone down significantly.
What I’ve realised is that the reconstruction loss is probably shadowed by the KL loss. So I started down-weighting the KL loss and eventually, I started getting more and more diverse images. I ended up with the following balance between reconstruction and KL losses:
loss = criterion(decoded, images) + 0.00001 * KLD
However, there appears to be more to this story. Specifically, the balance between the two losses is basically the trade-off between being precise and being creating with the process, as in like left and right brains hemispheres. Basically, downscaling the KL component leads to less penalty being imposed on following the prior distribution which leads to more creativity which in turn leads to less interpretability. There is some great discussion that can be found in these papers:
- https://arxiv.org/pdf/2006.13202.pdf — 𝜎-VAE paper
- https://openreview.net/pdf?id=Sy2fzU9gl — the original 𝛽-VAE paper
While the concept of conditioning images on something may sound easy, the implementations details are not that easy to get right right away.