PokeGAN: Generating Fake Pokemon with a Generative Adversarial Network
The Goal
As the title points out, the goal of this post is to walk through the process of creating fake Pokemon using a Generative Adversarial Network (GAN). Hopefully by the end of this post, there will be some serviceable Pokemon created by this system!
What is GAN?
GAN is a form of unsupervised learning, where two neural networks face off in conflict with each other. For the task of image generation, the first neural network tries to generate fake images using a seed of random numbers, or even starter images. This neural network is known as the generator. The opponent: a neural network that learns to differentiate between real and fake images. This agent is called the discriminator.
What can GANs Do?
GANs are used to generate data — it’s in the name. GANs can be used to make realistic images, such as human faces that aren’t actually real, and they can augment pre-existing images. An open and interesting research application for GANs is how they can be used to augment small datasets by generating more samples for another AI agent to train on. While GANs are often used with images, they can be used for generating text, audio or other data. GANs have been a topic at the forefront of ethics in machine learning, especially due to their ability to generate “deepfakes” — fake images, audio, or video that attack a target person by depicting actions that paint the subject in a bad light.
I will outline an approach for creating deep convolutional GANs, or DCGANs for short. All this means is that the neural networks I am using are convolutional neural networks, rather than standard linear networks. On to the fake Pokemon!
Data Processing
I am using a Kaggle dataset created by user kvpratama, with images of over 800 Pokemon. For GANs, ~800 images is a small dataset; during training of my early models I was unable to get more than indistinct outlines of what could be Pokemon due to the small training size. In order to solve this problem, I used PyTorch transforms to create mirrored and different colored training images. This tripled the size of my training data, and I came up with much better results.
One additional thing I could have done would be to sort the Pokemon by class. In this case, two routes would have made sense — either sorting the Pokemon by type (fire, water, flying, rock, etc.) or by dominant color. Since the dataset didn’t provide either of these things, I chose to hold off on that for a future version. Additionally, the dataset is still small, so segmenting the dataset further would have been worse than the gains of having similar looking pokemon.
I normalized my image data on a scale of -1 to 1, instead of 0 to 1. This helps with updating the weights for the generator.
Here’s a sample of a Pokemon batch:
One other thing that worked well for me was reducing the image size. The source images are 256x256, and loading them into Kaggle’s GPU caused me many memory headaches, as my networks had to be bigger to accommodate the larger images. By reducing the image size to 64x64, my training times went down and my results improved at the cost of slightly worse resolution. In the future, I might try doing 128x128 sized images to see if I can get the best of both worlds.
The Model
Discriminator
First up is the discriminator model. Here’s the PyTorch implementation of the architecture that worked the best for me:
This discriminator features 5 convolutional layers, and LeakyReLU layers for activation functions. Batch Normalization is performed after each convolutional stage, except for at the output.
I chose to set a constant high number of filters for my layers rather than ramping up to a high number and stopping. I experienced better performance when running constant filter values than when running changing values. This drove my choice to get to 128 filters as soon as possible and stay there until the last layer.
If you’re interested in what worked well for 256x256 images, it was the first discriminator I tried:
I think this model could have performed even better without memory limits. With a picture containing 16x more pixels than the 64x64 version, I was only able to comfortably train this discriminator on up to 512 channels. This is only 4x better than my 128 channels in the 64x64 case — and I could have gone higher for that case.
However, Kaggle is free and I’m relatively cheap. As long as some realistic Pokemon pops out, I’m happy.
Generator
The generator I used is similar to the discriminator I described above:
This features Convolutional Transpose layers, which are used to upsample the random input vector into a full image. Unlike the discriminator, whose goal is to condense an image to a single value — real or fake — the generator is taking in a random vector and converting it to an image. These layers make that possible by expanding the vector out until it is a (hopefully) realistic looking Pokemon.
Training the GAN
Hyperparameters
GANs are very sensitive to hyperparameter changes, and there are a lot of parameters to change. To name a few: batch size, input vector size, number of filters, kernel sizes, learning rates, number of training epochs. The list goes on, but in general I followed the following rules:
- Use a large number of filters. This gives the convolutional layers more information, which is better
- Use smaller batch sizes. The discriminator can outlearn the generator quickly with large batch sizes.
- Keep the kernel sizes medium. I tried making the kernel size small and the results weren’t good. I might try moving them up to 5 or 7 in future models, but 4 worked well and made the math easy for layer sizing
- Slightly high learning rates are better than conservative ones. Training nearer to 0.003 outperformed 0.001 for my models, probably due to the generator needing to learn somewhat fast to keep up with the reliable discriminator.
Tricks
Along with some of the above rules, I picked up some hot tips for training from this article. In particular, I found that adding noisy labels instead of reliable labels made a huge impact on my GAN’s performance. I also flipped my labels (0 = real, 1 = fake) and noticed slight improvements. If that isn’t intuitive to you, I would say you can stick to 0=fake, 1=real mode and you won’t be missing out a ton.
Discriminator
Here’s the discriminator’s training function:
The discriminator makes predictions on real data, then on the fake data currently being generated by the generator. These scores are used to update both models.
Generator
Here is the generator’s training function:
The generator reads a “latent batch,” which is a set of random input vectors. It generates images and feeds them to the discriminator, which then predicts how fake they are. If the discriminator does well, then the generator undergoes a big update, otherwise the discriminator probably will receive a larger update in the discriminator training step. In this way, both models work in tandem to improve each other.
Overall Training Loop
This is the full training loop! After loading the models on the GPU and choosing a learning rate and number of training epochs, I just hit run on this block:
and then wait for good results!
Results
Pictures
The results of this GAN are some decent (and many really mediocre) looking fake Pokemon. You could probably pick out a couple that could inspire a future Pokemon, but none are quite game-ready yet:
Fancy Charts
To analyze the GAN’s performance, we can look at some charts to see how well the generator and discriminator grappled with each other throughout the training cycle. Ideally, neither totally dominates the other, however it is better if the discriminator is able to distinguish images over time. Here are the losses for each model over time (lower is better):
In general, the losses for both models was relatively good throughout. The discriminator was prevented from hitting 0 by the noisy labels, which made it possible for the generator to not lose out over time. It would be better if the values were calibrated such that the noisy labels didn’t have such a huge effect, however.
Additionally, I have data for the scores the discriminator gave to the generated data and the real data. Remember, a lower score indicates the discriminator thinking the image is real:
Over time, the discriminator is able to start to tell which Pokemon are real or fake. This is pretty good, but it could be better if the discriminator didn’t figure out the generator so fast.
Conclusions and Future Work
In this article, many different models are presented, and the best ones’ results are shown. The results don’t scream Pokemon to me, but the potential is there with a bit more work and fine tuning. GANs are difficult to get right, but I think this is a good first introduction and something to expand on in the future.
I have outlined many ways to improve my current system. Among those, I want to create bigger kernel sizes at the 64x64 image size and see if I get better results. I also may increase the number of filters in this image size to get more detailed images. If I can find a way to manage my Kaggle GPU memory better, I will increase to 128x128 or higher image sizes. In the future, I may also look for additional datasets to pull in from similar art styles to Pokemon so I can generate crossover characters.
As I continue to work on this project, I will post updates here if I can achieve better results, and if this future work ends up working out (or not).
Resources
Full Notebook
Kaggle: https://www.kaggle.com/jkleiber/pokegan
Github: https://github.com/jkleiber/PokeGAN
Acknowledgements
DCGAN Example Framework from Jovian.ml: Link
Tips and Tricks for GANs (in article): Link
More Tips and Tricks (used for batch size, but good ideas for future development): Link