Understanding GAN with codes

GAN implementation on Fashion-MNIST using Tensorflow 2

Mehul Gupta
Data Science in your pocket
9 min readAug 23, 2021

--

Photo by Ulises Baga on Unsplash

As of now, we are done with

It's high time we explore GANs I guess !!

GANs or Generative Adversarial Network is a combination of two networks (assume them to be A & B for now) that helps us get to a system that produces, when fed with some random input (& I mean it), similar images as in the training data. Before deep-diving how this system of 2 networks works, we must understand what does an Adversarial system mean:

An adversarial system constitutes opposing powers/players/components that have motives opposite to each other. This is similar to any 2 player game/contest where each player wishes to win leading to another person’s loss. So, its more of a fight between two players to outperform each other. Something like this:

How is it related to GANs? will figure it out shortly

It’s time to understand the 2 networks, A & B

  • Generator: Keeping it simple, This network intakes a random input/vector & is responsible for producing fake images similar to the training dataset. It is usually a decoder network.
  • Discriminator: This network’s aim is to detect fake images produced by the Generator from a mix of real(training images) & fake images. It is more of a binary classifier (Real & Fake image detection). It is usually an encoder network.

Where is the Adversarial system?

In a typical GAN, the Generator & the Discriminator compete against each other to get better in their task. Hence

The discriminator wish to minimize the classification error (Real vs Fake)

The generator wishes to maximize the classification error (which would mean the discriminator now won’t understand the difference between fake & real) hence generating images that look real.

& hence the fight to outperform each other exists between the two !!

Though, the idea looks clear but still unable to make out how the flow looks like? Here we go

1.A random vector/vectors are generated (how, randomly)

2.This is fed to Generator to generate images.

3.These generated images (fake) are mixed with training dataset images (real)

4.This mixed dataset is fed Discriminator detecting which samples are fake (i.e. generated using a generator) & which ones are real (in training data), a binary classifier

5.This classifier provides feedback to the Generator to improve the quality of generated images. Simultaneously, even the discriminator learns from the misclassification it did. Hence, both the networks tend to improve gradually using backpropagation.

6.Eventually, the Generator becomes a pro liar generating images that the decoder becomes confused about & won’t be able to distinguish between real & fake images.

7.Once this system reaches Nash Equilibrium (ideally it should, practical this is seldom), the discriminator is detached from the system & we are left with just a Generator to produce fake images.

That’s too much of a theory, let’s code it out using TensorFlow 2.

Note: Below codes are available on Github at

  1. Import required libraries
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import os

2. We will be using the Fashion-MNIST dataset for training available here. The idea is to generate samples similar to Fashion-MNIST dataset.

#import csv (after unzipping)
train = pd.read_csv('fashion-mnist_train.csv')
#We won't require the 'labels' hence dropped. Also, reshaping flat #vectors in the csv as images of the dimension 28x28x1
train = train.drop(['label'],axis=1).to_numpy().reshape(-1,28,28,1)

3. Shouldn’t we observe a few samples before going forward? Yes as we must know what are we expecting an output.

#Declaring subplots
fig,axes = plt.subplots(1,5,figsize=(15,15))
#plotting 5 images
for index,x in enumerate(train[:5]):
axes[index].imshow(x)
Fashion MNIST samples

4. The next steps involved helps in creating a preprocessing pipeline for easy use of data in tensorflow using tf.data. This provides us the capability to do a number of steps simultaneously.

Note: Things may get tough from here!

def preprocess(records):
images = records['image']
images = tf.cast(images, tf.float32)/255.0
return images
#converting numpy array to tensorflow dataset
dataset = tf.data.Dataset.from_tensor_slices({'image':train})
#Once converted to tensorflow dataset, map() acts as apply() used in #dataframe in python i.e. mapping each entry to a function
dataset = dataset.map(preprocess)
#repeat(3): repeat the dataset thrice. No number means infinite
#shuffle(100): shuffle the dataset using 100 as buffer_size
#batch(128): divide dataset in batches of 128 elements each
#prefetch(1): prefetch 1 batch in advance before getting requested #hence reducing latency
dataset = dataset.repeat().shuffle(100).batch(128).prefetch(1)

On print(dataset), it should output

<PrefetchDataset shapes: (None, 28, 28, 1), types: tf.float32>

5. Declaring a few variables

input_shape = (28,28,1)
final_encoder_dim = 2 #discriminator output dimension
depth = 5 #depth for both discrimintor & generator
kernel_size = 3
activation = 'tanh'
dropout = 0.1
decoder_input_dim = 4 #generator input embedding length
epochs = 100

6. Designing Discriminator

def discriminator(input_shape,dim,depth,kernel,dropout,activation):
layers = []
layers.append(InputLayer(input_shape=input_shape))
for i in range(1,depth):
layers.append(Conv2D(16*i,kernel_size=kernel_size))
layers.append(BatchNormalization())
layers.append(Activation('relu'))
layers.append(Dropout(dropout))
layers.append(Flatten())
layers.append(Dense(128,activation='relu'))
layers.append(Dense(dim))
return Sequential(layers)encoder = discriminator(input_shape, final_encoder_dim, depth, kernel_size, dropout,activation)

Let’s decode it one by one

  • The InputLayer() intakes the input_shape i.e. 28x28x1. This layer acts as a placeholder for actual input
  • Then we add a Convolution layer followed by BatchNormalization, ReLU activation & Dropout multiple times in the same sequence depending upon the depth required (here it is 5)
  • The rest of the code is similar to any CNN

Now, its turn for the Generator

def generator(input_shape, depth, output_shape,kernel,dropout):
layers = []
layers.append(InputLayer(input_shape=(input_shape,)))
layers.append(Dense(784,activation='relu'))
layers.append(Reshape(target_shape=output_shape))
for i in range(1,depth):
layers.append(Conv2DTranspose(16*i,kernel_size=kernel))
layers.append(BatchNormalization())
layers.append(Activation('relu'))
layers.append(Dropout(dropout))

resizer = lambda name: Lambda(lambda images: tf.image.resize(images, [28,28]), name=name)
layers.append(resizer('Reshape'))
layers.append(Conv2DTranspose(1,kernel_size=1,activation=None))
return Sequential(layers)
decoder = generator(decoder_input_dim, depth, input_shape, kernel_size ,dropout)
  • InputLayer() intaking input_shape equal to (4,) in our case. This can be of any dimension
  • The Dense layer is used to map this lower dimension a higher dimension & the high dimension embedding is then reshaped
  • Similar to Discriminator, depending upon the depth, a similar code block is appended except the Conv2D layer is replaced by Conv2DTranspose which is more of an inverse Conv2D layer upsampling the data unlike Conv2D that downsamples the image.

Why Conv2DTranspose required? As we wish to generate an image (higher dimension) from a noise vector(lower dimension), we not only need to upsample the noise vector but do it intelligently so as to produce the desired output. Conv2DTranspose helps us to learn the weights for proper upsampling & not just upsampling by some rule. Though, even a Upsampling2D layer can also be used. In that case, there won’t be any learning going on while upsampling

  • The resizer lambda function helps in resizing the image size to training images dimension & the last Conv2DTranspose layer reduces the number of channels to 1.

So finally, we have designed our Discriminator & Generator. These two will be trained simultaneously. Now we need to define the loss function & optimizers for the 2 networks

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
encoder_opt = tf.keras.optimizers.Adam()
decoder_opt = tf.keras.optimizers.Adam()

Training the two networks is a hard nut to crack. To simplify things, this has been broken down into 3 segments that are: 1) Batch Training 2) Epoch training 3) Complete Training.

Let’s explore them one by one

def batch_training(training_data):    batch_size = tf.shape(training_data)[0]
real_data = training_data
real_labels = tf.ones((batch_size,))
fake_labels = tf.zeros((batch_size,))
labels = tf.concat((real_labels,fake_labels),axis=0)
noise = tf.random.normal((batch_size, decoder_input_dim), mean=0,stddev=1)

with tf.GradientTape() as decoder_gt, tf.GradientTape() as encoder_gt:

fake_images = decoder(noise,training=True)
fake_labels_2 = encoder(fake_images,training=True)
real_labels_2 = encoder(training_data,training=True)
predicted_labels = tf.concat((real_labels_2,fake_labels_2) ,axis=0)
discrim_loss = loss(labels,predicted_labels)
gen_loss = loss(real_labels,fake_labels_2)

dec_grad = decoder_gt.gradient(gen_loss, decoder.trainable_variables)
enc_grad = encoder_gt.gradient(discrim_loss, encoder.trainable_variables)

decoder_opt.apply_gradients(zip(dec_grad, decoder.trainable_variables))
encoder_opt.apply_gradients(zip(enc_grad, encoder.trainable_variables))

return discrim_loss, gen_loss

Let’s understand what this code block trying to do

  • It, first of all, generates ground truth for the Discriminator (the binary classifier). For this, it assigns a label=1 for real(from training dataset) images & 0 to any image generated/will be generated from Generator. So, if the batch_size=100, then we will generate 200 labels, 100 labels as 1 & other 100 as 0
  • Let the generator produce fake images (=batch_size).
  • Feed the two image batches (real from training dataset & fake produce by Generator) & get their predictions from the Discriminator.
  • Calculate the loss for the 2 networks using the ground truth defined earlier:

For Discriminator = loss calculated over all samples (real & fake combined) that are misclassified.

For Generator= loss calculated over fake images that were correctly classified as Fake hence the Discriminator is able to distinguish between real & fake samples

  • The rest of the 2 code couplets is more towards backpropagating & improving both Generator & Discriminator by calculating & applying gradients. This is done using the GradientTape() objects
  • The losses for the 2 networks are returned

Moving onto epoch & complete training


def epoch_training(data_iterator, steps_per_epoch, avg_gen_loss, avg_dis_loss):
for x in range(steps_per_epoch):
d_loss, g_loss = batch_training(next(data_iterator))
avg_gen_loss.update_state(g_loss)
avg_dis_loss.update_state(d_loss)

gen_loss = avg_gen_loss.result()
dis_loss = avg_dis_loss.result()

tf.summary.scalar('gen_loss',gen_loss,step=encoder_opt.iterations)
tf.summary.flush()
tf.summary.scalar('dis_loss',dis_loss,step=decoder_opt.iterations)
tf.summary.flush()

avg_gen_loss.reset_state()
avg_dis_loss.reset_state()
return gen_loss.numpy(), dis_loss.numpy()


def complete_train(training_data, epochs):
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=decoder_opt ,discriminator_optimizer=encoder_opt,generator=decoder,discriminator=encoder)
manager = tf.train.CheckpointManager(checkpoint, 'training_checkpoints', max_to_keep=5)
checkpoint.restore(manager.latest_checkpoint)
summary = tf.summary.create_file_writer('metrics/train')

avg_generator_loss = tf.keras.metrics.Mean()
avg_discriminator_loss = tf.keras.metrics.Mean()

data_iterator = training_data.__iter__()
for i in range(epochs):
with summary.as_default():
gen_loss, dis_loss = epoch_training(data_iterator, 100, avg_generator_loss, avg_discriminator_loss)
print({'gen_loss':gen_loss,'dis_loss':dis_loss})
manager.save()

That’s a lot of code, to begin with. Let’s start off with complete_training() as epoch_training() has a dependence on it.

complete_training()

  • The 1st 5 lines denote setting up checkpointing mechanism in our training so as to save trained models from time to time using tf.train.Checkpoint. The only point to note here is a single checkpoint object can help us in saving multiple models & optimizers.
  • The checkpoint.restore helps in restoring models & optimizers if already exists (helpful when we are doing training in parts)
  • tf.summary is used to log different metrics or images we wish to store while training & plot on Tensorboard afterward.
  • avg_generator_loss & avg_discriminator_loss are defined so as to calculate loss per epoch for respective models used accumulating batch errors.
  • The TensorFlow dataset (Fashion MNIST) is converted into an iterable object (python’s generator) using __iter__()
  • For each epoch, we call the epoch_training() function that returns 2 losses for generator & discriminator respectively. Using summary.as_default(), any tf.summary written gets dumped in this object as done in epoch_training().
  • manager.save() helps in saving the current status of objects (models & optimizers in our code) in the directory declared in the Checkpoint object.

Moving onto epoch_training()

  • For each epoch step, a batch from the dataset_iterator is fetched & fed to batch_training(), returning batch_loss for generator & discriminator. next() helps in getting the next element of an iterable object
  • These batch_loss updates avg_generator/discriminator_loss objects declared in complete_training() & passed as parameter. Once the epoch is done (all steps_per_epoch done), the status for these avg_generator/discriminator_loss objects are reset for next epoch & returned back to complete_training().
  • Using tf.summary.scalar, these losses are logged in a the file mentioned while declaring tf.summary.create_file_writer object
  • What is tf.summary.flush()? It forces the FileWriter object to write to log files immediately which otherwise may take time & some updates may get lost.

To initiate training, enter

complete_training(dataset,epochs)

DONE WITH TRAINING

Attached below are a few outputs generated after training the models for a few iterations (results will definitely improve when further trained):

#generating 5 images using random noise vectors
noise = tf.random.normal((5,4),mean=0,stddev=1)
images = decoder(noise,training=False)
fig,axes = plt.subplots(1,5,figsize=(15,15))
for index,x in enumerate(images):
axes[index].imshow(x)

Common problems while training a GAN

GANs, as mentioned earlier, are a hard nut to crack when it comes to training. A few common issues are mentioned below

  • Oscillating loss: To be assured of whether your GAN is converging/improving, the losses shouldn’t oscillate wildly. To ensure convergence it should, 1) if oscillating, it should be in small magnitude or 2) move in one direction i.e. either increasing or decreasing. Wildly oscillating losses aren’t good signs for your training & unfortunately, are observed quite frequently.
  • Mode collapse: This happens when the discriminator isn’t able to detect a certain small set of images/single image as fake. Hence, the Generator maps every incoming noise vector to this small set/single image (called a mode). This can lead to repetitive images getting generated.
  • Uninformative loss: This is something even I faced while training the above GAN. As both generator & discriminator improves over time, the loss at any point in time isn’t comparable with its past values. So, it might be the case you may observe Generator’s loss going sky high but the image quality produced is gradually improving !! Strange days

How to resolve these? will come back on this shortly

--

--