From GANs to Wasserstein GANs !!!

Udith Haputhanthri
the-ai.team
Published in
6 min readApr 5, 2021
Fig 1. Advancements of GANs from StyleGAN — [6]

Introduction

When considering the literature of Generative Adversarial Networks, Wasserstein GANs have become one of the key concepts due to their training stability compared to conventional GANs. In this article, I will be going through the concept of gradient-penalty-based WGAN.

The article is organized as follows,

  1. Intuition behind WGANs
  2. Comparison between GAN and WGAN
  3. The mathematical background of Gradient-Penalty based WGAN
  4. Implementation of WGAN on Celeba-Face dataset from scratch using PyTorch
  5. Discussion of the results

If you are new to Generative Adversarial Networks, please check my previous articles on,

  1. Introduction to Deep Convolutional Generative Adversarial Networks using PyTorch
  2. Image-to-Image Translation Using Conditional DCGANs !!!

Intuition behind WGANs

GANs are first invented by Ian J. Goodfellow et al. In a GAN, there is a two-player min-max game which is played by Generator and Discriminator. The main issues of earlier GANs are mode-collapsing and vanishing gradient problems. To overcome these issues, a lot of techniques have been invented throughout time. WGAN is one of the methods which tries to overcome those issues of conventional GANs.

GAN vs WGAN

When compared to conventional GAN, WGAN has several improvements/ changes.

  1. Critic instead of Discriminator
  2. W-Loss instead of BCE Loss
  3. Weight regularization using Gradient Penalty/ Weight clipping

The Discriminator of conventional GAN is replaced by the “Critic”. Considering the implementation perspective, this is not more than a Discriminator without Sigmoid activation in the final layer.

We will be discussing the WGAN loss function and weight regularization in a little while.

Mathematical background

Loss function

Here is the complete loss function of the gradient-penalty-based WGAN.

Eq 1. complete WGAN loss function with gradient penalty — [3]

Looks scary right? Let’s break down the equation.

1st part: Original Critic Loss

Eq 2. Original critic loss

The value resulted from this equation should be maximized positively by Generator while maximizing negatively by the Critic. Note that in here x_CURL are the generated images from the generator (G(z)).

Here, D has no Sigmoid activation in the last layer, therefore D(*) can be any real value. This gives an approximation to the Earth mover’s distance between real and generated distributions — [1]. What we trying to do here is,

  1. Critic’s perspective: Separate the critic’s output distributions for real and generated images as much as it can by maximizing the negative value/ minimizing the positive value of the result of eq 2. This reflects the Critic’s objective of giving higher scores to real images and lower scores to generated images.
  2. Generator’s perspective: Try to undo the Critics effort by separating its output distributions for real and generated images in the opposite direction. This eventually maximizes the positive value of the result of eq 2. This reflects the Generator’s objective of making the Critic scores for generated images higher by fooling Critic.
  • Here you may have noticed that the name Critic over Discriminator comes because Critic does not discriminate between real and fake images and just gives an unbounded score.

To make sure that the equation is valid, we need to make sure that the Critic function is 1-Lipschitz continuous — [1].

1-Lipschitz Continuity

A function f(x) is to be 1-L continuous, the gradient should be always less than or equal to 1.

To ensure this 1-Lipschitz continuity, there are mainly 2 methods proposed in the literature.

  1. Weight Clipping — This is the initial method that comes with WGAN paper [2]
  2. Gradient Penalty method — This is presented after the initial paper as an improvement [3]

In this article, we will be focusing on the Gradient Penalty-based WGAN.

2nd part: Gradient Penalty

Eq 3. Gradient Penalty

This is the gradient penalty which is presented by Gulrajani et al. — [3]. Here we force the Critic’s gradients to be 1 by reducing the squared distance between the L2 norm of the critic’s gradient and 1. Note that we cannot force the Critic’s gradients to be 0 because it will cause the Vanishing Gradient Problem.

Wait! what is x(^)?

Considering the definition of 1-Lipschitz continuity, the gradients should be ≤1 for all x. But practically, ensuring this condition for all possible images is hard. Therefore we use x(^) which denotes the randomly interpolated images using real and generated images as the data points for the gradient penalty. This makes sure that the gradients of Critic will be regularized by looking at a fair set of data points/ images that it meets during training.

Implementation

Here I will be presenting the necessary changes that someone should do in order to change their conventional GAN into WGAN.

For the below implementation, I will be using the models and training principles I have explained more in my previous article about DCGAN here.

Dataset

The Celeba-face dataset is used for the training. Downloading, preprocessing, making data loaders scripts are shown in Code 1.

Code 1. dataset

Generator and Critic

The Critic is the same as Discriminator but does not contain the last layer Sigmoid activation.

Code 2. Generator and Critic

Supporting blocks for Generator and Critic are shown below in Code 3.

Code 3. Supporting blocks

Loss function

The loss function may be a bit tricky because it contains gradients, unlike any other typical loss function. Here we will implement the W-loss with Gradient Penalty which can plug into the WGAN model later.

Code 4. Generator and Critic loss functions

Let's breakdown the loss function shown in Code 4.

  1. Generator Loss — Generator loss is not affected by the gradient penalties. Therefore it has to maximize only the D(x_CURL)/ D(G(z)) term which implies that minimizing -D(G(z)). This is implemented in line-2.
  2. Critic Loss — Critic loss contains 2 parts of the loss shown in eq 1. In line-6, the first 2 terms give the original critic loss explained in eq 2 while the last term gives the Gradient Penalty explained in eq 3.

Gradient Penalty can be implemented as in Code 5 below — [1].

Code 5. Gradient Penalty

In code 5, the get_gradient() function returns all the gradients of the network starting from x_hat (mixed images) and ending from Critic’s outputs (mixed_scores). This will be used in the gradient_penalty() function where it returns the Mean Squared Distance between 1 and L2 norm of the Critic’s gradients.

Reducing the Critic’s loss will eventually reduce this gradient penalty. This ensures that the1-Lipschitz continuity is preserved by the Critic’s function.

Training

Training will be almost the same as in the previous article. But here the losses are different than the conventional GAN loss. I have used WANDB to log my results. If you are interested in logging your results, WANDB is a really good tool.

Code 6. training

Results

Here the results obtained after 10 epochs of training. The generated images are becoming more realistic with time as in the conventional GAN. WANDB project with all the results can be found here.

Fig 2. Results after 10 epochs

Conclusion

Generative Adversarial Networks have been a popular topic among the Deep Learning community. Due to the drawbacks of conventional training methods of GANs, WGAN has become more popular over time. This is mainly because of its robustness to the mode collapsing and not having vanishing gradient problems. In this article, I have gone through the implementation of a simple WGAN model which is capable of generating human faces. If you are interested and curious more about this area, you can find this great article regarding WGANs written by Vincent Herrmann — [5].

Feel free to check out the GitHub repository. Any comments, suggestions, and advice are greatly appreciated. ❤️

Reference

[1] GAN specialization on coursera

[2] Arjovsky, Martin et al. “Wasserstein GAN

[3] Gulrajani, Ishaan et al. “Improved Training of Wasserstein GANs

[4] Goodfellow, Ian et al. “Generative Adversarial Networks

[5] Vincent Herrmann, “Wasserstein GAN and the Kantorovich-Rubinstein Duality

[6] Karras, Tero et al. “A Style-Based Generator Architecture for Generative Adversarial Networks

--

--