Implementation of Semi-Supervised Generative Adversarial Networks in Keras

Kunal Vaidya
Oct 13 · 7 min read
Image for post
Image for post

Everyone has heard about supervised learning and unsupervised learning but there is also another set of learning techniques in between them called semi-supervised learning.

Supervised Learning is the most popular technique used in Machine Learning but there is one disadvantage to it that it requires a lot of labeled data . It takes a lot of effort and time to label the data. So this is where Semi-Supervised Learning comes in the picture.

What is Semi-Supervised Learning?

Semi-Supervised Learning is a technique where we only use a small set of labeled data from a large amount of unlabeled data to train our model.

Seems like an interesting approach and also the cost of labeling the data is reduced drastically.

What are Generative Adversarial Networks?

Generative Adversarial Networks (known as GAN’s) are a class of generative models designed by Ian Goodfellow and his colleagues in 2014. It was a breakthrough in terms of generative models.

In GAN’s there are two neural networks that are trying to defeat each other (i.e one network’s loss is the other network’s gain). The two neural networks are called Generator and Discriminator.

The generator model tries to generate images (similar to training data) and the discriminator’s job is to classify the images as real (from training data) or fake (from the generator). So Basically Discriminator is trying not to get fooled by Generator and Generator is trying it’s best to fool the discriminator so that’s why this is called a Game Theory approach as both Generator and Discriminator are in a game trying to overcome each other.

How to use GAN for Semi-Supervised Learning?

Below is the model for Semi-Supervised GAN

Image for post
Image for post
Semi-Supervised GAN

Let’s Understand the model
The discriminator is passed three types of images namely Labeled Training Images, Unlabeled Training Images, and Fake Images generated by Generator. Its job is not only to distinguish between Real/Fake Images but also to classify the Labeled Training Images into their correct classes.

The Discriminator has two different outputs:

The powerful aspect of this is that discriminator is not only trained on labeled data but also on a large amount of unlabeled data as it has to also discriminate between Real/Fake images so for this task discriminator will need to learn to extract features for classifying images as real or fake. This adversarial training will help the discriminator to classify the labeled data more accurately as it will recognize other patterns as well while learning to classify real/fake images which normally it would not on just a small set of labeled data.

So Let’s see the implementation of the above model in Keras
The dataset which I am gonna use is the good old MNIST Dataset (Everyone’s Favourite)


First import all the necessary packages required

z_dim is the dimension of our random normal variable which we will pass to our generator.


We need to prepare two datasets as follows:

Only batch_labeled() function needs explanation here

batch_labeled() prepares the dataset for supervised discriminator by selecting a subset of samples along with their labels.We are just going to use 100 labeled examples(i.e 10 examples per class)to train our classifier.

batch_unlabeled() randomly samples images from the dataset with number of images equal to given batch_size.

Generator Network

Generator Network

Using Keras Sequential API for building our generator model.It is basic generator model as our task is not complicated.

Dense layer is used to reshape our dimensional z_dim which is of shape (batch_size,100) to (batch_size,7,7,256) so we can apply Conv2DTranspose Layer on it to generate image.

Discriminator Network

Discriminator Network

So this is our Discriminator which takes input an image (fake or real) and outputs whether the image is real or fake i.e “1” or “0”.

Here output has been kept an output from dense layer without any activation so this are values which can be negative as well.It will become clear in next step why it’s been done like this.

Supervised Discriminator

Supervised Discriminator

build_discriminator_supervised() takes as input the discriminator model we created in above step and uses softmax activation to classify our input image into one of 10 classes (for MNIST).

Unsupervised Discriminator

build_discriminator_unsupervised() takes input the discriminator we created before and it applies a custom_activation() on the output of discriminator.

As when we created discriminator we had kept its output as the output of a Dense layer which gives values without any activation. So the custom activation does the following

Image for post
Image for post
Custom Activation Function

Here z is the output of dense layer without any activation,k for our case is 10 and y is between 0 and 1.

As probability of image being real or fake will be the sum of probabilities over all classes so the activation function basically sums over all classes and scales output between [0,1] to get the probability of the image being real or not. This trick was much more effective than using a sigmoid.

The weights used by supervised and unsupervised discriminator are same only the output nodes used by both are different.

Building GAN

build_gan() basically takes input both the generator and discriminator and merges them to form a model. This is done to train the generator network. This will be called the composite model.

Note that here discriminator is the unsupervised discriminator as GAN is just used for unlabeled images.

Training Semi-Supervised GAN

First, lets build our model by compiling all the model functions we created above.

Defining Models

Note that while building GAN (composite model) we have kept discriminator_unsupervised as non-trainable (line 7) because the training of GAN is done in steps of training the Discriminator and then training the Generator so while training our Generator we don’t want our discriminator to get updated.

Cost Function used is Binary Cross Entropy for unsupervised discriminator and Categorical Cross Entropy for Supervised Discriminator.

Now the main training loop

Main Training Loop

So,we start our training by training supervised discriminator on a batch of real labeled images ,then we train unsupervised discriminator on real unlabeled images (labeled as “1”) and fake images (labeled as “0”) generated by Generator and Finally, we train our generator through discriminator by using the composite model which we had defined earlier.

So the way training of Generator works is as follows :

So, as Generator gets trained on more images it starts generating better images and the discriminator also starts getting better because it does not want its loss to get high while classifying fake or real images. So this is exactly why both discriminator and generator are in a game against each other if one gets better the other also needs to improve.

After training we just use the Discriminator and discard the Generator because our main aim was to build a classifier for MNIST.


Image for post
Image for post

So, Accuracy obtained on the test set is approx 90.67 % which is quite impressive given the fact that we have just used 100 labeled images (10 per class) to train our discriminator. So, this is a powerful classifier which does not need lots of labeled data. Of course, performance will improve even more if we use say 1000 labeled images which are still quite low compared to modern datasets.

Image for post
Image for post
Fake Images Generated by Generator

So, these are the fake images generated by the generator which are quite realistic with few exceptions.
You can also try training the same discriminator model without the GAN on 100 labeled images and see how it performs. It will be no doubt worse than the Discriminator trained with GAN.

This is the Github Repo Linked to this implemenatation

If you want to learn more about GAN’s you can try reading this book GANs in Action Deep learning with Generative Adversarial Networks by Jakub Langr, Vladimir Bok.

This is my first article on Medium, please let me know if you like it. Thanks for Reading!.


Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data…

Sign up for Analytics Vidhya News Bytes

By Analytics Vidhya

Latest news from Analytics Vidhya on our Hackathons and some of our best articles! Take a look

By signing up, you will create a Medium account if you don’t already have one. Review our Privacy Policy for more information about our privacy practices.

Check your inbox
Medium sent you an email at to complete your subscription.

Kunal Vaidya

Written by

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem

Kunal Vaidya

Written by

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store