WGAN and WGAN-GP

KION KIM
6 min readAug 23, 2018

--

GAN is notorious for its instability while training the model. There are two main streams of research to address this issue: one is to figure out an optimal architecture for better convergence and the other is to fix loss function, which is considered as one of the primary reasons for the instability. WGAN belongs to the latter by defining a new loss function based on a different distance measure between two distributions, called Wasserstein distance. Before, in original vanilla GAN paper, it was proven that adversarial set of loss functions (for discriminator and generator) is equivalent to Jenson-Shannon distance at optimal point. For more detailed information about GAN, please refer to Introduction to WGAN.

Wasserstein distance

A crucial disadvantage of KL(Kullback — Leibler) divergence based metric(Jenson — Shannon distance is just an extention of KL distance to more than two distributions) is that it can be defined only for the distributions that share the same support. If it is not the case, those metrics explodes or be a constant so that they cannot represent the right distance between distributions. WGAN paper has a nice illustration on this and if you need more detailed explanation, you can read this post.

This problem was not a big problem in classification tasks, since entropy-based metric for categorical response has limited number of categories and ground-truth distribution and its estimator must share the support. It is totally different story for generation tasks since we need to generate a small manifold in a huge original data space. Needlessly to say, it must be very hard for a set of tiny manifolds to share their support. Let’s think about MNIST. Assuming gray image, images dwell in 784 dimensional space with 255⁷⁸⁴ elements in it but the size of collected data at hand is just 60,000. I cannot tell precisely but meaningful images that look like hand-written numbers are rare in the entire space of 28 × 28 sized images.

It is hard for a small amount of data to overlap in a huge space

Wasserstein distance can measure how much distributions are apart even when those distributions do not share their supports. It is definitely a very good approach but calculating Wasserstein distance is not easy since it involves another optimization problem itself. The original definition of p-th Wasserstein distance for two probability measures, r and theta, defined on the metric space (M, d), where M is a set to measure and d is a metric defined on M, is given as below.

Basically, it needs to consider all possible joint distributions that minimizes the distance between average distance of two points to get the distance between two probability measures. It is such a big challenge to search the entire space of joint distributions defined on MxM.

Kantorovich — Rubinstein duality tweaks the original optimization problem into a much simpler maximization problem under a certain constraint. Here is a another definition of Wasserstein distance (Earth mover’s distance) derived from Kantorovich — Rubinstein duality.

We don’t need to search (MxM) space anymore, we only need to search for a function that satisfies Lipschitz condition defined on M.

The function f in the above figure is just an example. What we will going to do is to search a function f^* that maximizes expectation amongst all possible K — Lipschitz functions. The main idea of WGAN is that neural network can approximate f^* to obtain accurate Wasserstein distance.

Since WGAN tells us a real numbered distance between real and generated data’s distribution, WGAN can be thought of as a more flexible version of GAN that just say yes or no for the question “Are two distributions the same?”.

Critic vs Discriminator

WGAN introduces a new concept called ‘critic’, which corresponds to discriminator in GAN. As is briefly mentioned above, the discriminator in GAN only tells if incoming dataset is fake or real and it evolves as epoch goes to increase accuracy in making such a series of decisions. In contrast, critic in WGAN tries to measure Wasserstein distance better by simulating Lipschitz function more tightly to get more accurate distance. Simulation is done by updating critic network under implicit constraint that critic network satisfies Lipschitz continuity condition.

If you look at the final algorithm, they, GAN and WGAN, look very similar to each other in algorithmic point of view, but their intuition is quite different as much as variational auto encoder is different from auto encoder. One fascinating thing is that the derived loss function is even simpler than that of the original GAN algorithm. It’s just difference between two averages. The following equation is taken from the algorithm in the original WGAN paper.

Critic implementation

The entire algorithm is given below. Especially critic implementation is highlighted with pink box. When a set of data is given, the algorithm first compares with a set of generated images. To get more accurate distance, it iterates through several steps for critic network to end up with the maximum difference of expectations from real and fake data, which is Wasserstein distance. It my fail to find exact distance, but we want to be as close as possible.

The relevant part of the implementation looks like this. (It’s gluon)

According to the definition of Wasserstein distance, we need to maximize the expectations under two different distributions. For utilizing built-in optimizers in Gluon, we defined cost as negative of the value we want to maximize.

At the end of each critic update steps, to make sure that a function, the critic network surrogates, satisfies Lipschitz continuity condition, the weights are clipped not to let critic network violate Lipschitz condition. The authors didn’t like this heuristic approach though.

Since the first part of Wasserstein distance does not involve generator network’s parameter θ, we can ignore the first part of Wasserstein distance.

Only considering the latter part, we can update the generator network as follows:

The entire code can be found in git repository.

Penalization

In the original WGAN, the Lipschitz constraint was exposed using weight clipping and there was an obvious room for improvement. Instead, the authors in Improved Training of Wasserstein GANs proposed to expose penalty on the norm of weights from critic network. It is one of natural way to control the magnitude of weight matrix to make critic network satisfies Lipschitz condition. The following code shows “penalty part” from the entire implementation.

The rest of the algorithm is exactly the same as that of WGAN.

Results and thoughts

After 400 epochs, I just printed the generated image. Even after 400 epochs, I could not get perfect hand-written number images yet.

According to my experience, those two algorithms seem to be comparable. My personal feeling is that it’s still very hard to generate an images even with WGAN and improved WGAN.

--

--