Building Inception-Resnet-V2 in Keras from scratch

Siladittya Manna
The Owl
Published in
4 min readApr 10, 2019
Image taken from yeephycho

Both the Inception and Residual networks are SOTA architectures, which have shown very good performance with relatively low computational cost. Inception-ResNet combines the two architectures to further boost the performance.

Residual Inception blocks

Residual Inception Block(Inception-ResNet-A)
  1. Each Inception block is followed by a filter expansion layer
    (1 × 1 convolution without activation) which is used for scaling up the dimensionality of the filter bank before the addition to match the depth of the input.
  2. In the case of Inception-ResNet, batch-normalization is used only on top of the traditional layers, but not on top of the summations.

Scaling of Residuals

According to the authors, if the number of filters exceeded 1000, the residual variants started to exhibit instabilities and the network has just “died” early in the training, meaning that the last layer before the average pooling started to produce only zeros after a few tens of thousands of iterations. This could not be prevented, neither by lowering the learning rate, nor by adding an extra batchnormalization to this layer.

Scaling of Residuals

According to them, scaling down the residuals before adding them to the previous layer activation seemed to stabilize the training. To scale the residuals, scaling factors between 0.1 and 0.3 were picked.

The above image has been taken from here.

Building the Network

Importing the required libraries

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Concatenate, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers, activations
import os
from sklearn.utils import shuffle
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
#import itertools
#import shutil
%matplotlib inline

Image Pre-processing

def random_crop(img, random_crop_size): 
# Note: image_data_format is ‘channel_last’
h, w = img.shape[0], img.shape[1]
dy, dx = random_crop_size
x = np.random.randint(0, w — dx + 1)
y = np.random.randint(0, h — dy + 1)
return img[y:(y+dy), x:(x+dx), :]
def crop_generator(batches, crop_length):
“””Take as input a Keras ImageGen (Iterator)
and generate random crops from the image batches
generated by the original iterator. “””
sz = crop_length
while True:
batch_x, batch_y = next(batches)
batch_crops = np.zeros((batch_x.shape[0], sz, sz, 3))
for i in range(batch_x.shape[0]):
batch_crops[i] = random_crop(batch_x[i], (sz,sz))
yield (batch_crops, batch_y)
tr_datagen = ImageDataGenerator(rescale=1.0/255,
horizontal_flip=True,
vertical_flip=True)
ts_datagen = ImageDataGenerator(rescale=1.0/255)
train_gen = tr_datagen.flow_from_directory(train_path,
target_size=(IMAGE_SIZE,IMAGE_SIZE),
batch_size=train_batch_size,
class_mode=’categorical’)
tr_crops = crop_generator(train_gen)
val_gen = val_datagen.flow_from_directory(valid_path,
target_size=(IMAGE_SIZE,IMAGE_SIZE),
batch_size=val_batch_size,
class_mode=’categorical’)
val_crops = crop_generator(val_gen)
#Note: shuffle=False causes the test dataset to not be shuffled
test_gen = ts_datagen.flow_from_directory(test_path,
target_size=(IMAGE_SIZE,IMAGE_SIZE),
batch_size=1,
class_mode=’categorical’,
shuffle=False)
#ts_crops = crop_generator(test_gen)

Convolutional Block

Inception ResNet A block

Inception ResNet B block

Inception ResNet C block

Stem block

Inception-ResNet Network

Building the model

model = Model(img_input,x,name=’inception_resnet_v2')

Model Summary

model.summary()

Save Model as ‘.png’

from tensorflow.keras.utils import plot_model
from IPython.display import SVG

plot_model(model, to_file=’model_plot.png’,
show_shapes=True, show_layer_names=True)

Configure the Model for training

model.compile(Adam(lr=0.0001), 
loss=’categorical_crossentropy’,
metrics=[‘accuracy’])

Model Checkpoint

filepath = “model.h5”
checkpoint = ModelCheckpoint(filepath, monitor=’val_acc’,
verbose=1, save_best_only=True, mode=’max’)

Early Stopping

early = EarlyStopping(monitor="val_loss", 
mode="min",
patience=4, restore_best_weights=True)
callbacks_list = [checkpoint, early]

Training the Model

history = model.fit_generator(train_gen,
steps_per_epoch=train_steps,
validation_data=val_gen,
validation_steps=val_steps,
epochs=20, verbose=1,
callbacks=callbacks_list)

Plotting Training Statistics

# Training plots
epochs = [i for i in range(1, len(history.history['loss'])+1)]

plt.plot(epochs, history.history['loss'], color='blue', label="training_loss")
plt.plot(epochs, history.history['val_loss'], color='red', label="validation_loss")
plt.legend(loc='best')
plt.title('training')
plt.xlabel('epoch')
plt.savefig(TRAINING_PLOT_FILE, bbox_inches='tight')
plt.show()

plt.plot(epochs, history.history['acc'], color='blue', label="training_accuracy")
plt.plot(epochs, history.history['val_acc'], color='red',label="validation_accuracy")
plt.legend(loc='best')
plt.title('validation')
plt.xlabel('epoch')
plt.savefig(VALIDATION_PLOT_FILE, bbox_inches='tight')
plt.show()

Evaluating the model

#make sure to load the best model
model.load_weights('model.h5')

predictions = model.predict_generator(test_gen,
steps=num_test_images,
verbose=1)

References

Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning : https://arxiv.org/abs/1602.07261

Random Image Cropping in Keras : JK Jung’s Blog

--

--

Siladittya Manna
The Owl

Senior Research Fellow @ CVPR Unit, Indian Statistical Institute, Kolkata || Research Interest : Computer Vision, SSL, MIA. || https://sadimanna.github.io