Image to Image Translation: GAN and Conditional GAN
Many computer vision and image processing problems need translation of input images into corresponding output images. CNNs are being used for this purpose and they are becoming common workhorses behind solving many image prediction problems. CNNs learn to minimize a loss function and although the learning process is automatic , a lot of manual effort is needed in designing effective losses. In simple terms we still need to tell CNN what we want to minimize.
For example, if we ask the CNN to minimize the Euclidean distance between predicted and ground truth pixels (so that image to image translation is achieved), it will tend to produce blurry images. This is because the Euclidean distance is minimized by averaging all plausible outputs, which causes blurring.
So it would be great if we could tell the network like “make the output indistinguishable from source input” and then let the loss function learn, which satisfies this goal. This is achieved by using GAN (Generative Adversarial Networks) and Conditional GAN. To understand Conditional GAN (cGAN) you need to first understand the standard GAN.
Generative Adversarial Networks:
Generative Adversarial Networks belong to the set of generative models. It means that they are able to generate new content. Generative modeling is an unsupervised learning task that involves automatically discovering and learning patterns in input data, in such a way that the model can be used to generate new examples that could have possibly been drawn from the original dataset. GANs are a clever way of training a generative model by framing the problem as a supervised learning problem with two sub-models: Generator and Discriminator.
These networks not only map input images to output images, but also learns a loss function to train this mapping. Basically GAN contains Generator and a Discriminator, where generator tries to generate an image that is indistinguishable from real one, at the same time discriminator tries to detect if the generated image is real or fake.
The input to the generator is a series of randomly generated numbers called latent sample. To train the generator, we need to train GAN network. But before taking about GAN, let us look at the generator and the discriminator.
Discriminator classifies whether an image is real(1) or fake(0). Discriminator is trained using both the original dataset and the images generated by the generator. If input is from the original dataset, then discriminator should classify it as real and if input is from the generator, then it should classify it as fake.
To produce a GAN, we need to combine a Generator and a Discriminator, with generator being placed on top of the discriminator. And when a latent sample is passed as an input to the GAN, generator generates an image which is passed as an input to the discriminator for classification. If the generator has succeeded in generating an image (with high probability of the image being real), the discriminator returns classify it as real.
Generator initially produces garbage images, and the loss value is very high. So, back propagation updates the generator’s weights to produce more realistic images as training continues. The generator is updated via a weighted sum of both the adversarial loss and the L1 loss, where the L1 loss is set to a weighting of 100 to 1. This is to encourage the generator strongly toward generating plausible translations of the input image, and not just plausible images in the target domain.
Note: while training the generator via GAN, we do not want discriminator weight’s to be updated because we are using discriminator as a classifier. So, we set the discriminator non-trainable during the generator training.
But, we need to train the discriminator as well so that it classifies the input as a real or fake. So, the training takes place in a loop.
- First set the discriminator trainable.
- Train the discriminator with real dataset and the image generated by the generator to classify real and fake respectively.
- Then, set the discriminator as non-trainable.
- Latent samples are inputed to GAN and let the generator produce images and use the discriminator to classify this image. And then, using the discrepancy between fake and real images, generator weight’s are updated by back-propagation.
Here is an objective function for GAN network, where G tries to minimize the objective against an adversarial D that tries to maximize it. So, the generator tries to generate images that are most likely to be real and the discriminator tries to detect the images generated by the generator as fake. Equilibrium is achieved when the discriminator classifies the image generated from the generator as real (that is when generator has succeeded in fooling the discriminator. Let’s look at the cGAN now.
Conditional GAN:
As GAN’s learn generative model, conditional GAN learn a conditional generative model. This makes cGAN suitable for image to image translation, where we condition on an input image and generate a corresponding output image. All the functionalities are same as GAN except that, it is given an extra input condition based on which the generator generates an image and discriminator classifies this image.
And the objective function for the same given above, where G tries to minimize the objective against an adversarial D that tries to maximize it. Final objective is:
Without z, the net learns a mapping from x to y, but would produce deterministic outputs, therefore fail to match any distribution other than a delta function.
Here, I would like to show you the image to image translation which is explained in this paper. It is composed of a generator and a discriminator network. Generator is same as the U-net model with skip connections and discriminator is a convolutional patchGAN classifier.
Both generator and discriminator use modules of the form Convolution-BatchNormalization-ReLu. All convolutions use 4x4 spatial filters with 2x2 stride. Convolutions in encoder and discriminator down-sample by a factor of 2 and in decoder up-sample by a factor of 2.
Here, the generator is U-net with skip connections between the encoder and decoder network. After the last layer in the decoder, convolution is applied to match the number of output channels, followed by a Tanh function. BatchNormalization is not applied to the first layer in the encoder. All ReLu’s in encoder are leaky, with slope of 0.2, while ReLu’s in the decoder are not leaky.
Discriminator is a 70x70 patchGAN. After the last layer convolution is applied to map 1D output, followed by a sigmoid function. BatchNormalization is not applied to the first layer . All ReLu’s are leaky with a slope of 0.2
Data set:
Here you can find a list of datasets which can be used in this image to image translation using cGAN. I am using Facades data set here. Now we look into the implementation of the same.
To start with, we need to prepare a data set. Original data set contains concatenated images of the source and target images, and we need to separate them. We use the below function to separate them and save both the source and target images in separate arrays (Initial 256 columns of pixels contains source image and next 256 pixels corresponds to target image).
def load_images(path, size=(256,512)):
src_img, targ_img = list(), list()
for filename in listdir(path):
image = load_img(path+'/'+filename, target_size=size)
pixels = img_to_array(image)
edge_image, real_image = pixels[:,:256], pixels[:,256:]
src_img.append(edge_image)
targ_img.append(real_image)
return [asarray(src_img), asarray(targ_img)]
Next, we create a model for discriminator and generator as explained earlier.
Discriminator:
def discriminator(image_shape):
init = RandomNormal(stddev=0.02)
in_src_image = keras.layers.Input(shape=image_shape)
in_target_image = keras.layers.Input(shape=image_shape)
merged = keras.layers.Concatenate()([in_src_image, in_target_image]) d = keras.layers.Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
d = LeakyReLU(alpha=0.2)(d) d = keras.layers.Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = keras.layers.BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d) d = keras.layers.Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = keras.layers.BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d) d = keras.layers.Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = keras.layers.BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d) d = keras.layers.Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = keras.layers.BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d) d = keras.layers.Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
patch_out = keras.layers.Activation('sigmoid')(d) model = keras.models.Model([in_src_image, in_target_image], patch_out)
opt = keras.optimizers.Adam(lr=0.0002, beta_1=0.5) model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5]) return model
Generator:
def encoder_block(input_tensor, n_filters, kernel_size = 4, batchnorm=True):
init = RandomNormal(stddev=0.02) # first layer
x = keras.layers.Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size), strides=(2,2), padding = 'same', kernel_initializer=init)(input_tensor) if batchnorm:
x = keras.layers.BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
return xdef decoder_block(input_tensor, skip_layer, n_filters, dropout=True):
init = RandomNormal(stddev=0.02) x = keras.layers.Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding="same", kernel_initializer=init)(input_tensor)
x = keras.layers.BatchNormalization()(x, training=True)
if dropout:
x = keras.layers.Dropout(0.5)(x, training=True)
x = keras.layers.Concatenate()([x, skip_layer])
x = keras.layers.Activation("relu")(x)
return xdef generator(input_img=(256,256,3), batchnorm=True, dropout=True):
init = RandomNormal(stddev=0.02)
image = keras.layers.Input(shape=input_img) # Contracting Path
c1 = encoder_block(image, 64, kernel_size = 4, batchnorm=False)
c2 = encoder_block(c1, 128, kernel_size = 4, batchnorm=batchnorm)
c3 = encoder_block(c2, 256, kernel_size = 4, batchnorm=batchnorm)
c4 = encoder_block(c3, 512, kernel_size = 4, batchnorm=batchnorm)
c5 = encoder_block(c4, 512, kernel_size = 4, batchnorm=batchnorm)
c6 = encoder_block(c5, 512, kernel_size = 4, batchnorm=batchnorm)
c7 = encoder_block(c6, 512, kernel_size = 4, batchnorm=batchnorm) m = keras.layers.Conv2D(512, (4,4), strides = (2,2), padding="same", kernel_initializer=init)(c7)
m = keras.layers.Activation("relu")(m) # Expansive Path
d1 = decoder_block(m, c7, 512, dropout=dropout)
d2 = decoder_block(d1, c6, 512, dropout=dropout)
d3 = decoder_block(d2, c5, 512, dropout=dropout)
d4 = decoder_block(d3, c4, 512, dropout=False)
d5 = decoder_block(d4, c3, 256, dropout=False)
d6 = decoder_block(d5, c2, 128, dropout=False)
d7 = decoder_block(d6, c1, 64, dropout=False) outputs = keras.layers.Conv2DTranspose(3, (4, 4), strides=(2,2), padding="same", kernel_initializer=init)(d7)
outputs = keras.layers.Activation("tanh")(outputs) model = keras.models.Model(inputs=[image], outputs=[outputs])
return model
Next, combining the generator and discriminator we design the GAN network, by placing generator on top of the discriminator. And since we don’t want to update discriminator weights while training generator, we set the discriminator.trainable to False.
def gan_model(generator, discriminator, input_img):
discriminator.trainable = False
src_input = keras.layers.Input(shape = input_img)
gen_output = generator(src_input)
disc_output = discriminator([src_input, gen_output])
model = keras.models.Model(inputs=src_input, outputs=[disc_output, gen_output])
opt = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
model.compile(loss=["binary_crossentropy", "mae"], optimizer=opt, loss_weights=[1,100])
return model
These models are then trained on each batch, for 100 epochs. Then, predictions of the model and weights have been recorded after each iteration of 10 epochs(after every 4000 steps). Once training is completed, we compare the predicted images with the targeted images, and the model which gives the best result is saved.
In this case, we were able to achieve better results after 60 epochs, i.e., after 24,000 steps as shown in below figure.
First row indicates the source images, second row is for generated images from the generator and third is for target images.
We can see that there isn’t that much of a difference in generated and target images, which shows that generator was successful in fooling the discriminator by creating more realistic images.
Conclusion:
Conditional Generative Adversarial Networks are more promising than standard Generative Adversarial Networks, especially those involving highly structured graphical inputs.
Here is the link to the complete source code.