Building a Simple Python-Based GAN in 5 minutes

A beginner-level tutorial

Nandhini Swaminathan
The Research Nest
4 min readDec 27, 2022

--

Credits

Generative Adversarial Networks, or GANs, have created an uproar in academic circles for their abilities. The machine’s ability to produce new and inspired works has caused awe and horror in everyone’s mind. And as such, one becomes curious, how to build one?

A Generative Adversarial Network (GAN) is a deep learning model that generates new, synthetic data similar to some input data. GANs consist of two neural networks: a generator and a discriminator. The generator is trained to produce synthetic data identical to the input data, while the discriminator is trained to distinguish between synthetic and real data.

A generative model learns the intrinsic distribution function of the input data f(x), allowing it to generate both synthetic input x’ and output y’, typically given some hidden parameters. GANs are advantageous because they generate the sharpest images and are easy to train.

The Code

This code trains the GAN for a given number of epochs, where an epoch is defined as one pass through the entire dataset. During each epoch, the code iterates over the data in the data loader (which should be a PyTorch DataLoader object that wraps your dataset) and trains both the discriminator and generator on each batch.

The generator is trained by trying to fool the discriminator, which is trained to distinguish real images from fake images. The loss function used here is binary cross-entropy loss, which is a common choice for GANs. The optimizers used are Adam, which is a type of stochastic gradient descent optimizer.

  1. First, import the necessary libraries and define the generator and discriminator models.
import torch
import torch.nn as nn
import torch.optim as optim
  • The generator should be a neural network that takes in a random noise vector and generates synthetic data. At the same time, the discriminator should be a neural network that takes in real or synthetic data and outputs a probability that the input data is real.
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.tanh(self.fc2(x))
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
return x

2. In the following code block, we set up the environment for the GAN. This includes:

  • Setting the sizes of the input, hidden, and output layers for the discriminator and generator networks.
  • Create an instance of the Generator and Discriminator class
  • Setting up the loss function and optimizers
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the input and output sizes
input_size = 784
hidden_size = 256
output_size = 1

# Create the discriminator and generator
discriminator = Discriminator(input_size, hidden_size, output_size).to(device)
generator = Generator(input_size, hidden_size, output_size).to(device)

# Set the loss function and optimizers
loss_fn = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)

# Set the number of epochs and the noise size
num_epochs = 200
noise_size = 100

# Training loop
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
# Get the batch size
batch_size = real_images.size(0)

3. In the below code, the generator is trained by trying to fool the discriminator, which is trained to distinguish real and fake images. To do this,

  • We give the generator a batch of noise samples as input and generate a batch of fake images. These fake images are then passed through the discriminator, which produces a prediction for each image in the batch.
  • The loss for the generator is then calculated, and the code back-propagates the loss through the generator and optimizes the generator’s parameters using the Adam optimizer. This process updates the generator’s parameters in a direction that reduces the loss and improves the generator’s ability to fool the discriminator.
  # Generate fake images
noise = torch.randn(batch_size, noise_size).to(device)
fake_images = generator(noise)

# Train the discriminator on real and fake images
d_real = discriminator(real_images)
d_fake = discriminator(fake_images)

# Calculate the loss
real_loss = loss_fn(d_real, torch.ones_like(d_real))
fake_loss = loss_fn(d_fake, torch.zeros_like(d_fake))
d_loss = real_loss + fake_loss

# Backpropagate and optimize
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# Train the generator
d_fake = discriminator(fake_images)
g_loss = loss_fn(d_fake, torch.ones_like(d_fake))

# Backpropagate and optimize
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

# Print the loss every 50 batches
if (i+1) % 50 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))

And… that’s all. A quick basic GAN model ready to be used.

Further Reading

--

--