MNIST-GAN: Detailed step by step explanation & implementation in code

Don’t know anything about GAN? You’ve come to the right place!

Garima Nishad
Intel Student Ambassadors
8 min readAug 1, 2019

--

In this blog post, I will give an introduction to DCGANs through an example. We’ll be building a generative adversarial network (GAN) trained on the MNIST dataset. From this, we’ll be able to generate new handwritten digits! We will train a generative adversarial network (GAN) to generate new handwritten digits after showing it pictures of many real handwritten digits. This document will give a thorough explanation of the implementation and shed light on how and why this model works.
But don’t worry, no former experience of GANs is required, but it may require a first-timer to spend some time reasoning about what is really happening under the hood. Let's start from the beginning!

Google Trend’s Interest over time for the term “GANs”

GANs were first reported on in 2014 from Ian Goodfellow and others in Yoshua Bengio’s lab. Since then, GANs have exploded in popularity. Here are a few examples to check out:

So, the idea behind Generative Adversarial Nets is that you have two networks, a generator G and a discriminator D, competing against each other

Generator — The generator makes “fake” data to pass to the discriminator i.e. the only job of the generator is to spawn ‘fake’ images that look like the training images.
Discriminator — The discriminator also sees real training data and predicts if the data it’s received is real or fake i.e. only job of the discriminator is to look at an image and output whether or not it is a real training image or a fake image from the generator.

The final take away is that:

  • The generator is trained to fool the discriminator, it wants to output data that looks as close as possible to real, training data.
  • The discriminator is a classifier that is trained to figure out which data is real and which is fake.

During training, the generator is constantly trying to outsmart the discriminator by generating better and better fakes, while the discriminator is working to become a better detective and correctly classify the real and fake images. The equilibrium of this game is when the generator is generating perfect fakes that look as if they came directly from the training data, and the discriminator is left to always guess at 50% confidence that the generator output is real or fake.
What ends up happening is that the generator learns to make data that is indistinguishable from real data to the discriminator.

The general structure of a GAN is shown in the diagram above, using MNIST images as data. The latent sample is a random vector that the generator uses to construct its fake images.
This is often called a latent vector and that vector space is called latent space. As the generator trains, it figures out how to map latent vectors to recognizable images that can fool the discriminator.

If you’re interested in generating only new images, you can throw out the discriminator after training.

I’ll show you how to define and train these adversarial networks in PyTorch and generate new images!

  • Download the MNIST data :
  1. Here you’ll define the number of subprocesses to use for data loading
  2. Then define how many samples per batch to load i.e. batch size — ideal batch size ranges from 32 to 128.
  3. Then convert data to torch.FloatTensor
  4. Later get the training datasets
  5. And prepare data loader which helps to load the data in the batch size mentioned above.
from torchvision import datasets
import torchvision.transforms as transforms
# 1
num_workers = 0
# 2
batch_size = 64
# 3
transform = transforms.ToTensor()
# 4
train_data = datasets.MNIST(root=’data’, train=True,
download=True, transform=transform)
# 5
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=num_workers)
  • Visualize the data that you’re dealing with:
  1. Here you obtain one batch of training images. Here, “dataiter” will iterate through images & labels that are present in the dataset.
  2. Later you get one image from the batch. You can change the index inside images[0] to view another image.
#1 
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()
#2
img = np.squeeze(images[0])
fig = plt.figure(figsize = (3,3))
ax = fig.add_subplot(111)
ax.imshow(img, cmap=’gray’)

You should get an output like this:

  • Define the Model :

A GAN is comprised of two adversarial networks, a discriminator and a generator.

Discriminator :

The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we’ll need at least one hidden layer, and these hidden layers should have one key attribute:

All hidden layers will have a Leaky ReLu activation function applied to their outputs.

But why shall we use only Leaky ReLU?
We should use a leaky ReLU to allow gradients to flow backward through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.

  1. Here you’ve defined all hidden linear layers
  2. Then created a final fully-connected layer
  3. In the forward function, you’ll first need to flatten the image.
  4. And then define all hidden layers
  5. Eventually, create a final Layer
import torch.nn as nn
import torch.nn.functional as F
class Discriminator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):
super(Discriminator, self).__init__()

# 1
self.fc1 = nn.Linear(input_size, hidden_dim*4)
self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)
self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)

# 2
self.fc4 = nn.Linear(hidden_dim, output_size)

# dropout layer
self.dropout = nn.Dropout(0.3)


def forward(self, x):
#3
x = x.view(-1, 28*28)
#4
x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
x = self.dropout(x)
x = F.leaky_relu(self.fc2(x), 0.2)
x = self.dropout(x)
x = F.leaky_relu(self.fc3(x), 0.2)
x = self.dropout(x)
# 5
out = self.fc4(x)
return out

Generator:

The generator network will be almost exactly the same as the discriminator network, except that we’re applying a tanh activation function to our output layer.

But why only tanh?
The generator has been found to perform the best with 𝑡𝑎𝑛ℎ for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1.

  1. Here you’ve defined all hidden linear layers
  2. Then created a final fully-connected layer
  3. Add a dropout layer to avoid overfitting
  4. Create all hidden layers in forward function
  5. Eventually, add a final layer with tanh applied
class Generator(nn.Module):def __init__(self, input_size, hidden_dim, output_size):
super(Generator, self).__init__()

# 1
self.fc1 = nn.Linear(input_size, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim*2)
self.fc3 = nn.Linear(hidden_dim*2, hidden_dim*4)

# 2

self.fc4 = nn.Linear(hidden_dim*4, output_size)

# 3
self.dropout = nn.Dropout(0.3)
def forward(self, x):
# 4
x = F.leaky_relu(self.fc1(x), 0.2) # (input, negative_slope=0.2)
x = self.dropout(x)
x = F.leaky_relu(self.fc2(x), 0.2)
x = self.dropout(x)
x = F.leaky_relu(self.fc3(x), 0.2)
x = self.dropout(x)
# 5
out = F.tanh(self.fc4(x))
return out

Model hyperparameters:

Now comes the part in which you can experiment the most.

  1. Size of the input image to discriminator (28*28= 784)
  2. Size of discriminator output (real or fake)
  3. Size of last hidden layer in the discriminator
  4. Size of latent vector to give to the generator
  5. Size of discriminator output (generated image)
  6. Size of first hidden layer in the generator
# Discriminator hyperparameters# 1
input_size = 784
# 2
d_output_size = 1
# 3
d_hidden_size = 32
# Generator hyperparams# 4
z_size = 100
# 5
g_output_size = 784
# 6
g_hidden_size = 32

Build a complete network:

Now we’re instantiating the discriminator and generator from the classes defined above.

# instantiate discriminator and generator
D = Discriminator(input_size, d_hidden_size, d_output_size)
G = Generator(z_size, g_hidden_size, g_output_size)
# check that they are as you expect
print(D)
print()
print(G)

You should get an output like this :

Define Losses:

  • For the discriminator, the total loss is the sum of the losses for real and fake images, d_loss = d_real_loss + d_fake_loss.
  • Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.
  • The generator loss will look similar only with flipped labels. The generator’s goal is to get D(fake_images) = 1.
  • In this case, the labels are flipped to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!
# Calculate losses
def real_loss(D_out, smooth=False):
batch_size = D_out.size(0)
# label smoothing
if smooth:
# smooth, real labels = 0.9
labels = torch.ones(batch_size)*0.9
else:
labels = torch.ones(batch_size) # real labels = 1

# numerically stable loss
criterion = nn.BCEWithLogitsLoss()
# calculate loss
loss = criterion(D_out.squeeze(), labels)
return loss
def fake_loss(D_out):
batch_size = D_out.size(0)
labels = torch.zeros(batch_size) # fake labels = 0
criterion = nn.BCEWithLogitsLoss()
# calculate loss
loss = criterion(D_out.squeeze(), labels)
return loss

Define Optimizers:

We want to update the generator and discriminator variables separately. So, we’ll define two separate Adam optimizers.

import torch.optim as optim
lr = 0.002
d_optimizer = optim.Adam(D.parameters(), lr)
g_optimizer = optim.Adam(G.parameters(), lr)

Training:

Training will involve alternating between training the discriminator and the generator. We’ll use our functions real_loss and fake_loss to help us calculate the discriminator losses in all of the following cases.

Discriminator training:

  • Compute the discriminator loss on real, training images
  • Generate fake images
  • Compute the discriminator loss on fake, generated images
  • Add up the real and fake loss
  • Perform backpropagation + an optimization step to update the discriminator’s weights

Generator training:

  • Generate fake images
  • Compute the discriminator loss on fake images, using flipped labels!
  • Perform backpropagation + an optimization step to update the generator’s weights

Also, for the sake of time, it will help to have a GPU. I have trained it on Intel DevCloud which shows excellent performance & the network trains in a flash!

After training for about 100 epochs, the loss for G & D was as shown in the image above. I’ve plotted the training losses for the generator and discriminator, recorded after each epoch so that it’s easier to visualize-

Generator samples from training:

Below I’m showing the generated images as the network was training, every 10 epochs.

It starts out as all noise. Then it learns to make only the center white and the rest black. You can start to see some numbers like structures appear out of the noise like 1s and 9s.

If you need full code implementation for this, then do check out: GANs

In this blog post, I have introduced Generative Adversarial Networks. We explored the parts that conform to a GAN and how they work together. Finally, we finished linking the theory with the practice by programming with a fully working implementation of a GAN that learned to create synthetic examples of the MNIST dataset.
Now that you’ve learned all of this, the next step would be to keep on reading and learning about the more advanced GAN methods that I listed at the beginning of this post.

--

--

Garima Nishad
Intel Student Ambassadors

A Machine Learning Research scholar who loves to moonlight as a blogger.