Implementation of Semi-Supervised Generative Adversarial Networks in Keras

Build a powerful classifier using semi-supervised learning

Kunal Vaidya
Towards Data Science

--

Photo by mohammad alizade on Unsplash

Everyone has heard about supervised learning and unsupervised learning but there is also another set of learning techniques in between them called semi-supervised learning.

Supervised Learning is the most popular technique used in Machine Learning but there is one disadvantage to it that it requires a lot of labeled data. It takes a lot of effort and time to label the data. So this is where Semi-Supervised Learning comes into the picture.

What is Semi-Supervised Learning?

Semi-Supervised Learning is a technique where we only use a small set of labeled data from a large amount of unlabeled data to train our model.

Seems like an interesting approach and also the cost of labeling the data is reduced drastically.

What are Generative Adversarial Networks?

Generative Adversarial Networks (known as GAN’s) are a class of generative models designed by Ian Goodfellow and his colleagues in 2014. It was a breakthrough in terms of generative models.

In GAN’s there are two neural networks that are trying to defeat each other (i.e one network’s loss is the other network’s gain). The two neural networks are called Generator and Discriminator.

The generator model tries to generate images (similar to training data) and the discriminator’s job is to classify the images as real (from training data) or fake (from the generator). So Basically Discriminator is trying not to get fooled by Generator and Generator is trying it’s best to fool the discriminator so that’s why this is called a Game Theory approach as both Generator and Discriminator are in a game trying to overcome each other.

How to use GAN for Semi-Supervised Learning?

Below is the model for Semi-Supervised GAN

Semi-Supervised GAN, Source: Image by Author

Let’s Understand the model
The discriminator is passed through three types of images namely Labeled Training Images, Unlabeled Training Images, and Fake Images generated by Generator. Its job is not only to distinguish between Real/Fake Images but also to classify the Labeled Training Images into their correct classes.

The Discriminator has two different outputs:

  1. Softmax Activation for classifying labeled data into their correct classes i.e this is supervised discriminator.
  2. Custom Activation for classifying into real or fake. We’ll see about the custom activation in implementation.

The powerful aspect of this is that discriminator is not only trained on labeled data but also on a large amount of unlabeled data as it has to also discriminate between Real/Fake images so for this task discriminator will need to learn to extract features for classifying images as real or fake. This adversarial training will help the discriminator to classify the labeled data more accurately as it will recognize other patterns as well while learning to classify real/fake images which normally it would not on just a small set of labeled data.

So Let’s see the implementation of the above model in Keras
The dataset which I am gonna use is the good old MNIST Dataset (Everyone’s Favourite)

Setup

First import all the necessary packages required

z_dim is the dimension of our random normal variable which we will pass to our generator.

Dataset

We need to prepare two datasets as follows:

  1. Supervised Discriminator: Data set will be a small subset of the complete training set.
  2. Unsupervised Discriminator: Data set will be complete training set.

Only batch_labeled() function needs explanation here

batch_labeled() prepares the dataset for supervised discriminator by selecting a subset of samples along with their labels. We are just going to use 100 labeled examples(i.e 10 examples per class)to train our classifier.

batch_unlabeled() randomly samples images from the dataset with number of images equal to given batch_size.

Generator Network

Generator Network

Using Keras Sequential API for building our generator model. It is a basic generator model as our task is not complicated.

Dense layer is used to reshape our dimensional z_dim which is of shape (batch_size,100) to (batch_size,7,7,256) so we can apply Conv2DTranspose Layer on it to generate image.

Discriminator Network

Discriminator Network

So this is our Discriminator which takes input an image (fake or real) and outputs whether the image is real or fake i.e “1” or “0”.

Here output has been kept an output from dense layer without any activation so these are values which can be negative as well. It will become clear in the next step why it’s been done like this.

Supervised Discriminator

Supervised Discriminator

build_discriminator_supervised() takes as input the discriminator model we created in above step and uses softmax activation to classify our input image into one of 10 classes (for MNIST).

Unsupervised Discriminator

build_discriminator_unsupervised() takes input the discriminator we created before and it applies a custom_activation() on the output of discriminator.

As when we created discriminator we had kept its output as the output of a Dense layer which gives values without any activation. So the custom activation does the following

Custom Activation Function

Here z is the output of dense layer without any activation,k for our case is 10 and y is between 0 and 1.

As the probability of the image being real or fake will be the sum of probabilities over all classes so the activation function basically sums over all classes and scales output between [0,1] to get the probability of the image being real or not. This trick was much more effective than using a sigmoid.

The weights used by supervised and unsupervised discriminator are the same only the output nodes used by both are different.

Building GAN

build_gan() basically takes input both the generator and discriminator and merges them to form a model. This is done to train the generator network. This will be called the composite model.

Note that here discriminator is the unsupervised discriminator as GAN is just used for unlabeled images.

Training Semi-Supervised GAN

First, let's build our model by compiling all the model functions we created above.

Defining Models

Note that while building GAN (composite model) we have kept discriminator_unsupervised as non-trainable (line 7) because the training of GAN is done in steps of training the Discriminator and then training the Generator so while training our Generator we don’t want our discriminator to get updated.

Cost Function used is Binary Cross Entropy for unsupervised discriminator and Categorical Cross Entropy for Supervised Discriminator.

Now the main training loop

Main Training Loop

So, we start our training by training supervised discriminator on a batch of real labeled images, then we train unsupervised discriminator on real unlabeled images (labeled as “1”) and fake images (labeled as “0”) generated by Generator and Finally, we train our generator through discriminator by using the composite model which we had defined earlier.

So the way training of Generator works is as follows:

  1. The generator generates a batch of fake images.
  2. These generated fake images are passed through discriminator to classify and the target label to discriminator is given as real i.e “1” (line 31). Note: Discriminator is not updated during this step.
  3. So if the images are not realistic looking loss would be very high so to minimize the loss Generator will start generating realistic-looking images which is our goal.

So, as Generator gets trained on more images it starts generating better images and the discriminator also starts getting better because it does not want its loss to get high while classifying fake or real images. So this is exactly why both discriminator and generator are in a game against each other if one gets better the other also needs to improve.

After training, we just use the Discriminator and discard the Generator because our main aim was to build a classifier for MNIST.

Results

Source: Image by Author

So, the accuracy obtained on the test set is approx 90.67 % which is quite impressive given the fact that we have just used 100 labeled images (10 per class) to train our discriminator. So, this is a powerful classifier which does not need lots of labeled data. Of course, performance will improve even more if we use say 1000 labeled images which are still quite low compared to modern datasets.

Fake Images Generated by Generator, Source: Image by Author

So, these are the fake images generated by the generator which are quite realistic with few exceptions.
You can also try training the same discriminator model without the GAN on 100 labeled images and see how it performs. It will be no doubt worse than the Discriminator trained with GAN.

This is the Github Repo Linked to this implemenatation

https://github.com/KunalVaidya99/Semi-Supervised-GAN

If you want to learn more about GAN’s you can try reading this book GANs in Action Deep learning with Generative Adversarial Networks by Jakub Langr, Vladimir Bok.

This is my first article on Medium, please let me know if you like it. Thanks for Reading!.

References

https://arxiv.org/pdf/1606.01583.pdf

--

--