Using GANS for semi-supervised learning
In supervised learning, we have a training set of inputs x and class labels y. We train a model that takes x as input and gives y as output.
In semi-supervised learning, our goal is still to train a model that takes x as input and generates y as output. However, not all of our training examples have a label y. We need to develop an algorithm that is able to get better at classification by studying both labeled (x,y) pairs and unlabeled x examples.
We will use a GAN discriminator as a n+1 class discriminator. It will recognize the n different classes of labeled data, as well as an (n+1)th class of fake images that come from the generator. The discriminator will get to train on real labeled images, real unlabeled images, and fake images. By drawing on three sources of data instead of just one, it will generalize to the test set much better than a traditional classifier trained on only one source of data.
Typical generator that outputs fake images with same dimension as real images.
Things to note:
- Since output activation is tanh (values range from -1 to +1), this means we need to scale real images to same range when we train the discriminator.
- We use standard convolution transpose, batch norm and leaky relu for each layer of generator.
def generator(z, output_dim, reuse=False, alpha=0.2, training=True, size_mult=128):
with tf.variable_scope('generator', reuse=reuse):
# First fully connected layer
x1 = tf.layers.dense(z, 4 * 4 * size_mult * 4)
# Reshape it to start the convolutional stack
x1 = tf.reshape(x1, (-1, 4, 4, size_mult * 4))
x1 = tf.layers.batch_normalization(x1, training=training)
x1 = tf.maximum(alpha * x1, x1)
x2 = tf.layers.conv2d_transpose(x1, size_mult * 2, 5, strides=2, padding='same')
x2 = tf.layers.batch_normalization(x2, training=training)
x2 = tf.maximum(alpha * x2, x2)
x3 = tf.layers.conv2d_transpose(x2, size_mult, 5, strides=2, padding='same')
x3 = tf.layers.batch_normalization(x3, training=training)
x3 = tf.maximum(alpha * x3, x3)
# Output layer
logits = tf.layers.conv2d_transpose(x3, output_dim, 5, strides=2, padding='same')
out = tf.tanh(logits)
Discriminator is more complex than generator here due to unsupervised component.
Things to note:
- Aggressive and widespread use of dropout. This is because we have a big hammer of deep convolution layer on a small set of labelled train examples (note most of the train examples are unlabeled) and this makes it hard to not overfit on train data and teach discriminator to generalize better.
- We dont use batch normalization in first layer of discriminator as it is important for discriminator to look at distribution of data from real and fake data.
- We dont use batch normalization in last layer of discriminator. Bn would set the mean of each feature to Bn mu parameter. This layer is used for the feature matching loss, which only works if the means can be different when the discriminator is run on the data than when the discriminator is run on the generator samples.
- As with standard GAN discriminator we use convolution with strides, batch norm and leaky relu for each layer and we use convolution strides instead of max pool.
- For the last layer we use global average pooling to get the features instead of dense layer.
- class_logits = softmax distribution over the different classes.
- gan_logits such that P(input is real vs fake) = sigmoid(gan_logits)
def discriminator(x, reuse=False, alpha=0.2, drop_rate=0., num_classes=10, size_mult=64):
with tf.variable_scope('discriminator', reuse=reuse):
x = tf.layers.dropout(x, rate=drop_rate/2.5)
# Input layer is 32x32x3
x1 = tf.layers.conv2d(x, size_mult, 3, strides=2, padding='same')
relu1 = tf.maximum(alpha * x1, x1)
relu1 = tf.layers.dropout(relu1, rate=drop_rate)
x2 = tf.layers.conv2d(relu1, size_mult, 3, strides=2, padding='same')
bn2 = tf.layers.batch_normalization(x2, training=True)
relu2 = tf.maximum(alpha * x2, x2)
x3 = tf.layers.conv2d(relu2, size_mult, 3, strides=2, padding='same')
bn3 = tf.layers.batch_normalization(x3, training=True)
relu3 = tf.maximum(alpha * bn3, bn3)
relu3 = tf.layers.dropout(relu3, rate=drop_rate)
x4 = tf.layers.conv2d(relu3, 2 * size_mult, 3, strides=1, padding='same')
bn4 = tf.layers.batch_normalization(x4, training=True)
relu4 = tf.maximum(alpha * bn4, bn4)
x5 = tf.layers.conv2d(relu4, 2 * size_mult, 3, strides=1, padding='same')
bn5 = tf.layers.batch_normalization(x5, training=True)
relu5 = tf.maximum(alpha * bn5, bn5)
x6 = tf.layers.conv2d(relu5, 2 * size_mult, 3, strides=2, padding='same')
bn6 = tf.layers.batch_normalization(x6, training=True)
relu6 = tf.maximum(alpha * bn6, bn6)
relu6 = tf.layers.dropout(relu6, rate=drop_rate)
x7 = tf.layers.conv2d(relu5, 2 * size_mult, 3, strides=1, padding='valid')
# Don't use bn on this layer, because bn would set the mean of each feature
# to the bn mu parameter.
# This layer is used for the feature matching loss, which only works if
# the means can be different when the discriminator is run on the data than
# when the discriminator is run on the generator samples.
relu7 = tf.maximum(alpha * x7, x7)
# Flatten it by global average pooling
features = tf.reduce_mean(relu7, (1, 2))
# Set class_logits to be the inputs to a softmax distribution over the different classes
class_logits = tf.layers.dense(features, num_classes + extra_class)
# Set gan_logits such that P(input is real | input) = sigmoid(gan_logits).
# Keep in mind that class_logits gives you the probability distribution over all the real
# classes and the fake class. You need to work out how to transform this multiclass softmax
# distribution into a binary real-vs-fake decision that can be described with a sigmoid.
# Numerical stability is very important.
# You'll probably need to use this numerical stability trick:
# log sum_i exp a_i = m + log sum_i exp(a_i - m).
# This is numerically stable when m = max_i a_i.
# (It helps to think about what goes wrong when...
# 1. One value of a_i is very large
# 2. All the values of a_i are very negative
# This trick and this value of m fix both those cases, but the naive implementation and
# other values of m encounter various problems)
real_class_logits, fake_class_logits = tf.split(class_logits, [num_classes, 1], 1)
assert fake_class_logits.get_shape() == 1, fake_class_logits.get_shape()
fake_class_logits = tf.squeeze(fake_class_logits)
real_class_logits = class_logits
fake_class_logits = 0.
mx = tf.reduce_max(real_class_logits, 1, keep_dims=True)
stable_real_class_logits = real_class_logits - mx
gan_logits = tf.log(tf.reduce_sum(tf.exp(stable_real_class_logits), 1)) + tf.squeeze(mx) - fake_class_logits
out = tf.nn.softmax(class_logits)
return out, class_logits, gan_logits, features
Model Loss functions and optimizer:
d_loss: the loss for the discriminator is combination of:
- The loss for the GAN problem, where we minimize the cross-entropy for the binary real-vs-fake classification problem: d_loss_real and d_loss_fake
- The loss for the multi classification problem on real supervised data (with labels): d_loss_class. For d_loss_class we only consider the samples in minibatch that have the labels, however it is possible that none of the samples in minibatch have labels so following expression will be 0/1 and give a loss of 0 for d_loss_class. “tf.reduce_sum(label_mask * class_cross_entropy) / tf.maximum(1., tf.reduce_sum(label_mask))”
g_loss: the loss for the generator is“feature matching” loss invented by Tim Salimans at OpenAI. This loss consists of minimizing the absolute difference between the expected features on the data and the expected features on the generated samples. This loss works better for semi-supervised learning than the tradition GAN losses. Over time this will force generator to produce samples similar to real data.
Note that using feature loss for generator, generator samples are not going to be as good as one generated in traditional DCGAN, however goal here is to achieve high discriminator accuracy for the multi class classification using small set of labeled data and large set of unlabeled data.
def model_loss(input_real, input_z, output_dim, y, num_classes, label_mask, alpha=0.2, drop_rate=0.):
Get the loss for the discriminator and generator
:param input_real: Images from the real dataset
:param input_z: Z input
:param output_dim: The number of channels in the output image
:param y: Integer class labels
:param num_classes: The number of classes
:param alpha: The slope of the left half of leaky ReLU activation
:param drop_rate: The probability of dropping a hidden unit
:return: A tuple of (discriminator loss, generator loss)
# These numbers multiply the size of each layer of the generator and the discriminator,
# respectively. You can reduce them to run your code faster for debugging purposes.
g_size_mult = 32
d_size_mult = 64
# Here we run the generator and the discriminator
g_model = generator(input_z, output_dim, alpha=alpha, size_mult=g_size_mult)
d_on_data = discriminator(input_real, alpha=alpha, drop_rate=drop_rate, size_mult=d_size_mult)
d_model_real, class_logits_on_data, gan_logits_on_data, data_features = d_on_data
d_on_samples = discriminator(g_model, reuse=True, alpha=alpha, drop_rate=drop_rate, size_mult=d_size_mult)
d_model_fake, class_logits_on_samples, gan_logits_on_samples, sample_features = d_on_samples
# Here we compute `d_loss`, the loss for the discriminator.
# This should combine two different losses:
# 1. The loss for the GAN problem, where we minimize the cross-entropy for the binary
# real-vs-fake classification problem.
# 2. The loss for the SVHN digit classification problem, where we minimize the cross-entropy
# for the multi-class softmax. For this one we use the labels. Don't forget to ignore
# use `label_mask` to ignore the examples that we are pretending are unlabeled for the
# semi-supervised learning problem.
d_loss_real = tf.reduce_mean(
d_loss_fake = tf.reduce_mean(
y = tf.squeeze(y)
class_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=class_logits_on_data,
labels=tf.one_hot(y, num_classes + extra_class,
class_cross_entropy = tf.squeeze(class_cross_entropy)
label_mask = tf.squeeze(tf.to_float(label_mask))
d_loss_class = tf.reduce_sum(label_mask * class_cross_entropy) / tf.maximum(1., tf.reduce_sum(label_mask))
d_loss = d_loss_class + d_loss_real + d_loss_fake
# Here we set `g_loss` to the "feature matching" loss invented by Tim Salimans at OpenAI.
# This loss consists of minimizing the absolute difference between the expected features
# on the data and the expected features on the generated samples.
# This loss works better for semi-supervised learning than the tradition GAN losses.
data_moments = tf.reduce_mean(data_features, axis=0)
sample_moments = tf.reduce_mean(sample_features, axis=0)
g_loss = tf.reduce_mean(tf.abs(data_moments - sample_moments))
pred_class = tf.cast(tf.argmax(class_logits_on_data, 1), tf.int32)
eq = tf.equal(tf.squeeze(y), pred_class)
correct = tf.reduce_sum(tf.to_float(eq))
masked_correct = tf.reduce_sum(label_mask * tf.to_float(eq))
return d_loss, g_loss, correct, masked_correct, g_model
Credits: From lecture notes: https://classroom.udacity.com/nanodegrees/nd101/syllabus