Semi-supervised learning with GANs

In this post I will cover a partial re-implementation of a recent paper on manifold regularization (Lecouat et al., 2018) for semi-supervised learning with Generative Adversarial Networks (Goodfellow et al., 2014). I will attempt to re-implement their main contribution, rather than getting all the hyperparameter details just right. Also, for the sake of demonstration, time constraints and simplicity, I will consider the MNIST dataset rather than the CIFAR10 or SVHN datasets as done in the paper. Ultimately, this post aims at bridging the gap between the theory and implementation for GANs in the semi-supervised learning setting. The code that comes with this post can be found here.

Generative Adversarial Networks

Let’s quickly go over Generative Adversarial Networks (GAN). In terms of the current pace within the AI/ML community, they have been around for a while (just about 4 years), so you might already be familiar with them. The 'vanilla' GAN procedure is to train a generator to generate images that are realistic and capable of fooling a discriminator. The generator generates the images by means of a deep neural network that takes in a noise vector z.

The discriminator (which is a deep neural network as well) is fed with the generated images, but also with some real data. Its job is to say whether each image is either real (coming from the dataset) or fake (coming from the generator), which in terms of implementation comes down to binary classification. The image below summarizes the vanilla GAN setup.

Vanilla GAN setup.

Semi-supervised learning

Semi-supervised learning problems concern a mix of labeled and unlabeled data. Leveraging the information in both the labeled and unlabeled data to eventually improve the performance on unseen labeled data is an interesting and more challenging problem than merely doing supervised learning on a large labeled dataset. In this case we might be limited to having only about 200 samples per class. So what should we do when only a small portion of the data is labeled?

Note that adversarial training of vanilla GANs doesn't require labeled data. At the same time, the deep neural network of the discriminator is able to learn powerful and robust abstractions of images by gradually becoming better at discriminating fake from real. Whatever it's learning about unlabeled images will presumably also yield useful feature descriptors of labeled images. So how do we use the discriminator for both labeled and unlabeled data? Well, the discriminator is not necessarily limited to just telling fake from real. We could decide to train it to also classify the real data.

A GAN with a classifying discriminator would be able to exploit both the unlabeled as well as the labeled data. The unlabeled data will be used to merely tell fake from real. The labeled data would be used to optimize the classification performance. In practice, this just means that the discriminator has a softmax output distribution for which we minimize the cross-entropy. Indeed, part of the training procedure is just doing supervised learning. The other part is about adversarial training. The image below summarizes the semi-supervised learning setup with a GAN.

Semi-supervised learning setup with a GAN.

The implementation

Let's just head over to the implementation, since that might be the best way of understanding what's happening. The snippet below prepares the data. It doesn't really contain anything sophisticated. Basically, we take 400 samples per class and concatenate the resulting arrays as being our actual supervised subset. The unlabeled dataset consists of all train data (it also includes the labeled data, since we might as well use it anyway). As is customary for training GANs now, the output of the generator uses a hyperbolic tangent function, meaning its output is between -1 and +1. Therefore, we rescale the data to be in that range as well. Then, we create TensorFlow iterators so that we can efficiently go through the data later without having to struggle with feed dicts later on.

Next up is to define the discriminator network. I have deviated quite a bit from the architecture in the paper. I’m going to play safe here and just use Keras layers to construct the model. Actually, this enables us to very conveniently reuse all weights for different input tensors, which will prove to be useful later on. In short, the discriminator’s architecture uses 3 convolutions with 5x5 kernels and strides of 2x2, 2x2 and 1x1 respectively. Each convolution is followed by a leaky ReLU activation and a dropout layer with a dropout rate of 0.3. The flattened output of this stack of convolutions will be used as the feature layer.

The feature layer can be used for a feature matching loss (rather than a sigmoid cross-entropy loss as in vanilla GANs), which has proven to yield a more reliable training process. The part of the network up to this feature layer is defined in _define_tail in the snippet below. The _define_head method defines the rest of the network. The 'head' of the network introduces only one additional fully connected layer with 10 outputs, that correspond to the logits of the class labels. Other than that, there are some methods to make the interface of a Discriminator instance behave similar to that of a tf.keras.models.Sequential instance.

The generator's architecture also uses 5x5 kernels. Many implementations of DCGAN-like architectures use transposed convolutions (sometimes wrongfully referred to as 'deconvolutions'). I have decided to give the upsampling-convolution alternative a try. This should alleviate the issue of the checkerboard pattern that sometimes appears in generated images. Other than that, there are ReLU nonlinearities, and a first layer to go from the 100-dimensional noise to a (rather awkwardly shaped) 7x7x64 spatial representation.

I have tried to make this model work with what TensorFlow's Keras layers have to offer so that the code would be easy to digest (and to implement of course). This also means that I have deviated from the architectures in the paper (e.g. I'm not using weight normalization). Because of this experimental approach, I have also experienced just how sensitive the training setup is to small variations in network architectures and parameters. There are plenty of neat GAN 'hacks' listed here which I definitely found insightful.

Putting it together

Let's do the forward computations now so that we see how all of the above comes together. This consists of setting up the input pipeline, noise vector, generator and discriminator. The snippet below does all of this. Note that when define_generator returns the Sequential instance, we can just use it as a functor to obtain the output of it for the noise tensor given by z.

The discriminator will do a lot more. It will take (i) the 'fake' images coming from the generator, (ii) a batch of unlabeled images and finally (iii) a batch of labeled images (both with and without dropout to also report the train accuracy). We can just repetitively call the Discriminator instance to build the graph for each of those outputs. Keras will make sure that the variables are reused in all cases. To turn off dropout for the labeled training data, we have to pass training=False explicitly.

The discriminator's loss

Recall that the discriminator will be doing more than just separating fake from real. It also classifies the labeled data. For this, we define a supervised loss which takes the softmax output. In terms of implementation, this means that we feed the unnormalized logits to tf.nn.sparse_cross_entropy_with_logits.

Defining the loss for the unsupervised part is where things get a little bit more involved. Because the softmax distribution is overparameterized, we can fix the unnormalized logit at 0 for an image to be fake (i.e. coming from the generator). If we do so, the probability of it being real just turns into:

Probability of observing a fake image

where Z(x) is the sum of the unnormalized probabilities. Note that we currently only have the logits. Ultimately, we want to use the log-probability of the fake class to define our loss function. This can now be achieved by computing the whole expression in log-space:

Where the lower case l with subscripts denote the individual logits. Divisions become subtractions and sums can be computed by the logsumexp function. Finally, we have used the definition of the softplus function:

In general, if you have the log-representation of a probability, it is numerically safer to keep things in log-space for as long as you can, since we are able to represent much smaller numbers in that case.

We're not there yet. Generative adversarial training asks us to ascend the gradient of:

So whenever we call tf.train.AdamOptimizer.minimize we should descent:

Time to rewrite this though…

The first term on the right-hand side of the equation can be written:

The second term of the right-hand side can be written as:

So that finally, we arrive at the following loss:

Optimizing the discriminator

Let's setup the operations for actually updating the parameters of the discriminator. We will just reside to the Adam optimizer. While tweaking the parameters before I wrote this post, I figured I might slow down the discriminator by setting its learning rate at 0.1 times that of the generator. After that my results got much better, so I decided to leave it there for now. Notice also that we can very easily select the subset of variables corresponding to the discriminator by exploiting the encapsulation offered by Keras.

Adding some control flow to the graph

After we have the new weights for the discriminator, we want the generator’s update to be aware of the updated weights. TensorFlow will not guarantee that the updated weights will actually be used even if we were to redeclare the forward computation after defining the minimization operations for the discriminator. We can still force this by using tf.control_dependencies. Any operation defined in the scope of this context manager will depend on the evaluation of the ones that are passed to context manager at instantiation. In other words, our generator’s update that we define later on will be guaranteed to compute the gradients using the updated weights of the discriminator.

The generator's loss and updates

In this implementation, the generator tries to minimize the L2 distance of the average features of the generated images vs. the average features of the real images. This feature-matching loss (Salimans et al., 2016) has proven to be more stable for training GANs than directly trying to optimize the discriminator’s probability for observing real data. It is straightforward to implement. While we’re at it, let’s also define the update operations for the generator. Notice that the learning rate of this optimizer is 10 times that of the discriminator.

Adding manifold regularization

Lecouat et. al (2018) propose to add manifold regularization to the feature-matching GAN training procedure of Salimans et al. (2016). The regularization forces the discriminator to yield similar logits (unnormalized log probabilities) for nearby points in the latent space in which z resides. It can be implemented by generating a second perturbed version of z and computing the generator's and discriminator's outputs once more with this slightly altered vector.

This means that the noise generation code looks as follows:

The discriminator's loss will be updated as follows (note the 3 extra lines at the bottom):

Classification performance

So how does it really perform? I have provided a few plots below. There are many things I might try to squeeze out additional performance (for instance, just training for longer, using a learning rate schedule, implementing weight normalization), but the main purpose of writing this post was to get to know a relatively simple yet powerful semi-supervised learning approach. After 100 epochs of training, the mean test accuracy approaches 98.9 percent.

The full script can be found here. Thanks for reading!

Jos van de Wolfshaar

Written by

Machine Learning Practitioner, Music Enthusiast, building next-gen communication technologies at MessageBird.