Simple GAN using PyTorch

Nikolaj Goodger
2 min readAug 30, 2020

--

Generated MNIST digits after 20 epochs

Generative Adversarial Networks (GANs) are a super cool way to model a high dimensional distribution using deep neural networks. GANs can famously generate photorealistic images. In this implementation though, I wanted to have a generic, robust implementation that can be easily understood and adapted to other use cases.

This implementation borrows heavily from DCGAN, specifically the PyTorch DCGAN Tutorial. GANs using the original discriminator loss function like DCGAN can be difficult to train and suffer undesirable behavior like mode collapse (GAN loses the ability to model parts or all of the training data distribution). Significant research has gone into mitigating these issues. One improvement that has come out of this is the Wasserstein GAN. This implementation changes the discriminator to be a critic of the realness of the image rather than a binary classifier of real and fake. This is achieved by implementing the proposed Wasserstein loss instead, which was further improved with the introduction of a Gradient Penalty. This loss function has been used to produce some pretty amazing results, such as generating high-resolution celebrity photos with Progressive GAN.

To keep things simple and relatable the MNIST dataset was chosen but it should be straightforward to replace the dataset and generate new classes of images. The implementation currently requires a GPU as training on a CPU would be very slow, at least with the configured model size. I hope you find it useful!

There is also a conditional implementation available.

--

--