A Comparison of WGAN Implementations (WGAN-GP and WGAN-SN)

Brad Brown
7 min readSep 13, 2019

--

A WGAN is a type of network used to generate fake high quality images from an input vector. In this experiment, I implemented two different improvements of a WGAN in Pytorch to see which one is able to perform the best in terms of speed and quality of generated images. (Github: https://github.com/BradleyBrown19/WGANOptimizations).

WGAN with spectral norm after 230k iterations

What is a WGAN?

I’m assuming prior knowledge of GAN, so I recommend reading the following if you are unfamiliar: https://towardsdatascience.com/image-generator-drawing-cartoons-with-generative-adversarial-networks-45e814ca9b6b

There is a bunch of complicated math in the formulation of WGANs, but i’m going to try and keep it simple and explain it intuitively.

The main concept behind the WGAN is to try and minimize the earth mover’s distance. The Earth Mover’s distance (equation 1) gets it’s name because it can be though of as the minimum work required to transform one pile of dirt into another pile. In other words it is the amount of dirt times the distance moved.

In our case, the dirt piles are the probability distributions of the real images, Pr, and the generated images, Pg.

Equation1: Wassertein-1 or Earth Mover’s Distance

Now, let’s break down the above equation. Π(pr,pg) means take all the joint distributions between Pr and Pg and finds the greatest lower bound (inf) of moving the ‘dirt’ a distance between x and y ( ‖x−y‖) times the total amount of dirt Eγ(x,y). In other words, we are trying to find the least amount of work or cost to transform the probability distributions of the generated images to the real ones.

That was a mouthful.

Using something called the Kantorovich-Rubinstein duality (see appendix of: https://arxiv.org/pdf/1701.07875.pdf), equation 1 can be transformed into:

Equation 2: Simplified Earth Mover’s distance

where f is a 1-Lipschitz that must follow the constraint:

1-Lipschitz constrain

What is f? How do we find it to maximize equation 2?

Drumroll please… we use a neural net ie. the critic. The critic is very similar to the discriminator in a normal GAN however, instead of outputting a probability of the image being real or fake, it outputs a scalar representing how ‘real’ the image is. The only change in architecture between the two is that the new critic omits the sigmoid layer because it is no longer needed. The critic cannot cover all possibilities of f to find the supremum of equations 2, but since the approximating ability of a neural network is huge, it provides a very good estimate.

Another change with the WGAN is that the log functions in the losses are no longer needed.

Now that we have everything that we need set up, all that’s left is to satisfy the 1-Lipschitz constraint. In a normal WGAN this is done by simply clipping all the weights in the discriminator, literally making sure they are below an absolute value. This value is typically 0.01.

Implementation 1: WGAN with Gradient Penalty and a Consistency Term

What is Gradient Penalty?

The big disadvantage of weight clipping is that it puts a cap on how effective the model can be. If the weights are clipped too much, then the model is unable to learn to model complex functions and as such the approximation of f is not optimal. On the other hand if the weights aren’t clipped enough, it leads to vanishing gradients. Overall the original WGAN is too sensitive to this clipping to be as effective as can be.

Enter gradient penalty,

Gradient Penalty is another, less restrictive, method to enforce the 1-Lipschitz constraint. By the definition of a 1-Lipschitz function (see equation 3), a functions satisfies this constraint if the max norm of the gradients is 1.

Gradient Penalty

This is implemented in the discriminator’s loss by adding in extra gradientPenalty term in the loss. This term is calculated by taking the mean of all of the critic’s gradients when fed a mix of real and fake data and taking the mean of the square of the difference between all the gradients and one. A hyper-parameter of 10 is the standard multiplier for this term.

This method involves computing all the gradients of the discriminator at every step of learning and is very computationally costly.

What is a Consistency Term?

Further regularization is added to the neural network with the introduction of a consistency term.

This is calculated by feeding the same real data to the critic twice with dropout of about 0.5. We take both the final output as well as the activations before the final output.

Dropout is essentially turning off half the weights in the critic making the two outputs different from each other despite being the same input.

The consistency term favours the critic to produce the same scalar output despite the dropout difference. This is done to increase the consistency and dependability of the critic’s output.

Implementation 2: WGAN with Spectral Normalization

Again, for training a WGAN, we need our function to be Lipshitz continuous. It’s easy to see from the equation that In the 1-D case, this means that the value of K must be greater than the maximum value derivative of the function. (For a visual proof of this see: christigancosgrove.com).

This is where Spectral Normalization comes in. Since we know that WGAN’s are required to be Lipschitz continuous, we must find a way to contain the gradients of the discriminator.

In the case of a multidimensional function A: Rⁿ -> Rᵐ, the spectral norm is defined as the largest singular value of A which is also the square root of the largest eigenvalue of AᵀA (see here). Through a bunch of linear algebra that is perfectly summed up here). We can see that this also happens to be the Lipschitz constant of the linear function as well.

Now that we see that the Lipschitz constant for this general linear and differentiable function is it’s spectral norm of it’s gradient over it’s domain:

Now, we need to find (and contain) the Lipschitz constant for the composition of functions (the discriminator) so that it is Lipschitz continuous and works as a WGAN.

According to the chain rule for a composition of functions:

Where the terms on the right are just gradients of a matrix being multiplied together. Therefore, we can find the spectral norm of the composition by simply finding the spectral norm of the product of the gradients.

The above can be broken down into this final equation:

If we can fix each of the spectral norms of the linear functions to 1, then the composition will also be fixed to 1 and the WGAN will satisfy Kantarovich-Rubenstein duality.

Now, the final thing left to do is to make sure every linear function’s spectral norm is bounded. This can be simply done by computing the following: W/σ(W), where σ(W) is the singular largest value of W (it’s spectral norm).

This is done with a nifty trick called power iteration. Which after some manipulation boils down to the simple equation:

Where uᵀ and v are vectors in the codomain and domain of W respectively. This is great because we only need to compute the vectors u and v for each weight at every step of learning. This is very computationally inexpensive and makes this implementation way faster than WGAN-GP.

Here is a great algorithmic depiction of a spectral norm implementation:

Final Thoughts

All in all, spectral normalization is the superior method for optimizing WGAN’s.

At each epoch, the two GANs were very similar on performance with a very similar quality of image at each epoch.

However, for 4700 iterations, the spectral norm implementation only took 25 minutes, whereas the gradient penalty took 3 hours and 35 minutes. This makes sense due to the large amount of extra overhead required in the gradient penalty implementation.

One thing to note however is that the discriminator in the gradient clipping case was a marker of the quality of generated images whereas it wasn’t with spectral normalization. The number is nice because it is an accurate way to tell if your GAN is still improving.

--

--

Brad Brown

Software Engineer at the University of Waterloo. AI Enthusiast