Generating New Realities: Crafting a Simple GAN with PyTorch

The Programming Geek
4 min readJan 29, 2024

--

The advent of Generative Adversarial Networks (GANs) has been a revolution in the world of machine learning and artificial intelligence, offering unprecedented capabilities in generating realistic images, music, text, and more. In this comprehensive guide, we will delve into the intricacies of building a simple yet powerful GAN using PyTorch, a leading deep learning library that provides maximum flexibility and speed during the development of complex algorithms. So let’s embark on this journey to not only understand the fundamental concepts behind GANs but also to learn how to implement them from scratch.

Understanding GANs

Generative Adversarial Networks, conceptualized by Ian Goodfellow and his colleagues in 2014, are composed of two neural networks that are trained simultaneously through adversarial processes. These networks are:

  • The Generator: This network learns to generate data that is similar to the input data it’s been trained on, with the goal of creating something that is indistinguishable from that real data.
  • The Discriminator: As the counterpart to the Generator, the Discriminator’s task is to differentiate between the real data and the fake data produced by the Generator.

The training of GANs can be thought of as a game where the Generator is trying to produce counterfeit currency, while the Discriminator is acting like the police, trying to detect the fake bills. Over time, the Generator becomes so good at creating forgeries that the Discriminator can’t tell the difference between real and fake anymore.

Setting Up the Environment

Before diving into the code, ensure that you have Python and PyTorch installed in your environment. PyTorch can be easily installed using pip:

pip install torch torchvision

Make sure to check the PyTorch website for the latest installation instructions tailored to your operating system and CUDA version (if applicable).

The Data

For our example, we’ll be using the MNIST dataset, which contains 60,000 images of hand-written digits and is commonly used for training various image processing systems. PyTorch provides a straightforward way to load this dataset:

from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

The Generator

The Generator is a neural network that takes random noise as input and outputs data with the same dimensions as the MNIST images. Here’s a simple architecture for the Generator, using PyTorch’s Sequential model:

import torch
from torch import nn

latent_dim = 100 # Size of the random noise vector

# Define the Generator's architecture
G = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 28*28),
nn.Tanh()
)

The Tanh function is used in the last layer to output values between -1 and 1, which matches the normalization we applied to the MNIST images.

The Discriminator

The Discriminator is another neural network that classifies its input as real or fake. Here’s a simple Discriminator using PyTorch’s Sequential model:

# Define the Discriminator's architecture
D = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)

The Sigmoid function outputs a probability, indicating whether the input data is real (closer to 1) or fake (closer to 0).

Loss Functions and Optimizers

GANs require two different loss functions: one for the Generator and one for the Discriminator. The Binary Cross Entropy loss (BCELoss) is typically used for both:

# Loss function
criterion = nn.BCELoss()

# Optimizers
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

Training the GAN

Training a GAN involves alternating between training the Discriminator and the Generator. Here’s the high-level training loop:

import torch.autograd as autograd

# Number of epochs
num_epochs = 200

# Training loop
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# Flatten the images for the Discriminator
images = images.view(images.size(0), -1)

# Real labels are 1, fake labels are 0
real_labels = torch.ones(images.size(0), 1)
fake_labels = torch.zeros(images.size(0), 1)

############################
# Train the Discriminator
############################
D_optimizer.zero_grad()

# Compute BCELoss using real images
outputs = D(images)
D_loss_real = criterion(outputs, real_labels)
real_score = outputs

# Generate fake images
z = torch.randn(images.size(0), latent_dim)
fake_images = G(z)

# Compute BCELoss using fake images
outputs = D(fake_images.detach())
D_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs

# Optimize the Discriminator
D_loss = D_loss_real + D_loss_fake
D_loss.backward()
D_optimizer.step()

############################
# Train the Generator
############################
G_optimizer.zero_grad()

# Generate fake images
z = torch.randn(images.size(0), latent_dim)
fake_images = G(z)

# Compute BCELoss using fake images, with reversed labels
outputs = D(fake_images)
G_loss = criterion(outputs, real_labels)

# Optimize the Generator
G_loss.backward()
G_optimizer.step()

if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item(),
real_score.mean().item(), fake_score.mean().item()))

In the training loop, note how we train the Discriminator on both real and fake data, and optimize it to improve its accuracy. We then train the Generator to produce data that the Discriminator will classify as real.

Visualizing the Results

It’s essential to visualize the generated images to assess the performance of our GAN. After training, we can use the Generator to create images:

# Generate fake images for visualization
z = torch.randn(16, latent_dim)
fake_images = G(z)
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
fake_images = (fake_images + 1) / 2 # Rescale images to [0, 1]

# Plot the fake images
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 16, figsize=(15, 15))
for ax, img in zip(axes.flatten(), fake_images):
ax.axis('off')
ax.set_adjustable('box')
img = transforms.ToPILImage()(img.cpu().squeeze())
ax.imshow(img, cmap='gray')
plt.show()

Conclusion

We’ve explored the fundamental principles behind GANs and taken you through a practical example of building and training a simple GAN with PyTorch. While our example is basic, it captures the essence of GANs. You are now equipped with the knowledge to start experimenting and creating your own sophisticated generative models.

As with all machine learning models, remember that practice and experimentation lead to mastery. So keep tinkering, keep learning, and most importantly, keep generating!

--

--