Implementation of Semi-Supervised Generative Adversarial Networks in Keras
Everyone has heard about supervised learning and unsupervised learning but there is also another set of learning techniques in between them called semi-supervised learning.
Supervised Learning is the most popular technique used in Machine Learning but there is one disadvantage to it that it requires a lot of labeled data . It takes a lot of effort and time to label the data. So this is where Semi-Supervised Learning comes in the picture.
What is Semi-Supervised Learning?
Semi-Supervised Learning is a technique where we only use a small set of labeled data from a large amount of unlabeled data to train our model.
Seems like an interesting approach and also the cost of labeling the data is reduced drastically.
What are Generative Adversarial Networks?
Generative Adversarial Networks (known as GAN’s) are a class of generative models designed by Ian Goodfellow and his colleagues in 2014. It was a breakthrough in terms of generative models.
In GAN’s there are two neural networks that are trying to defeat each other (i.e one network’s loss is the other network’s gain). The two neural networks are called Generator and Discriminator.
The generator model tries to generate images (similar to training data) and the discriminator’s job is to classify the images as real (from training data) or fake (from the generator). So Basically Discriminator is trying not to get fooled by Generator and Generator is trying it’s best to fool the discriminator so that’s why this is called a Game Theory approach as both Generator and Discriminator are in a game trying to overcome each other.
How to use GAN for Semi-Supervised Learning?
Below is the model for Semi-Supervised GAN
Let’s Understand the model
The discriminator is passed three types of images namely Labeled Training Images, Unlabeled Training Images, and Fake Images generated by Generator. Its job is not only to distinguish between Real/Fake Images but also to classify the Labeled Training Images into their correct classes.
The Discriminator has two different outputs:
- Softmax Activation for classifying labeled data into their correct classes i.e this is supervised discriminator.
- Custom Activation for classifying into real or fake.We’ll see about the custom activation in implementation.
The powerful aspect of this is that discriminator is not only trained on labeled data but also on a large amount of unlabeled data as it has to also discriminate between Real/Fake images so for this task discriminator will need to learn to extract features for classifying images as real or fake. This adversarial training will help the discriminator to classify the labeled data more accurately as it will recognize other patterns as well while learning to classify real/fake images which normally it would not on just a small set of labeled data.
So Let’s see the implementation of the above model in Keras
The dataset which I am gonna use is the good old MNIST Dataset (Everyone’s Favourite)
First import all the necessary packages required
z_dim is the dimension of our random normal variable which we will pass to our generator.
We need to prepare two datasets as follows:
- Supervised Discriminator : Data set will be a small subset of complete training set.
- Unsupervised Discriminator : Data set will be complete training set.
Only batch_labeled() function needs explanation here
batch_labeled() prepares the dataset for supervised discriminator by selecting a subset of samples along with their labels.We are just going to use 100 labeled examples(i.e 10 examples per class)to train our classifier.
batch_unlabeled() randomly samples images from the dataset with number of images equal to given batch_size.
Using Keras Sequential API for building our generator model.It is basic generator model as our task is not complicated.
Dense layer is used to reshape our dimensional z_dim which is of shape (batch_size,100) to (batch_size,7,7,256) so we can apply Conv2DTranspose Layer on it to generate image.
So this is our Discriminator which takes input an image (fake or real) and outputs whether the image is real or fake i.e “1” or “0”.
Here output has been kept an output from dense layer without any activation so this are values which can be negative as well.It will become clear in next step why it’s been done like this.
build_discriminator_supervised() takes as input the discriminator model we created in above step and uses softmax activation to classify our input image into one of 10 classes (for MNIST).
build_discriminator_unsupervised() takes input the discriminator we created before and it applies a custom_activation() on the output of discriminator.
As when we created discriminator we had kept its output as the output of a Dense layer which gives values without any activation. So the custom activation does the following
Here z is the output of dense layer without any activation,k for our case is 10 and y is between 0 and 1.
As probability of image being real or fake will be the sum of probabilities over all classes so the activation function basically sums over all classes and scales output between [0,1] to get the probability of the image being real or not. This trick was much more effective than using a sigmoid.
The weights used by supervised and unsupervised discriminator are same only the output nodes used by both are different.
build_gan() basically takes input both the generator and discriminator and merges them to form a model. This is done to train the generator network. This will be called the composite model.
Note that here discriminator is the unsupervised discriminator as GAN is just used for unlabeled images.
Training Semi-Supervised GAN
First, lets build our model by compiling all the model functions we created above.
Note that while building GAN (composite model) we have kept discriminator_unsupervised as non-trainable (line 7) because the training of GAN is done in steps of training the Discriminator and then training the Generator so while training our Generator we don’t want our discriminator to get updated.
Cost Function used is Binary Cross Entropy for unsupervised discriminator and Categorical Cross Entropy for Supervised Discriminator.
Now the main training loop
So,we start our training by training supervised discriminator on a batch of real labeled images ,then we train unsupervised discriminator on real unlabeled images (labeled as “1”) and fake images (labeled as “0”) generated by Generator and Finally, we train our generator through discriminator by using the composite model which we had defined earlier.
So the way training of Generator works is as follows :
- The generator generates a batch of fake images.
- These generated fake images are passed through discriminator to classify and the target label to discriminator is given as real i.e “1” (line 31). Note : Discriminator is not updated during this step.
- So if the images are not realistic looking loss would be very high so to minimize the loss Generator will start generating realistic-looking images which is our goal.
So, as Generator gets trained on more images it starts generating better images and the discriminator also starts getting better because it does not want its loss to get high while classifying fake or real images. So this is exactly why both discriminator and generator are in a game against each other if one gets better the other also needs to improve.
After training we just use the Discriminator and discard the Generator because our main aim was to build a classifier for MNIST.
So, Accuracy obtained on the test set is approx 90.67 % which is quite impressive given the fact that we have just used 100 labeled images (10 per class) to train our discriminator. So, this is a powerful classifier which does not need lots of labeled data. Of course, performance will improve even more if we use say 1000 labeled images which are still quite low compared to modern datasets.
So, these are the fake images generated by the generator which are quite realistic with few exceptions.
You can also try training the same discriminator model without the GAN on 100 labeled images and see how it performs. It will be no doubt worse than the Discriminator trained with GAN.
This is the Github Repo Linked to this implemenatation
If you want to learn more about GAN’s you can try reading this book GANs in Action Deep learning with Generative Adversarial Networks by Jakub Langr, Vladimir Bok.
This is my first article on Medium, please let me know if you like it. Thanks for Reading!.