Analytics Vidhya

Analytics Vidhya is a community of Generative AI and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Understanding GANs — Deriving the Adversarial loss from scratch

--

Generative adversarial networks or GANs for short are an unsupervised learning task where the generator model learns to discover patterns in the input data in such a way that the model can be used to generate new samples of the training data.

The main idea of GAN is adversarial training, where two neural networks fight against each other and improve themselves to fight better.

The Generator takes a noise vector as input and then transforms the noise vector into a fake training sample, which is then passed through the discriminator. The discriminator takes both real samples (from the training data) and fake samples (generated by the generator), and then it tries to discriminate between fake and real samples. In other words, the generator tries to fool the discriminator by showing it fake training data samples and the discriminator tries to be as clever as possible.

The main idea is if the generator fools the discriminator then it means the discriminator should improve itself. On the other hand, if the discriminator classifies fake and real samples perfectly then it means that the generator should improve itself so that it can fool the discriminator!

The possible causes are,

  1. The generator fools the discriminator means the discriminator failed to classify a fake image sample. In this case, the discrimination should improve hence the loss will backpropagate through the discriminator only!
  2. The discriminator does a good job at classifying fake and real images, which means the fake images are not good enough to confuse the discriminator. That means the generator should improve itself hence the loss will backpropagate through the generator network only!

BUT, the question is how can the generator fool the discriminator?

Intuitively, the generator learns the probability distribution of our training data. The below picture describes the idea intuitively.

In one sentence, The generator learns to approximate the distribution of the actual training data, and then it samples from the learned distribution. While training, there will be a fight between the generator and the discriminator, and both are trained alternatively while keeping the other one fixed!

Deriving the adversarial loss:

The discriminator is nothing but a classifier that performs a binary classification(either Real or Fake). So, what loss function do we use for binary classification? Is not it binary cross-entropy?

The equation of the Binary cross-entropy loss function is given below.

Z: Noise vector (Dimension of the noise vector is a hyperparameter).

G(Z): Output of the generator when given the noise vector Z.

X: Real Training data.

D(G(Z)): Output of the discriminator when given fake generated data or G(Z).

D(X): Output of the discriminator when given real training data from X.

The discriminator takes either X or G(Z). Note that the discriminator is nothing but a binary classifier so, we label D(X) as 1 and D(G(Z)) as 0.

We want our discriminator to label all D(X) as 1 and all D(G(Z)) as 0. Right?

So,

The discriminator should maximize Log(D(X)), and as Log is a monotonic function so Log(D(X)) will automatically get maximized if the discriminator maximized D(X).

On the other hand,

The discriminator needs to maximize log(1 — D(G(Z))), which means it must have to minimize D(G(Z)).

So, the loss function for the discriminator (for a single sample) becomes,

The discriminator will maximise D(X) and minimise D(G(Z)) to overall maximize the above loss function.

Note that, D(X) and D(G(Z)) both are probability values and both of them lie in between 0 and 1.

Now, the loss function of the discriminator over a batch is,

Where, P(X) is the probability distribution of real training data and P(Z) is the probability distribution of the noise vector Z. Typically, P(Z) is gaussian or Uniform.

The Generator needs to fool the discriminator by generating images as real as possible. This means the generator should generate such G(Z) which if we pass through the discriminator will label is as 1.

So, discriminator wants to make D(G(Z)) equal to 1 and generator wants to make D(G(Z)) equal to 0.

So, from binary cross-entropy,

The above one is for just one sample. Over a batch, it will be,

The generator will minimize the above loss function and to minimize the above the generator must maximize D(G(Z)). Now, it is very clear that the discriminator wants to minimize D(G(Z)) and the generator wants to maximize D(G(Z)).

Understand that, the generator is never going to see any real data but for completeness, the generator loss function can be written as follows!

Note that, the generator has no control over the first term so the generator will only minimize the second term.

Assume that, D is the parameters of the Discriminator and G is the parameters of the generator. So, we can write the loss function as,

This means the discriminator parameters (defined by D) will maximize the loss function and the generator parameters (defined by G) will minimize the loss function.

The adversarial loss can be optimized by gradient descent. But while training a GAN we do not train the generator and discriminator simultaneously, while training the generator we freeze the discriminator and vice-versa!

The original GAN paper provides a pseudo-code that shows how a GAN is trained.

I think this was not too hard to follow! :)

References:

--

--

Analytics Vidhya
Analytics Vidhya

Published in Analytics Vidhya

Analytics Vidhya is a community of Generative AI and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Hrithick Sen
Hrithick Sen

Responses (2)