Everything you ever wanted to know about Generative Adversarial Networks (GANs)

“The most interesting idea in the last ten years in ML”

- Yann LeCun, Facebook AI Research Director

Generative Adversarial Networks (GANs) are powerful machine learning models which act as frameworks for estimating generative models. This method uses two models which are trained simultaneously : i. a generative model that captures the data distribution, ii. a discriminative model which estimates the probability that a sample came from training data rather than from the generative model.

Statistically, GANs are represented with the below formula :

GANs Loss Function

where:

G = generator

D = discriminator

Pdata (x) = distribution of real data

P(z) = distribution of generator

x = sample from Pdata (x)

z = sample from Zdata(x)

D(x) = discriminator network

G(z) = generator network

The two models are set up against each other with opposite goals :

i. Generator model is trained to maximize the final classification error (between true and generated data)

ii. Discriminator model is trained to minimize the final classification error

The generative model’s objective is to increase the error rate of the discriminative network. The generator trains based on whether it succeeds in fooling the discriminator.

Whereas, the discriminative model’s objective is to identify which outputs are synthetic from the outputs it receives from the generator.

Applications of GANs :

GANs have some very interesting uses, some of which have been mentioned below:

a. Generate examples for image datasets

b. Generate photographs of human faces

c. Generate realistic photographs

d. Generate cartoon characters

e. Filling images from an outline

Other miscellaneous applications include:

a. Fashion, art and advertising

b. Science

c. Video games

d. GANs can also be used to detect glaucomatous images, photorealistic images, 3d models of objects from images, data augmentation, inpaint missing features in maps, etc.

Some impressive GAN Libraries :

a. TF-GAN — This is a lightweight library for training and evaluating GANs. It can be installed with pip using pip install tensorflow-gan, and used with import tensorflow_gan as tfgan.

b. Torch-GAN — This is a PyTorch based framework for designing and developing GANs. This framework has been designed to provide building blocks for popular GANs and also to allow customization for cutting edge research.

c. Mimicry — This is a lightweight PyTorch library aimed towards the reproducibility of GAN research.

d. IBM GAN Toolkit — The aim of this toolkit is to provide a highly flexible, no-code way of implementing GANs. By providing the details of a GAN model, in an intuitive config file or as command line arguments, the code could be generated for training the GAN model.

e. Keras-GAN — This includes a collection of Keras implementations of GANs suggested in research papers.

f. PyTorch-GAN — PyTorch is a leading open source deep learning framework. While PyTorch does not provide a built-in implementation of a GAN network, it provides primitives that allows you to build GAN networks, including fully connected neural network layers, convolutional layers, and training functions.

g. Py-GAN — This is a Python library which is useful in designing Generative models based on statistical machine learning problems for different types of GANs.

Different types of GANs:

a. Vanilla GANs — Vanilla GANs has two networks called generator network and a discriminator network. Both the networks are trained at the same time and compete or battle against each other in a minimax play.

b. Conditional GAN (CGAN) — Conditional GAN is a GAN variant in which both the generator and the discriminator are conditioned on auxiliary data such as a class label during training.

c. Deep Convolutional GAN (DCGAN) — DCGAN is one of the popular and a successful network design for GAN. It mainly composes of convolution layers without max pooling or fully connected layers. It uses convolutional stride and transposed convolution for downsampling and upsampling.

d. Laplacian Pyramid GAN (LAPGAN) — LAPGAN combines the CGAN model with a Laplacian pyramid representation.

e. Super Resolution GAN (SRGAN) — SRGAN applies a deep network in combination with an adversary network to produce high resolution images.

Sample Python code for GAN using MNIST dataset:

# importing the necessary libraries and the MNIST dataset
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets(“MNIST_data”)

# defining functions for the two networks.
# Both the networks have two hidden layers
# and an output layer which are densely or
# fully connected layers defining the
# Generator network function
def generator(z, reuse = None):
with tf.variable_scope(‘gen’, reuse = reuse):
hidden1 = tf.layers.dense(inputs = z, units = 128,
activation = tf.nn.leaky_relu)

hidden2 = tf.layers.dense(inputs = hidden1,
units = 128, activation = tf.nn.leaky_relu)

output = tf.layers.dense(inputs = hidden2,
units = 784, activation = tf.nn.tanh)

return output

# defining the Discriminator network function
def discriminator(X, reuse = None):
with tf.variable_scope(‘dis’, reuse = reuse):
hidden1 = tf.layers.dense(inputs = X, units = 128,
activation = tf.nn.leaky_relu)

hidden2 = tf.layers.dense(inputs = hidden1,
units = 128, activation = tf.nn.leaky_relu)

logits = tf.layers.dense(hidden2, units = 1)
output = tf.sigmoid(logits)

return output, logits

# creating placeholders for the outputs
tf.reset_default_graph()

real_images = tf.placeholder(tf.float32, shape =[None, 784])
z = tf.placeholder(tf.float32, shape =[None, 100])

G = generator(z)
D_output_real, D_logits_real = discriminator(real_images)
D_output_fake, D_logits_fake = discriminator(G, reuse = True)

# defining the loss function
def loss_func(logits_in, labels_in):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits = logits_in, labels = labels_in))

# Smoothing for generalization
D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real)*0.9)
D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real))
D_loss = D_real_loss + D_fake_loss

G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake))

# defining the learning rate, batch size,
# number of epochs and using the Adam optimizer
lr = 0.001 # learning rate

# Do this when multiple networks
# interact with each other

# returns all variables created(the two
# variable scopes) and makes trainable true
tvars = tf.trainable_variables()
d_vars =[var for var in tvars if ‘dis’ in var.name]
g_vars =[var for var in tvars if ‘gen’ in var.name]

D_trainer = tf.train.AdamOptimizer(lr).minimize(D_loss, var_list = d_vars)
G_trainer = tf.train.AdamOptimizer(lr).minimize(G_loss, var_list = g_vars)

batch_size = 100 # batch size
epochs = 500 # number of epochs. The higher the better the result
init = tf.global_variables_initializer()

# creating a session to train the networks
samples =[] # generator examples

with tf.Session() as sess:
sess.run(init)
for epoch in range(epochs):
num_batches = mnist.train.num_examples//batch_size

for i in range(num_batches):
batch = mnist.train.next_batch(batch_size)
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images * 2–1
batch_z = np.random.uniform(-1, 1, size =(batch_size, 100))
_= sess.run(D_trainer, feed_dict ={real_images:batch_images, z:batch_z})
_= sess.run(G_trainer, feed_dict ={z:batch_z})

print(“on epoch{}”.format(epoch))

sample_z = np.random.uniform(-1, 1, size =(1, 100))
gen_sample = sess.run(generator(z, reuse = True),
feed_dict ={z:sample_z})

samples.append(gen_sample)

# result after 0th epoch
plt.imshow(samples[0].reshape(28, 28))

# result after 499th epoch
plt.imshow(samples[49].reshape(28, 28))

In conclusion, I would like to say that GANs are still being researched and continues to be an exciting topic for machine learning enthusiasts to explore! In this blog, I have discussed interesting details about GANs. Please use the following link : https://www.ml-concepts.com/machine-learning-models/ for more interesting articles.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store