Implementing SRResnet/SRGAN Super-Resolution with Tensorflow

Sieun Park
Analytics Vidhya
Published in
7 min readMar 17, 2021

Original Paper: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

Introduction

The paper above proposes a residual block-based neural network to super-resolve images, a VGG loss to improve the MSE loss that often fails to enforce fine SR image generation. The SRGAN methods from the paper also involve training the model with an adversarial loss together with the context loss to further improve image reconstruction quality.

We summarized the concepts and methods of the paper in a previous post[2]. In this post, we will implement the network architecture, loss, and training procedure of the methods proposed in this paper. The complete code used in this post can be viewed here.

Loading data

The paper trained their networks by crops from the renowned ImageNet image recognition dataset. Although it is beneficial to train models in large amounts of data, the dataset found to be too heavy and I decided to use the tf_flowers dataset, consisting of 3,670 images which might seem too small but were just enough for a toy dataset to evaluate and compare the performance of each training method of the paper.

data=tfds.load('tf_flowers')train_data=data['train'].skip(600)
test_data=data['train'].take(600)

We use the `tensorflow_datasets` module for loading the tf_flowers dataset and take the first 600 images as a validation dataset.

@tf.function
def build_data(data):
cropped=tf.dtypes.cast(tf.image.random_crop(data['image'] / 255,(128,128,3)),tf.float32)
lr=tf.image.resize(cropped,(32,32))
return (lr,cropped * 2 - 1)
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_dataset_mapped.take(1):
plt.imshow(x[0].numpy())
plt.show()
plt.imshow(bicubic_interpolate(x[0].numpy(),(128,128)))
plt.show()
plt.imshow(x[1].numpy())
plt.show()

We then define a function to map each image from the dataset to (128, 128) crops and a (32, 32) low-resolution copy of it. We can apply this function to our dataset by train_data.map(build_data, …) . This will be executed before every training epoch.

Model Definition

def residual_block_gen(ch=64,k_s=3,st=1):
model=tf.keras.Sequential([
tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
])
return model
def Upsample_block(x, ch=256, k_s=3, st=1):
x = tf.keras.layers.Conv2D(ch,k_s, strides=(st,st),padding='same')(x)
x = tf.nn.depth_to_space(x, 2) # Subpixel pixelshuffler
x = tf.keras.layers.LeakyReLU()(x)
return x
input_lr=tf.keras.layers.Input(shape=(None,None,3))
input_conv=tf.keras.layers.Conv2D(64,9,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
SRRes=input_conv
for x in range(5):
res_output=residual_block_gen()(SRRes)
SRRes=tf.keras.layers.Add()([SRRes,res_output])
SRRes=tf.keras.layers.Conv2D(64,9,padding='same')(SRRes)
SRRes=tf.keras.layers.BatchNormalization()(SRRes)
SRRes=tf.keras.layers.Add()([SRRes,input_conv])SRRes=Upsample_block(SRRes)
SRRes=Upsample_block(SRRes)
output_sr=tf.keras.layers.Conv2D(3,9,activation='tanh',padding='same')(SRRes)
SRResnet=tf.keras.models.Model(input_lr,output_sr)

We define the residual generator architecture using Tensorflow. Functions were defined to build an entire residual block, and element-wise sum skip connections were also implemented. 5 residual blocks are connected, and the final image is upsampled through the pixel shuffler method, implemented in the Upsample_block function. The model is built as below. Because this network is fully-convolution composed, we do not have to define the input shape and therefore, the model can also process images of any size.

def residual_block_disc(ch=64,k_s=3,st=1):
model=tf.keras.Sequential([
tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
])
return model
input_lr=tf.keras.layers.Input(shape=(128,128,3))
input_conv=tf.keras.layers.Conv2D(64,3,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
channel_nums=[64,128,128,256,256,512,512]
stride_sizes=[2,1,2,1,2,1,2]
disc=input_convfor x in range(7):
disc=residual_block_disc(ch=channel_nums[x],st=stride_sizes[x])(disc)
disc=tf.keras.layers.Flatten()(disc)
disc=tf.keras.layers.Dense(1024)(disc)
disc=tf.keras.layers.LeakyReLU()(disc)disc_output=tf.keras.layers.Dense(1,activation='sigmoid')(disc)
discriminator=tf.keras.models.Model(input_lr,disc_output)

The discriminator architecture is also implemented based on the specifications of the papers. The network is a conventional CNN which inputs the image and decides the authenticity of the image.

Loss Implementation

def PSNR(y_true,y_pred):
mse=tf.reduce_mean( (y_true - y_pred) ** 2 )
return 20 * log10(1 / (mse ** 0.5))
def log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def pixel_MSE(y_true,y_pred):
return tf.reduce_mean( (y_true - y_pred) ** 2 )

We define the pixel-wise MSE loss and the PSNR metric for training and evaluation. These loss formulations are explained more in the previous post on the concepts of this paper.

VGG19=tf.keras.applications.VGG19(weights='imagenet',include_top=False,input_shape=(128,128,3))
VGG_i,VGG_j=2,2def VGG_loss(y_hr,y_sr,i_m=2,j_m=2):
i,j=0,0
accumulated_loss=0.0
for l in VGG19.layers:\
cl_name=l.__class__.__name__
if cl_name=='Conv2D':
j+=1
if cl_name=='MaxPooling2D':
i+=1
j=0
if i==i_m and j==j_m:
break
y_hr=l(y_hr)
y_sr=l(y_sr)
if cl_name=='Conv2D':
accumulated_loss+=tf.reduce_mean((y_hr-y_sr)**2) * 0.006
return accumulated_loss
def VGG_loss_old(y_true,y_pred):
accumulated_loss=0.0
for l in VGG19.layers:
y_true=l(y_true)
y_pred=l(y_pred)
accumulated_loss+=tf.reduce_mean((y_true-y_pred)**2) * 0.006
return accumulated_loss

The VGG loss proposed in the paper compares intermediate activation of the pre-trained VGG-19 network when predicting images. We forward propagate through each layer of the VGG model one by one and compare each intermediate output. We define the intuitive VGG loss as VGG_loss_old , and the precise loss as VGG_loss .

cross_entropy = tf.keras.losses.BinaryCrossentropy()
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)

The adversarial losses are defined as above. Code related to the adversarial training procedure is mainly referenced from the Tensorflow DCGAN tutorial[3].

Training

generator_optimizer=tf.keras.optimizers.Adam(0.001)
discriminator_optimizer=tf.keras.optimizers.Adam(0.001)
adv_ratio=0.001
evaluate=['PSNR']
# MSE
loss_func,adv_learning = pixel_MSE,False
# VGG2.2
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),False
# VGG 5.4
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),False
# SRGAN-MSE
loss_func,adv_learning = pixel_MSE,True
# SRGAN-VGG 2.2
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),True
# SRGAN-VGG 5.4
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True#Real loss
loss_func,adv_learning = pixel_MSE,False

We first define the hyperparameters and loss function for the model to optimize. The snippet provides some configurations of the losses proposed in the paper.

Training step based on Tensorflow DCGAN tutorial, the training loop can generalize for all possible losses. Adversarial training is only done if adv_learning=True . We super resolve the image using the generator model, measure the loss with the given metric, and tape the gradients. If the following code seems overly complicated, I strongly recommend having a look at the DCGAN tutorial.

@tf.function()
def train_step(data,loss_func=pixel_MSE,adv_learning=True,evaluate=['PSNR'],adv_ratio=0.001):
logs={}
gen_loss,disc_loss=0,0
low_resolution,high_resolution=data
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
super_resolution = SRResnet(low_resolution, training=True)
gen_loss=loss_func(high_resolution,super_resolution)
logs['reconstruction']=gen_loss
if adv_learning:
real_output = discriminator(high_resolution, training=True)
fake_output = discriminator(super_resolution, training=True)

adv_loss_g = generator_loss(fake_output) * adv_ratio
gen_loss += adv_loss_g

disc_loss = discriminator_loss(real_output, fake_output)
logs['adv_g']=adv_loss_g
logs['adv_d']=disc_loss
gradients_of_generator = gen_tape.gradient(gen_loss, SRResnet.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, SRResnet.trainable_variables))

if adv_learning:
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

for x in evaluate:
if x=='PSNR':
logs[x]=PSNR(high_resolution,super_resolution)
return logs

Training is performed by the following code that loops the dataset and calls the predefined train_step function for every batch. As mentioned above, images are cropped again before every epoch.

for x in range(50):
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(128)
val_dataset_mapped = test_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(128)
for image_batch in tqdm.tqdm(train_dataset_mapped, position=0, leave=True):
logs=train_step(image_batch,loss_func,adv_learning,evaluate,adv_ratio)
for k in logs.keys():
print(k,':',logs[k],end=' ')
print()

Evaluation

We visualize some example images super-resolved through the trained models. The first image is the original HR image, the second image is the SR image, and the third and fourth images are low-resolution and bicubic interpolated images. Although each model wasn’t trained for a sufficient amount of time, we could compare the performance of each model. Images generated with models trained with VGG and adversarial losses seem to have better quality. Look close at the reconstructed texture of the wood in the first picture.

I didn’t test out all the proposed losses. It would be great if you could share results after training more methods and evaluate the performance with the code provided in my COLAB link, and try training the model on bigger datasets such as the ImageNet dataset. Also, I am definite that the model will perform better with more training epochs. in the current stage of training, we can see artificial filters in the reconstructed image because of immature ESPCN reconstruction layers. This can be solved through more iterations of training, although the model still outperforms the MSE based model perceptually.

  • SRResNet + MSE
  • SRResNet + VGG 2.2
  • SRResNet + VGG 5.4
  • SRGAN 0.001 + MSE
  • SRGAN 0.001 + VGG 2.2
  • SRGAN 0.001 + VGG 5.4

References

My COLAB implementation of SRResnet/SRGAN: https://colab.research.google.com/drive/15MGvc5h_zkB9i97JJRoy_-qLtoPEU2sp?usp=sharing

[1] Ledig, Christian, et al. “Photo-realistic single image super-resolution using a generative adversarial network.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.

[2] Super Resolution with SRResnet, SRGAN. https://medium.com/analytics-vidhya/super-resolution-with-srresnet-srgan-2859b87c9c7f

[3] Tensorflow DCGAN Tutorial: https://www.tensorflow.org/tutorials/generative/dcgan

--

--

Sieun Park
Analytics Vidhya

Who knew AI could be so hot? I am an 18 years old Korean student. I started working in ML at 16, but that's it. LinkedIn: https://bit.ly/2VTkth7 😀