Exploring GANs with Pokémon
To generate them is my real test, to train them is my cause
What is this project about?
This article aims at readers that are not familiar with machine learning or in particular with artificial neural networks (ANNs). It is an attempt at a friendly introduction to the basics of ANNs. We will start with fundamental mathematical prerequisites before we dive successively into more abstract and more complex concepts aiming at an intuitively understandable explanation of so-called Generative Adversarial Networks (GANs). I have chosen GANs as an example for a more advanced type of ANNs due to their powerful and exciting concept to generate nice to look at digital images. The goal of this project has been the automatic generation of completely new Pokémon sprite images (GitHub).
What is Machine Learning?
Machine Learning (ML) is a computer science discipline that exploits patterns and regularities in available data to solve specific tasks without the need for any explicit instructions. In ML, algorithms and statistical models are used e.g. to detect anomalies in data, to classify data points in classes, to identify natural clusters in a dataset, to predict future values of a time-series based on the past, or to find lower-dimensional representations for data points. Usually, weights and parameters of a general mathematical model are optimized so that the resulting model explains the underlying data points the best or at least good enough. Due to the ubiquitousness of data today, ML currently receives a lot of attention.
Mathematical Background
There are, from my point of view, three main mathematical concepts one has to be familiar with in order to intuitively grasp the powerful idea behind ANNs:
- Vector Spaces: A vector is a mathematical object that consists of a fixed amount of components that can take on different numeric values. The number of components is called the dimensionality of the respective vector and obviously is never negative. In the other, positive direction, however, there is literally no upper limit to the size of such a vector. Hence, it can be one-, two-, three-, or in general n-dimensional, where n can be chosen freely from the range of all positive integers. If the numeric values of the components are either continuous or from an unbounded set of integers, there is an endless amount of value combinations. This huge space of different vectors of the same dimensionality is called an n-dimensional vector space. In this space, each and every position refers to another and unique n-dimensional vector. If a set of certain vectors is now of any particular interest, we can analyze these vectors according to their proximity and structure in the underlying vector space. In addition to that, vectors can be transformed to other vectors in the same, or even to vectors of a completely new vector space. For instance, if we look at the two-dimensional vector space with elements such as (0, 0) or (23, 42) and apply a simple summation transformation to its vectors (0+0=0 and 23+42=65), we actually have found a mapping from a two-dimensional onto a one-dimensional space.
- Mathematical Functions: A function is an abstract machine that receives some input, processes it and returns the resulting output. The space of possible inputs is called the domain of the function, while the target space of the outputs is called the codomain. Thus, a mathematical function is a relation between the elements of two sets or vector spaces. There can be different functions over the same sets, differing in terms of their inner workings. This is an incredibly general concept. In mathematical notation, a function is typically denoted by a single lower-case letter. Functions can be seen as algorithms and vice versa since both transform a well-defined input into a well-defined output. If this inner transformation is hidden or overwhelmingly complex so that you have no clue about what is exactly happening to the inputs but you can observe the outputs, the function is often said to be a black box. There are a lot of black-box functions (or algorithms) that you are using in your daily life: your computer or laptop, your visual cortex, your washing machine, or even your fellow humans. Computers receive user inputs and the user expects the computers to behave expectedly, but almost none of them knows how a computer works at the deepest layers of abstraction.
- Differentiation and Gradients: Imagine you are provided with a mathematical function that takes in a value and outputs another single value, and thus is a mapping of vectors in a one-dimensional vector space. Now you are asked to find out at which input value the output value is the smallest. You are forced to sample vectors (in this case single numbers) from the domain to get a feeling for which inputs lead to which outputs. However, without the necessary assumption that similar input values lead to similar output values, you have literally no chance to discover the minimum. But even with this assumption, this naive approach of sampling relies to a certain extent on pure randomness. What we basically want to achieve by sampling is to identify where the mathematical function rises and where it falls. Fortunately, we can derive this information directly from the analytical expression in the definition of our function through a process called differentiation. All readers that went to high school or college, will for sure be familiar with calculus procedures such as the chain or the product rule. After applying these rules to the objective function to obtain the first derivative, we can sample one input value from the domain. This starting point is then fed through the derivative, which tells us how strongly the function is rising or falling in relation to the input at the respective location in space. If it is rising, we decrease the input value. If it is falling, we increase it. If it is rising stronger, we decrease the input value more. If it is falling harshly, we increase it stronger. And else, if the function is neither rising, nor falling, we reached a point that is close to a minimum. This intuitive but powerful optimization procedure is called Gradient Descent since the value of the derivative at a certain location is called the gradient of the function at that respective point in space. This concept is not restricted to one-dimensional functions. It can be successfully applied to any function that more or less satisfies the made assumption.
What is an Artificial Neural Network?
An artificial neural network (ANN) is a general function approximator that is able to learn to mimic the behavior of a black box function only provided with input and output example pairs. Therefore, an ANN is typically given by a chain of parameterized transformations that successively map the input vector through several different vector spaces until the final output vector is reached. The effects of these transformations is specified by the values of the ANN parameters. By changing and adapting these weights of the ANN, it is able to acquire the behavior of almost any function. However, the ANN itself is a black box as well as we usually cannot make any sense out of the often millions of parameters. The only thing we can judge is the behavior of the network when it is confronted with certain inputs.
In order to train such a network, we need two things: (1) a dataset of input-output example pairs, and (2) a procedure that optimizes the parameters of the ANN so that it learns over time to mimic the hidden black-box function behind the data. While the latter one is generally given by variants of the Gradient Descent algorithm, the former one has to be carefully prepared for the ANN. ANNs are nothing more than complex functions with a lot of parameters. By automatic differentiation, we can identify how we have to tweak these weights so that the neural network behaves more and more like the input-output example pairs dictate.
The dataset: Why Pokémon?
Pokémon has shaped my entire childhood and was the first thing that came to my mind when I was looking for an interesting dataset of small images. I mainly discovered and explored this world by playing a variety of different traditional Pokémon video games, initially starting at Pokémon Blue (Generation 1) and eventually ending at Pokémon Black (Generation 5). While in the first Pokémon generation, there have been 151 different Pokémon in existence, this number has grown rapidly to 649 in the fifth generation. For this project, I compiled a custom dataset from the Pokémon sprite images of the first five generations. I discarded the images from the first two and the ones from the last generations due to their deviating styles when compared to the larger portion of images. Nevertheless, since usually, all Pokémon from previous Pokémon editions appear in a traditional Pokémon video game, I got several different images for each of the 649 Pokémon from the first five generations. These images from the same Pokémon species vary in terms of gender, shininess, and poses.
The ones from different species obviously vary additionally in terms of color, shape, and size. In total, the dataset consists of 11.779 Pokemon sprite images that have been resized to a common size of 96x96 RGBA pixels. Hence, every image pixel has a red (R), a green (G), a blue (B), and an alpha (A) transparency value, leading to a total of 96x96x4=36864 integer values to describe one of the images in its entirety. Put in other words, each of the images can be understood as a vector from a 35864-dimensional vector space, where each component can take on values in the range from 0 to 255. The larger the space, the more data points we need to come up with a good estimate for the ANN parameters. However, the number of required data points rises exponentially with the dimensionality of the underlying vector space (curse of dimensionality). In order to enable ANNs to cope with high-dimensional digital image data, so-called convolutional neural networks (CNNs) as special ANN representatives have been developed. We will have a brief look at CNNs after the next short section.
Although we are not using any labels for these images, all Pokémon images actually do have natural labels that can be extracted from the National Pokédex, such as the primary type, the secondary type, or the species of the Pokémon.
Data Augmentation
Even though roughly 12 thousand small Pokémon images seem to be a lot of data, I decided for two additional techniques to increase the diversity even more directly before an image is presented to the ANN. This is done on-the-fly without any intermediary storing in a file. These two techniques are given by:
- Random Horizontal Flipping: Despite the fact that most of the original Pokémon are directed to the left side (due to their usage in the traditional Pokémon video game), I randomly flipped each image horizontally with a probability of fifty percent.
- Random Jitter: In addition to that, I experimented with some random jittering translations of the images by resizing it slightly to a size of 100x100 pixels and subsequent random cropping back to a size of 96x96 pixels.
What is a Convolutional Neural Network?
Digital images are an extraordinary challenge for ANNs due to two different aspects: (1) even small images exceed tens of thousands of dimensions easily, and (2) the two-dimensional spatial nature which requires special treatment. Therefore, so-called Convolutional Neural Networks (CNNs) have been specifically designed to operate on digital image data.
A CNN is an ANN that uses convolutional operators in most of its layers to exploit the spatial structure of images in order to reduce the number of parameters considerably. Each of these convolutional layers learns a set of convolutional kernels that are applied across the entire layer input. When chained together, these convolutions have the potential to resemble increasingly more abstract concepts the deeper the network is. While early layers might learn to detect features like straight lines or brightness, later layers can grasp more complex patterns and concepts, such as shapes or textures, based on features of previous layers.
What is a Generative Adversarial Network?
Remember our goal of automatically generating fake Pokémon images based on the patterns and regularities present in the custom dataset. In order to achieve this, we have to design a mathematical function that takes in some random noise (for variation) and produces a deceptively real Pokémon sprite image. Suppose that we have such a generator function, we could simply sample the random noise input and rely on the generator to turn it into realistic fake images. This is an extremely hard task, which arguably involves some type of creativity. Solving this problem by sitting down and playing around with analytical formulas of functions that map random noise onto points in the vector space of images is futile since even the task of directly writing down a formula that just discriminates between Pokémon and Non-Pokémon images is impossible. The generator is an imaginative black-box function without any specific or existing input-output example pairs. Hence, we even cannot try to rely on ANNs to rescue us by using them to mimic the wanted generative behavior. This is where so-called Generative Adversarial Networks (GANs) come into play.
The great concept of GANs has been proposed by Goodfellow et al. in 2014 and is a powerful framework that allows us to apply ANNs to our problem, even though we do not have any particular input-output example pairs. Training a neural network to produce a certain output when provided with a certain available input is called supervised learning. If no suitable labels as a targeted output are available, the corresponding learning problem is said to be unsupervised. The first type of GAN presented by Goodfellow et al. is an example for an unsupervised method and hence does not require any additional label information about the images we aim to fake.
In a GAN, two adversarially trained ANNs are working against each other to reach a common goal: the generator and the discriminator. Instead of directly engineering a perfect generator or a perfect discriminator by hand, we let two ANNs iteratively play against each other in a zero-sum game, where one of them is trained to produce realistic fake images and one is trained to distinguish fakes from the real images. In this way, we neither have to provide any particular measure for realness of the generated images, nor do we have to identify the patterns and regularities in the real dataset by ourselves. Whenever the generator performs well, the discriminators performance leaves something to be desired, and in times where the discriminator has no problems to discriminate between the two types of images, the generator has a hard time to fool him. Initially the parameters of both ANNs are randomly initialized, and thus lead to poor performances of the generator as well as of the discriminator. The discriminator has no clue about what Pokémon images typically look like and the generator does so far only produce random rubbish. During training, the discriminator is provided with binary labels telling him which of the input images is a fake previously generated by the generator or a sample of the real dataset. The generator learns from the feedback it gets from the discriminator. This training procedure often suffers from instabilities. If either one of the two adversaries is too strong for the counterpart, the weaker one has no chance to compete against the mighty one. These instabilities have to be avoided through a careful selection of the generator and discriminator ANNs and several different stabilization techniques. Over the iterations, the generator is thus hopefully guided by the discriminator to produce points in the space of images that lie on the same manifold as the real ones. If the reader wants to have a look at mindblowing results of advanced GAN architectures, I strongly recommend to have a look at the hyperrealistic style-based face image generator that has been developed by researchers at NVIDIA from 2019.
The basic network architecture
In this project, I decided to base my implementation on a basic convolutional type of GAN and the guideline recommendations described in the Deep-convolutional GAN (DCGAN) paper. The only major change that I added to the architecture is the application of so-called dropout layers after most of the up- or downsampling modules to prevent overfitting. The reader can find more technical details and the entire implementation publicly available on my GitHub. In addition to the DCGAN guidelines, I experimented with three different techniques that help to stabilize the training and avoid a common problem often referred to as mode collapse, which describes the outcome at which the generator has specialized on producing only one exact type of fake images ignoring the provided input noise almost completely.
Stability and Mode Collapse: Techniques for remedy
During my research on this project, I came across the following three often-applied stabilization techniques and decided to leverage all of them.
- One-sided Label Smoothing: Instead of telling the discriminator to produce overconfident probability scores near 1.0 for real images, we penalize overconfidence by smoothing the true label by a certain amount. In our case, I have chosen a smoothing of 0.1 so that the label for real images has been decreased to be 1.0 - 0.1 = 0.9. This prevents the discriminator to base its predictions solely on a few apparent features.
- Decaying Instance Noise: Instead of feeding the (augmented) Pokémon images or the generated ones directly to the discriminator, we add some gaussian random noise to each of the inputs to increase the difficulty for the discriminator and to widen the overlap of the true and the synthetic data distributions. The standard deviation of this instance noise has been decayed over the iterations by a certain decay factor. For more information and theory have a look at this excellent blog post.
- Spectral Normalization: This technique has shown by far the largest impact on the generated images. Spectral normalization of the ANNs layer weights ensures that the entire function the discriminator network resembles satisfies an even stronger assumption than the one we made in the paragraph about differentiation and gradients. A spectrally normalized discriminator network produces a smoother and more stable realness decision. For more information about this game-changing but yet simple concept, I refer to this well-written and visualized article.
Results
In this section, I present the results for four different variants of the basic DCGAN-like implementation. All of these leverage one-sided label smoothing, decaying instance noise, and random flipping of images. The four variants differ in terms of whether or not they applied random jittering or spectral normalization.
Below you can find animations showing the progress of the generator for fixed random input seeds. Each image has been produced after five iterations.
In order to see the variety of synthetic Pokémon images generated by each of these four models after 250 iterations, 64 randomly sample images are presented for all of them below. The model without spectral normalization but additional random jittering seems to have collapsed into a single mode. By applying spectral normalization to this particular model has helped to alleviate this problem. The decision of whether or not random jittering improves the quality of the results is left to the reader.
Latent Space Traversal
Having the above-mentioned generator functions ready to create, we can now try to figure out how the provided random input noise affects the generated output. The 128-dimensional random seed vector is randomly sampled from a normal distribution before the generator magic starts to happen. We can view each point in this 128-dimensional vector space as a fake Pokémon sprite that just wants to be discovered. In order to see how adjacent points in this space relate in terms of their synthesized images, we will traverse along certain trajectories in this input space and visualize how the generated output changes. There are two types of traversals that make sense in the context of a gaussian input vector: spherical and linear traversal. While the former one takes an arbitrary vector and rotates around the origin until a certain target point is reached, does the second one follow a straight line from the origin to a certain target position. Hence, the spherical traversal fixes the vector’s magnitude and the linear traversal fixes the vector’s direction.
Spherical Traversal by Rotation
Linear Traversal by Scaling
Conclusion
- Pokémon images have a large variety in terms of colors, shapes, poses, and textures. Finding a continuous neural function that produces perfect deceptions is not an easy task since the number of different Pokémon images might be too small for the complexity of this task.
- Nevertheless, the proposed GAN was able to occasionally produce convincing attempts that remind of real Pokémon images in several important aspects: (1) strong black outline, (2) vibrant colors, and (3) clear contours.
- Even though it is easy for us humans to discriminate between these fakes and the real Pokémon images, we might treat them as happy accidents and get inspired by them to fill the wonderful world of Pokémon with even more colorful creatures in our imagination.
What did I learn?
- GPU vs CPU: During the work on this project I spent countless hours testing out different hyperparameters and architectures on a single CPU. This has driven me crazy and motivated me to invest in an inexpensive GPU, which changed the game completely. I am now able to test out models in a fraction of the time that I needed before. For instance, training the simple DCGAN-like model took me at least 40 minutes per epoch on my CPU. Remember that I trained each of the models for 250 epochs each. This was only possible with the my new GPU that was able to do this work over a single night.
- Demystification of GANs: Before this project, GANs have always been a black-box concept for me in general. They seemed to be tremendously more complex than other types of ANNs and considerably harder to train. During this project, I have collected some experience and developed a sense of how GANs behave. This encourages me to pursue this path further in the future towards more advanced types and methods of generative models: Conditional GANs, ACGANs, Self-Attention Blocks, Conditional Batch Normalization, Style-GANs, Self-Modulation Blocks, …
References
- Generative Adversarial Networks, Goodfellow et al., Source: https://arxiv.org/abs/1406.2661
- Unsupervised representation learning with deep convolutional GANs, Radford et al., Source: https://arxiv.org/pdf/1511.06434.pdf
- How to train a GAN? Tips and tricks to make GANs work, Chintala et al., Source: https://github.com/soumith/ganhacks
- Instance Noise: A trick for stabilizing GAN training, F. Huszár, Source: https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
- Spectral Normalization Explained, C. Cosgrove, Source: https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html
- Sampling Generative Networks, T. White, Source: https://arxiv.org/abs/1609.04468