Deep adversarial learning is finally ready 🚀 and will radically change the game
Adversarial learning is one of the most hyped areas in deep learning. If you browse arxiv-sanity, you’ll notice much of the most popular recent research explores this area.
This post will:
- Explain why we should care about adversarial learning
- Briefly introduce generative adversarial networks (GANs) and the major challenges associated with them
- Summarize recent research (Wasserstein GAN, Improved Training of Wasserstein GANs) that solve these challenges and stabilizes GAN training (implementation included)
Here is a presentation I gave on the topic at the Re-Work 2017 Deep Learning in Healthcare summit:
Generative adversarial networks (GANs) are one of the most promising areas in deep learning research. In this…videos.re-work.co
Classical 🎻 machine learning -> Deep learning
In the opening lecture of a course I took at UIUC on analog signals and systems, the professor confidently asserted something along the lines of:
This is the most important course you will take, and abstraction is the most important concept in engineering.
The solution to complexity is abstraction, also known as information hiding. Abstraction is simply the removal of unnecessary detail. The idea is that to design a part of a complex system, you must identify what about that part others must know in order to design their parts, and what details you can hide. The part others must know is the abstraction.
Deep neural networks learn hierarchical representations of data. The layers in a network and the representations they learn build on each other, with layers representing data at a progressively higher level of abstraction. Given raw data, a question to ask the network, and an objective function to evaluate the network’s answer, a network learns to optimally represent (abstract) this data.
A 🔑 consequence of this concept is that feature engineering is learned and performed by the network. As opposed to the classical machine learning approach where features that are expected to contain information relevant to the task-at-hand are manually identified and extracted from data, reducing the dimensionality of input to the ‘learning’ algorithm.
When the underlying structure, patterns, and mechanisms of data are learned instead of hand-crafted ✍️, previously infeasible applications of AI are enabled and super-human performance is made possible.
Deep learning -> Deep adversarial learning
Years ago I had a boxing coach who wouldn’t let new boxers ask questions. New boxers asked the wrong questions, got answers they didn’t need, and then focused on the wrong things.
Asking the right questions takes as much skill as giving the right answers.
- Robert Half
The beauty of adversarial learning is that our networks learn entirely from data — the questions to ask, the corresponding answers, and the evaluation of these answers are learned. As opposed to the classical deep learning approach where questions that are expected to be relevant to the task-at-hand are manually identified, and hand-crafted objective functions guide the optimization of our networks towards learning the corresponding answers.
Deep Mind recently demonstrated the amazing potential of deep (adversarial) learning with AlphaGo, showing that AlphaGo invents new knowledge and teaches new theories in the game of Go. This ushered in a new era of Go and moved players past a local maxima they’d been stuck in for thousands of years. AlphaGo was able to achieve this by learning an evaluation function that describes ‘the score’ at any given moment to the system, rather than attempting to hand-craft and pre-program this. AlphaGo was then trained against itself through millions of simulated games. Sound like adversarial learning 🤔?
AlphaGo didn’t just brute force 👊 its way towards becoming the best Go player in the world. It truly mastered the game and all its subtleties and intricacies. This was possible because it wasn’t constrained by human input or our (what we now realize is limited) understanding of the problem domain (both in asking, answering, and evaluating questions). The next step will be to apply these approaches in learning to the real-world 🌏. It’s hard to imagine how AI will reinvent agriculture 🌱, healthcare 🏥, etc… but it will happen.
Generative adversarial networks
What I cannot build, I do not understand.
- Richard Feynman
The above quote motivated me to start working with GANs. GANs pose the training process as a game between two networks and allow adversarial learning on generic data.
With the goal of modeling the true data distribution, the generator learns to generate realistic samples of data while the discriminator learns to determine if these samples are real or not. With the belief that the ultimate expression to understanding something is being able to recreate it, our goal seems like a worthy one. If we are able to successfully train our GAN to equilibrium (generated samples indistinguishable from real samples by a perfect discriminator), we should be able to apply this gained understanding of our data to almost any task with top performance 🎯.
GANs are difficult to optimize and training is unstable. The network architectures must be carefully designed and the balance between the generator and discriminator must be carefully maintained for training to converge. On top of that mode dropping is typical in GANs (generator learns only a very small subset of the true distribution), and they are difficult to debug due to near-meaningless learning curves.
Still, state-of-the-art results have been achieved with GANs, but practical usefulness has been limited by these problems.
GANs are trained to minimize the distance between the generated and true data distributions. Initially, the Jensen-Shannon divergence was used as this distance metric. However, Wasserstein GAN (wGAN) provided extensive theoretical work and showed empirically that minimizing a reasonable and efficient approximation of the Earth Mover’s (EM) distance is a theoretically sound optimization problem that cures the main problems of GANs (described in the section above). For this approximation of the EM distance to be valid, wGAN imposed weight clipping constraints on the critic (referred to as the discriminator pre-Wasserstein) which caused some training failures.
Improved training of Wasserstein GANs enables very stable GAN training by penalizing the norm of the gradient of the critic with respect to its input instead of clipping weights. This ‘gradient penalty’ is simply added to the Wasserstein distance described above for the total loss.
Finally, for the first time, we can train a wide variety of GAN architectures with almost no hyper-parameter tuning, including 101-layer ResNets and language models over discrete data 💪!
One of the 🔑 benefits of using the Wasserstein distance is that as the critic improves, the generator will receive improved gradients from it. When using the Jensen-Shannon divergence, gradients vanish as the discriminator improves and the generator has nothing to learn from (a major source of training instability).
I recommend reading both papers if interested in gaining a solid theoretical understanding of these concepts:
The way I visualize GANs has changed with the introduction of this new objective function as illustrated below:
Adversarial learning allows us to free our models of any constraints or limitations in our understanding of the problem domain — there is no preconception of what to learn and the model is free to explore 🕵 the data.
In the next post we will see how we can utilize the representations learned by our generator for image classification.
I’ll be talking about GANs at the Deep Learning in Healthcare Summit in Boston on Friday, May 26th. Feel free to stop by and say 👋!