Writing a deep learning repo #3
Continuing with my series of blog posts on writing generic DL code for advanced deep learning structures, this time I will be focusing mainly on Wasserstein GAN and how we can try to train them well. This will be shorter and more concise and will mainly try to show how to train more complicated GAN models without going into mode failure.
I would highly recommend first going through my blog post on writing GAN code, before you go ahead and write the code for Energy Based GANs or Wasserstein GANs.
NN graph and more
The key difference between a Wasserstein GAN and standard GAN is that, as mentioned before, the standard (original) GAN will try to learn by optimizing the probability that the discriminator is able to predict if the said input sample is real or fake. But in the case of the Wasserstein GAN, instead of keeping this is as a single bit which is the said probability, we keep a single bit (or a higher dimensional vector), which is effectively a representation of the said distribution. In this case, the discriminator would try to maximize the distance of this representation between samples derived from target and fake distributions, while the generator will try to minimize the said distance by fooling the discriminator.
So we can change the discriminator as follows :
The only key difference is that we derive a single metric but use it without a sigmoid layer. This allows for the lack for probability distribution difference.
We also edit the loss function module (earlier the class cross entropy module) to loss generator module which rather only calculates the mean of the required values.
We can effectively update the loss function from the original GAN to give us an alternate build_model function.
But there is one problem with this. With GANs we very well know that the since we apply sigmoid layer of the last bit, the slope is nearly 1 at 0, and it also helps us to train better and avoid mode collapse. We also see that, this is not the case with Wasserstein GANs. The gradients being back propogated maybe me much larger or much smaller than 1, therefore making it quite difficult for us to train the models.
Therefore it is important to add another term to the loss functions which acts as a gradient penalty and therefore allows us to effectively implement the training modules. In order to this better, we can perturbate the original image with the created image and therefore get a rough estimation of the gradients. And then we can add a loss term which actively pushes the loss functions to maintain the backprop gradient to be nearly 1, or therefore impose the condition that the loss function under consideration is 1-lipschitz. This is important because, the functions even though they are not always differentiable and are always continuous it allows us to train using them, by applying the Kantorovich Rubinstein duality.
Therefore the real build model function should be :
With this done, our model is complete for the Wasserstein GAN setting.
Training the Wasserstein GAN
The Wasserstein GAN is more powerful than the GAN, and is more stable, but that does not mean it does not undergo mode collapse, even after making sure that the approach is quite more dynamic and robust. But most of this can be avoided by using hacks which allow us to train it better.
The easiest hack is to train the discriminator a larger number of times like we did in the GAN setting. Another hack is to occasionally boost the training of the discriminator, since the learning at later stages is quite slow and weak and therefore a rather larger amount of training will not affect the generator, but only actually help it to train faster.
The last hack, which is more of a regularizing hack is that after a point, we can start creating adversarial examples for our own running of the models. The epsilon pertubration leads us to create an alternate image, but after a point that image is as much close to the real image as is the generated image (or as far). Therefore, that allows us to train better by adding it to the dataset and therefore automatically creating more training data from the examples that are available to us. This also allows us to regularize better since these examples during the initial epochs of training would be quite far from the original (or true) distribution.
Due to lack of space, I did not bother to save the intermediate training stage. But as you can see that the wasserstein GAN is able to create near real numbers in all orientation and learn how to do this in a completely un-supervised manner.