Understanding CycleGANs using examples & codes

Training CycleGAN for season translation using tensorflow 2

Mehul Gupta
Data Science in your pocket
10 min readSep 14, 2021

--

After covering basic GANs (with a sample model) in my last post, taking a step further, we will explore an advanced GAN version i.e CycleGAN having some fascinating applications in the field of image translation

What the heck is image translation?

It's basically tweaking the image such that the domain/style of the input image changes to another keeping core content intact. For: Season changes for a landscape image, keeping other features in the image intact. So, the idea is if we have an image of a house in the Summer season, it should get translated into an image depicting the winter season keeping the house as it is.

How can this be done? an early guess can be that maybe we have can an X →Y pairing of images for each input X from domain 1(say Summer) & Y(same image from winter) & we may train some autoencoder sort of model to get this translation done. Right?

But if you ponder for a minute, preparing such a training dataset is a monumental task. For example: for the Summer to Winter translation of the same location, you may need to wait for a year to get pictures. In other cases, the Y just doesn’t exist. For example: converting MF Hussain’s painting to Da Vinci’s painting. You can just imagine how would Da Vinci have painted the same painting that MF Hussain painted but can’t have actual ground truth.

So, in short, you don’t have an X →Y mapping of the same scene.

Now what? looks like an impossible task. This is when CycleGANs comes to the rescue. The best part being they don’t require an X →Y mapping but just a few samples from both the domains (the input & output. Random Summer & Winter images in our case) that are independent. By independent we mean random images from both the domains and not necessarily pairs.

So, how does CycleGAN achieves this? by bringing in a few subtle changes in basic GAN structure? Let’s discuss them

Multiple Generators & Discriminators in the play

So, in a general GAN, we have a Generator generating fake images & a Discriminator detecting real & fake images given one domain. Now, as we have images coming from 2 domains (be it Summer →Winter or Winter →Summer, MF Hussain →Da Vinci or Da Vinci →MF Hussain), we have a pair of Generator & Discriminator for each domain so that the translation can happen both ways.

Bring on the Cycle

As the name suggests, CycleGAN consists of a cyclic structure formed between these multiple generators & discriminators.

Let's assume A=Summer, B=Winter. Now, the cyclic flow goes something like this

A. Random samples from domain A (Summer) are fed to generator_A_B (Summer →Winter). This generator intakes Summer images & translates them to Winter. So, instead of intaking random noise (as in general GAN), we take images from one domain & train the generator to translate them to another domain.

B. This generated image for Domain B (Winter) from generator_A_B is fed to the 2nd generator i.e. generator_B_A (Winter →Summer), hence reproducing the input image from Domain A.

The same flow then goes vice-versa for Winter →Summer.

Image_A →Generator_A_B →Image_B →Generator_B_A →Image_A

&

Image_B →Generator_B_A →Image_A →Generator_A_B →Image_B

Loss function going complex

For a general GAN, it's the discriminator’s error in classifying real vs fake samples that we use to train our generator & discriminator. This time though, we have a few more additions making the loss function a little complex

  • Adversarial loss: Loss due to misclassification between real & fake images.
  • Identity loss: It is a loss term calculated at the pixel level. What would you wish your generator_A_B(i.e Summer →Winter) to produce if a Winter image is fed in? We would wish the generator shouldn’t tweak anything as the input image is already in the desired domain. It's assumed this term helps in preserving colors in the output images avoiding unnecessary changes.
  • Cyclic loss: As we observed the above cyclic structure that exists in CycleGAN, where we pass an image from one of the domains to both the generators sequentially producing the same image as output. This term is the MSE loss between the input & output image as ideally, input & output should remain the same.

As we are done with the basics, it's time for some action I guess !!

We will be training a CycleGAN for Summer →Winter & vice-versa translation using TensorFlow 2

Note: Kindly download the dataset from here. All the below codes with trained models can be checked in at

  1. Importing required packages
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow_addons.layers import *
from PIL import Image
import numpy as np
import glob
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import os

2. Creating TensorFlow pipeline for training purposes

def preprocess(records):
images = records['image']
images = tf.cast(images, tf.float32)/255.0
return images
def tf_pipeline(dataset):
dataset = tf.data.Dataset.from_tensor_slices({'image':dataset})
dataset = dataset.map(preprocess)
dataset = dataset.repeat().shuffle(100).batch(16).prefetch(1)
return dataset

def tf_data(path):
trainingA = []
for x in glob.glob(path+'A/*'):
image = Image.open(x)
image = image.resize((128,128))
trainingA.append(np.array(image))

trainingB = []
for x in glob.glob(path+'B/*'):
image = Image.open(x)
image = image.resize((128,128))
trainingB.append(np.array(image))
a,b = tf_pipeline(trainingA),tf_pipeline(trainingB)
return a.__iter__(),b.__iter__()
trainA,trainB = tf_data('train')

Note: For any confusion related to tf.data.Dataset, do look out for an explanation in my previous post

  • tf_data() reads image data from both domains (folder trainA: Summer images; trainB: Winter samples). It calls tf_pipeline() converting image arrays as TensorFlow dataset
  • preprocess() and applies basic normalization to images loaded. Some more preprocessing can be added for enhanced results.
  • An iterator over both TensorFlow datasets is initialized i.e. trainA & trainB respectively by calling tf_data()

3. Declaring a few constants

input_dim = (128,128,3) #input/output image dimension
depth = 4 #network depth
kernel = 3 #kernel size for Conv2D
n_batch = 16 #batch_size
epochs = 10
steps = round(1500/n_batch) #steps per epoch, we have ~1500 samples per domain so calculating steps using it

4. Discriminators(for both A →B, B →A translations) structure

def discriminator(input_dim,depth,kernel):
layers = []
layers.append(Input(shape=input_dim))
for i in range(1,depth):
layers.append(Conv2D(16*i,kernel_size=kernel))
layers.append(InstanceNormalization())
layers.append(Activation('relu'))
layers.append(Dropout(0.2))
layers.append(Conv2D(1,kernel_size=kernel))
model = Sequential(layers)
model.compile(loss='mse',optimizer=tf.keras.optimizers.Adam())
return model
discriminator_A = discriminator(input_dim,depth,kernel)
discriminator_B = discriminator(input_dim,depth,kernel)

The structure looks pretty similar to the one declared in my previous GAN post & is already explained. We declare 2 discriminators for each domain classification as discussed above

The only change that is notable is the last layer which is a Conv2D layer.

Why a conv2d layer at the end for a discriminator? As its more of a binary classifier, shouldn’t we be using Flatten() & Dense() combo as in traditional discriminator

This takes inspiration from the PatchGAN neural network where the discriminator, instead of giving a single real/fake classification for the entire image, outputs real/fake labels at the region level for an image. So, each patch of the image is classified rather than an entire image as one.

But why not a traditional discriminator? What extra advantage do we get when producing labels at the regional level?

The major advantage such a discriminator adds is to take into account style loss while the translation happens. So, when we try to translate Summer →Winter or vice-versa, what we actually wish to have the same content(house, hills, other objects present should stay) but changed style(from Summer to Winter) for the input image. Such a discriminator keeps a check on whether style change (in the output image) has happened for the entire image & not just parts of it.

How does the ground truth look for such discriminators?

If you remember, for a traditional discriminator, we have a single value (0 or 1 as the label). Here, we have a 2d array per image filled with either 0s (for fake) or 1s (for real).

5. Defining our Generator

def generator(input_dim, depth, kernel):
layers = []
layers.append(Input(shape=input_dim))
for i in range(1,depth):
layers.append(Conv2D(16*i,kernel_size=kernel))
layers.append(InstanceNormalization())
layers.append(Activation('relu'))
layers.append(Dropout(0.2))

for i in range(1,depth):
layers.append(Conv2DTranspose(16*i,kernel_size=kernel))
layers.append(InstanceNormalization())
layers.append(Activation('relu'))
layers.append(Dropout(0.2))

resizer = lambda name: Lambda(lambda images: tf.image.resize(images, [128,128]), name=name)
layers.append(resizer('Reshape'))
layers.append(Conv2DTranspose(3,kernel_size=1,activation=None))
model = Sequential(layers)
return model
generator_A_B = generator(input_dim,depth,kernel)
generator_B_A = generator(input_dim,depth,kernel)

The generator looks similar to the one used in my last GAN post and is also explained previously. Here also, as in the case of discriminators, we declare 2 generators for translation from each of the domains mentioned.

6. Time to get the cyclic connection in !!

def composite_model(g1,d,g2,image_dim):
g1.trainable = True
g2.trainable = False
d.trainable = False

#Adversarial loss
input_img = Input(shape=input_dim)
g1_out = g1(input_img)
d_out = d(g1_out)

#identity loss
input_id = Input(shape=input_dim)
g1_out_id = g1(input_id)

#Cycle Loss, Forward cycle
g2_out = g2(g1_out)

#Cycle Loss, Backward-cycle
g2_out_id = g2(input_id)
output_g1 = g1(g2_out_id)

model = Model([input_img,input_id],[d_out, g1_out_id, g2_out, output_g1])
model.compile(loss=['mse','mae','mae','mae'],loss_weights=[1,5,10,10],optimizer=tf.keras.optimizers.Adam())
return model
composite_A_B = composite_model(generator_A_B, discriminator_B, generator_B_A, input_dim)
composite_B_A = composite_model(generator_B_A, discriminator_A, generator_A_B, input_dim)

This one is a hard nut to crack.

  1. The composite_model() helps us in building the cyclic connection.

2. This function intakes both the generators, and discriminator corresponding to the 1st generator passed. Only the 1st generator is trainable.

Assume the composite_A_B model

  • Adversarial loss:

Input(Summer) →Generator_A_B(Summer →Winter) →Discriminator_B(Winter or not) →d_out

  • Identity loss:

Input(Winter) →Generator_A_B →g1_out_id

  • Cycle loss (Forward cycle)

Input(Summer) →Generator_A_B →Generator_B_A →g2_out

  • Cycle loss (Backward loss)

Input(Winter) →Generator_B_A →Generator_A_B →output_g1

3. Model is compiled with

  • Inputs: Input_A(Summer), Input_B(winter)
  • Outputs: d_out, g1_out_id, g2_out, output_g1

MSE is used for the Adversarial terms & MAE for the rest of the terms.

Similarly, we define composite_B_A to learn translation and vice-versa.

7. Generating samples & labels for training

def generate_real(dataset, batch_size,patch_size):
labels = np.ones((batch_size,patch_size,patch_size,1))
return dataset,labels
def generate_fake(dataset,g,patch_size):
predicted = g(dataset)
labels = np.zeros((batch_size,patch_size,patch_size,1))
return predicted,labels

This code block helps us in generating fake & real samples while training alongside labels to assign (1 for real & 0 for fake). The only thing to note is the labels aren’t single values but more of a 2d array per image because of the PatchGAN idea we used in the discriminator. So, for a 128x128 image, we might be producing a 2d array of 32x32 dimensions where 32 is the patch-size.

8. We must checkpoint the models we train as training CycleGAN takes a good chunk of time

checkpoint_dir = './cyclegan'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_A_B=generator_A_B, generator_B_A=generator_B_A,discriminator_A=discriminator_A,discriminator_B=discriminator_B,composite_A_B=composite_A_B, composite_B_A=composite_B_A)manager = tf.train.CheckpointManager(checkpoint, 'training_checkpoints', max_to_keep=3)checkpoint.restore(manager.latest_checkpoint)

This is similar to the checkpointing mechanism explained in the previous post.

9. Now comes the training !!

def train(discriminator_A, discriminator_B, generator_A_B, generator_B_A, composite_A_B, composite_B_A, epochs, batch_size, steps,n_patch):

for epoch in range(1,epochs):
for step in range(1,steps):
print(epoch,step)

x_real_A, y_real_A = generate_real(next(trainA),n_batch,n_patch)
x_real_B, y_real_B = generate_real(next(trainB),n_batch,n_patch)

x_fake_A, y_fake_A = generate_fake(x_real_B, generator_B_A,n_batch,n_patch)
x_fake_B, y_fake_B = generate_fake(x_real_A, generator_A_B,n_batch,n_patch)

g_A_B_loss,_,_,_,_ = composite_A_B.train_on_batch([x_real_A,x_real_B],[y_real_B,x_real_B, x_real_A, x_real_B])
disc_A_real_loss = discriminator_A.train_on_batch(x_real_A, y_real_A)
disc_A_fake_loss = discriminator_A.train_on_batch(x_fake_A, y_fake_A)

g_B_A_loss,_,_,_,_ = composite_B_A.train_on_batch([x_real_B,x_real_A],[y_real_A,x_real_A, x_real_B, x_real_A])
disc_B_real_loss = discriminator_B.train_on_batch(x_real_B, y_real_B)
disc_B_fake_loss = discriminator_B.train_on_batch(x_fake_B, y_fake_B)

print('g_A_B_loss',g_A_B_loss)
print('g_B_A_loss',g_B_A_loss)

manager.save()
train(discriminator_A, discriminator_B, generator_A_B, generator_B_A, composite_A_B, composite_B_A, epochs, n_batch, steps,discriminator_A.output_shape[1])

Let’s go through this

  • For every step in the epoch, real & fake(using a generator over other domain images) samples are generated for both domains alongside labels
  • The two composite models alongside the discriminators are trained over these samples

Considering composite_A_B.train_on_batch():

  1. y_real_B (Ground Truth for adversarial loss term)
  2. x_real_A (GT for identity loss term)
  3. x_real_B (GT for forward cycle term)
  4. x_real_A(GT for backward cycle term)

Once training is done, we may see how our generators perform over sample test data.

Summer to Winter

testA,testB = tf_data('test')x_real_A, _ = generate_real(next(testA),n_batch,0)
images_B,_ = generate_fake(x_real_A, generator_A_B,n_batch,0)
fig,ax = plt.subplots(n_batch,figsize=(75,75))
for index,img in enumerate(zip(x_real_A,images_B)):
concat_numpy = np.clip(np.hstack((img[0],img[1])),0,1)
ax[index].imshow(concat_numpy)
fig.tight_layout()
Left(Summer, original), Right(Winter, translated)

Similarly, Winter to Summer

x_real_B,_ = generate_real(next(testB),n_batch,0)
images_A,_ = generate_fake(x_real_B, generator_B_A,n_batch, 0)
fig,ax = plt.subplots(n_batch,figsize=(75,75))
for index,img in enumerate(zip(x_real_B,images_A)):
concat_numpy = np.hstack((img[0],img[1]))
ax[index].imshow(concat_numpy)
fig.tight_layout()
Left(Winter, original), Right(Summer, translated)

A few observations about this entire setup & the end results

  • The end results look decent. Though, there is huge space for improvements as 1) The networks can be made more complex rather than a basic CNN 2) Some image preprocessing can be done 3) These networks were trained for a few epochs due to hardware constraints. If trained rigorously, results should improve
  • Talking about the results, I can observe 1) Some unnecessary patches in the translated images 2) Some sort of fadedness/greyness/dullness getting added to the Summer →Winter translated image giving it the ‘Winter touch’ 3) Similarly, for Winter →Summer, warmth/orangeness gets added to the translated image giving it the ‘Summer touch’.
  • A common misunderstanding is when converting Winter →Summer, the snow should melt out from mountains, or for Summer →Winter, snow should appear on entities present. We must remember that the content/entities won’t change while imaging translation. So, if a mountain has snow for the winter season, it will still have it after translating it to Summer. It's just that style of the image that will change keeping everything else intact.

A big thanks to Machine Learning Mastery for assistance with the code part.

And with this, it's a wrap. Will rejoin next time with InfoGAN.

--

--