GANs as a loss function.
In lesson 12 of the fast.ai course about Generative Adversarial Networks (GANs), the lecturer Jeremy Howard said:
“ … We’ll learn about Generative Adversarial Networks (GANs). This is, at its heart, a different kind of loss function.”
What does this even mean, you ask?
In this article, I will put the statement above in proper context and show you the simple yet beautiful idea of GAN as a learned loss function.
First, let’s start with some contexts:
Neural Network as a function approximation
In mathematics, you can think of a function as a machine, you give it one or more numbers and it churns out one or more numbers.
It’s all nice and good if you can express your function in a mathematical expression. But what if you can not or have not figured out a way to write your desired function as a bunch of additions and multiplications, for example, a function to tell if the input is an image of a cat or a dog.
If you can’t formulate it, can you at least approximate it?
Here Neural Network (NN) comes to the rescue. Universal Approximation Theorem stated that for a sufficiently large network with enough hidden units, a NN can compute any function.
For why this is the case, please consult the interactive demo in the Neural Network and Deep Learning book by Micheal Nielsen.
Explicit loss function in Neural Networks
With the power of Neural Network at hand, instead of trying to explicitly formulate a function to classify a dog and a cat, we instead build a NN and try to gradually make it better at approximating this function.
In order to get better at approximating, NN first needs to know how bad it is doing. The way to compute this error of Neural Net is called loss function.
There is a handful of loss functions and the use of which depend on the task at hand, however, they all share the same property that it must be possible to express these loss functions in precise mathematical formulas:
- L1 loss (Absolute error): Used for regression task
- L2 loss (Squared error): Similar to L1 but more sensitive to outliers.
- Cross-entropy loss: Usually used in classification tasks.
- Dice loss (IoU): Used in segmentation tasks.
- KL Divergence: For measuring the difference between two distributions.
Loss function is very important in building a good NN approximation. Understanding and using the appropriate loss functions for the tasks at hand is the most important skill of a Neural Nets builder.
Designing better loss functions is also a very active area of research. For instance, the paper “Focal Loss for Dense Object Detection” is all about a new loss function called Focal loss that deal with unbalancedness in single-stage object detection model.
Limitation of explicit loss functions
The above loss functions work relatively well for classification, regression, segmentation tasks but in case the outputs have a multi-modal distribution, these losses break down.
Take the task of coloring a black-white picture for example.
- Your inputs can be birds in black and white and your ground truth images are the same birds in blue.
- You use L2 loss functions to calculate the pixel-wise difference between your model color outputs and the blue-bird ground truth.
- Next, you have an image of a very similar bird in black and white, only now the ground truth image is this bird in red.
- The L2 loss function now tries to minimize the difference between the color outputs with red.
- From the feedback of L2 loss, the model now learns that for a similar bird, it should output a color that is close to red, and also close to blue. What will it do?
- It will output the bird with color yellow, which is the safest option to minimize distance to red and blue, even though it has never seen a yellow bird during training before.
- There is no yellow bird in nature, so you know your model is not realistic.
This averaging effect can lead to very unpleasant consequent in some cases. Take the tasks of predicting the next video frame for instance, the next frame has many possibilities and what you want is the model to output one of it. But if you train your model with L2 or L1, it will average out all the possibilities and produce a very blurry average image instead.
GAN as the new loss.
So in the beginning, you don’t know the exact mathematical formula for a complicated function — e.g a function that takes in an array of numbers and output a realistic image of a dog — so you use Neural Nets to approximate it.
A NN needs loss functions to tell it how good it currently is, but no explicit loss function can perform the task well.
Hmm, if only there was a way to approximate this Neural Net’s loss function directly, without having to explicitly know its mathematical formula, … like using a Neural Net?
So what happens if you replace this explicit loss function with a NN model too? Congratulation, you just discovered GAN.
You can see this more clearly from the architecture of GAN and Alpha-GAN below. In these figures, white boxes represent inputs, pink and green boxes represents the networks that you want to build and the blue boxes represent the loss function.
In case of vanilla GAN, there is only one loss function, that is the Discriminator network D, which is itself a different NN.
In case of Alpha-GAN, there are 3 loss functions, the discriminator D of the input data, the latent code discriminator C for the encoded latent variables and the traditional pixel-wise L1 loss function. Among these, D and C are not explicit loss functions but are just the approximation— a Neural Net.
Follow the gradients.
So if the Generator network (and Encoder in Alpha-GAN) is trained with loss function as the Discriminator — which is also a NN, what loss function is the Discriminator trained with?
The task for the discriminator is to discriminate between real data distribution and the generated data distribution. The labels to train the discriminator in a supervised manner come for free, so it’s easy to train it with some explicit loss function such as binary cross-entropy.
But since the discriminator is the loss function for the generator, this means that the gradients accumulated from the discriminator’s binary cross-entropy loss are also used to update the generator network.
Looking at the flow of the gradients in GAN, it is easy to come up with some new ideas to change its path. What if the gradients from explicit loss does not flow back to just 2 NNs of discriminator and generator, but it flows back to 3 NNs, what can that be useful for? What if the the gradients does not flow from a traditional loss but move back and forth directly between these NNs? Reasoning from the fundamental idea, it’s much easier to see the unexplored path and unanswered questions.
By wrapping around the traditional loss function with a Neural Network, GANs make it possible to use a NN as a loss function for a different NN. This beautiful interplay between two networks has allowed deep Neural Net to performs some previously unattainable tasks such as generating realistic images.
Viewing GANs as essentially a learned loss function, I hope that this post has helped you appreciate the simple yet powerful idea of GANs.