On the intuition behind deep learning & GANs — towards a fundamental understanding
A generative adversarial network (GAN) is composed of two separate networks - the generator and the discriminator. It poses the unsupervised learning problem as a game between the two. In this post we will see why GANs have so much potential, and frame GANs as a boxing match between two opponents.
Intuition behind deep learning
Deep learning is famously biologically inspired and many of the major concepts in deep learning are intuitive and grounded in reality. The fundamental truth of deep learning is that it’s hierarchical — the layers in a network and the representations they learn build on each other. This is also the case in reality: electrons, protons, neutrons -> atoms -> molecules -> … It makes sense that the best way to model a hierarchical world is hierarchically, and this is why deep learning has been so successful in providing simple, elegant, and general solutions to very difficult problems.
Motivating unsupervised learning
‘Adversarial training is the coolest thing since sliced bread.’ — Yann LeCun, Director of AI Research at Facebook and Professor at NYU
Now let’s apply this biologically inspired mind set to the way we currently train our networks. Supervised learning is standard in the current state of machine learning - for each data sample a ground-truth annotation/label is required in training. But unsupervised learning is how most learning is done in the real world. Just think about how we learn to walk, talk, etc… While supervised learning has performed well on many tasks, unsupervised learning seems to be the key to real artificial intelligence.
It’s often impractical to accurately annotate data. Ideally an unsupervised model could be trained on data that doesn’t have the required annotations, and then fine-tuned with a much smaller properly annotated dataset. Circling back to the hierarchical view of the world, it should be possible to train AI to understand the world’s basic building blocks, and then build on top of that existing knowledge base, fine-tuning it in a more supervised manner for specific use-cases.
Unsupervised learning - a concrete example
A convolutional neural network is trained on millions of unlabelled images of skin. Some of these images could be of healthy skin, others of diseased skin, and everything in between. Eventually the network would gain a very deep understanding of skin and all its intricacies. A specific use-case (i.e. diagnosing skin cancer instantly and accurately) could then be built on top of this network.
Since the model has already learned general, powerful representations of the most important information contained in images of skin, it should be able to quickly learn the new task of diagnosing skin cancer with a much smaller labelled dataset than if it was trained using only supervised methods. This is the basic concept of transfer learning & fine-tuning.
GANs are one of the most promising areas of research in unsupervised learning and we will see that they are a simple, general approach to learning powerful representations from data.
Let’s break down a GAN into its basic components:
Data: Mathematically, we think about a dataset as samples from a true data distribution. This data could be anything: images, speech, sensor readings, etc…
Generator: Takes some code (i.e. random noise) as input, and transforms it, outputting a sample of data. The goal of the generator is to eventually output diverse data samples from the true data distribution.
Discriminator: Takes a sample of data as input, and classifies it as real (from the true data distribution) or fake (from the generator). The goal of the discriminator is to be able to discriminate between real and generated images with high precision.
The overall goal of a standard GAN is to train a generator that generates diverse data samples from the true data distribution, leading to a discriminator that can only classify images as real/generated with a 50/50 guess. In the process of training this network, both the generator and the discriminator learn powerful, hierarchical representations of the underlying data that can then transfer to a variety of specific tasks like classification, segmentation, etc… and use-cases.
Understanding the GAN’s training procedure
The pseudo-code below might be confusing at first so we’ll step through it with a simple real-world example of the adversarial learning procedure right after.
# train the discriminator to classify a batch of images from our
# dataset as real and a batch of images generated by our current
# generator as fake
# train the generator to trick the discriminator into
# classifying a batch of generated images as real. The key here
# is that the discriminator is frozen (not trainable) in this
# step, but it's loss functions gradients are back-propagated
# through the combined network to the generator
# the generator updates its weights in the most ideal way
# possible based on these gradients
# where combined is a model that consists of the generator and
# discriminator joined together such that: input => generator =>
# generator_output => discriminator => classification
We are all very familiar with the general concept of GANs and adversarial learning whether we realize it or not. For example, consider learning to play a song on guitar:
Listen to the song — figuring out how to map it to the guitar (step 1 in the training procedure above), try to play the song — listening to what you play and paying attention to how it differs from the actual song (step 2), play the song again — trying to fix these differences (step 3).
We repeat some variation of this procedure, where steps 2 & 3 are pretty much merged together, and step 1 is partially memorized and revisited every once in a while when the memory needs to be refined, until what we are playing sounds close enough to the actual song and we are happy.
As you become a more skilled guitarist your ability to learn new songs improves until you reach a point where you can play songs you’ve never heard or played before with very little practice (i.e. transfer learning/fine-tuning).
In this example, the song is the data, our ears/brain is the discriminator, and our hands/brain is the generator. This is probably similar to how we learned to move, talk, etc… Taking this one step further, think about when a deaf person talks - it sounds funny because they don’t have a discriminator to facilitate the adversarial learning (maybe they can pick up on other cues like people’s reactions which serve as a form of weak discriminator).
Now that we’ve built up some intuition behind GANs, let’s see how they are currently implemented in software. You should think about the similarities & differences between GANs in reality & software along the way. Highlighting one difference, the adversarial learning procedure that occurs in reality seems collaborative between the generator and discriminator, while the software implementation of GANs seems adversarial (… a boxing match).
Training a GAN — a boxing match between the generator and discriminator
At first it might seem like the discriminator is the coach, and the generator is the boxer. But really they are both boxers. The real data is actually the coach. The thing here is that only the discriminator has direct access to the data.
The discriminator is a boxer that learns from a coach (the larger the real dataset, the more experienced the coach) while the generator is a boxer who can only learn from his sparring partner (the discriminator).
In step 1 of the training procedure above, the discriminator is trained for a round on the heavy bag by his coach. The coach critiques his technique and the discriminator adapts. In step 2, the discriminator watches a round of the generator shadowboxing, studying the generator and preparing accordingly for their upcoming round of sparring.
Now step 3, sparring! The generator is a scrappy boxer from Philly who is relaxed & focused when sparring, studying every movement and mistake the discriminator makes and learning from it - adapting after each round. The discriminator hates sparring, and is so scared and nervous every time that he learns absolutely nothing from it. The discriminator may be more athletically gifted and talented than the generator (it’s easier to classify data as real/fake than it is to actually generate realistic data), but the generator’s mindset helps level the playing field. Even though the generator doesn’t have a coach (no access to the real dataset), it learns so much from the discriminator during sparring that it picks up on the fundamental things the discriminator was taught by his coach.
This process goes on for rounds and rounds until eventually the discriminator and generator are both well-rounded boxers ready to compete. The coach has taught the discriminator every important detail of the game he knows, and the generator and discriminator have learned a lot from each other in their sparring wars. Ideally they are both so equally matched at the end of training that a match between them would have 50/50 odds.
As you dive deeper into GANs you will see that one of the major difficulties we currently face is training these networks to converge properly - we want the generator and discriminator to reach some desired equilibrium but most of the time this doesn’t happen. There is a lot of information and research out there on what can go wrong: https://www.quora.com/Do-generative-adversarial-networks-always-converge and more and more information out there on how to counteract these problems: https://github.com/soumith/ganhacks.
Just to highlight a few of the most common fail cases of GANs:
- The discriminator becomes too strong too quickly and the generator ends up not learning anything. In our boxing analogy this would be like the discriminator getting so good that the generator ends up being completely out-matched and just a punching bag in sparring, not able to learn anything since the discriminator is making no mistakes and leaving no openings for the generator to work. In theory this means that in step 3 of the above training procedure, the discriminator classifies generated data as fake so accurately and confidently that there is nothing in the discriminator’s back-propagated loss function gradients for the generator to learn.
- The generator only learns very specific weaknesses of the discriminator and takes advantage of these to trick the discriminator into classifying generated data as real instead of learning to represent the true data distribution. What this looks like in theory. In our boxing analogy this would be like the generator learning every little weakness of the discriminator and just capitalizing on those any way possible rather than actually learning the fundamentals and skill of boxing. In a match against an opponent who doesn’t share the same weaknesses the generator would be useless! And everything the discriminator learns from the generator will be useless too because in real matches opponents will not behave like the generator.
- The generator learns only a very small subset of the true data distribution. In our boxing analogy this would be like our generator only learning a good jab and hiding behind it - developing no other tools. This will lead to the discriminator learning very little from the generator, and placing too much importance on representing this small subset of the data distribution. An example of this occurring in practice is the case where for every possible input, the generator is generating the same data sample and there is no variation among it’s outputs.
The above analogies are a work in progress and more may be added later, especially if requested or suggested.
Now that we have a fundamental understanding of GANs, let’s revisit their purpose: to learn powerful representations from unlabelled data (i.e. take our data from its original dimension and learn to represent its most important features in a much smaller dimension => less labelled data required to achieve desired performance). After training a GAN, most current methods use the discriminator as a base model for transfer learning and the fine-tuning of a production model, or the generator as a source of data that is used to train a production model. In our boxing analogy this means that the discriminator gets his boxing license and competes but the generator doesn’t. It’s unfortunate because the generator seems like he has the potential to be the better boxer, and he is either completely discarded or only used as a sparring partner/coach for the production model.
What I cannot create, I do not understand.
A well-trained generator has learned the true data distribution so well that it can generate samples belonging to it from a much smaller dimension of input. This suggests that it has developed extremely powerful representations of the data. It would be ideal to leverage what the generator has learned directly in production models but I don’t know of any methods to do this. If you do please comment.
For a clean & simple implementation of a standard GAN (as well as other types of GANs like InfoGAN and ACGAN) see:
GAN-Sandbox - Vanilla GAN implemented on top of keras/tensorflow enabling rapid experimentation & research. Branches…github.com
There are types of GANs that produce an extremely valuable generator, even if its only a ‘sparring partner/coach’:
Apple’s Learning from Simulated and Unsupervised Images through Adversarial Training (S+U Learning) lays down the…medium.com
Waya.ai is a company whose vision is a world where medical conditions are addressed early on, in their infancy. This approach will shift the health-care industry from a constant fire-fight against symptoms to a preventative approach where root causes are addressed and fixed. Our first step to realize this vision is easy, accurate and available diagnosis. Our current focus is concussion diagnosis, recovery tracking & brain health monitoring. Please get in contact with me if this resonates with you!