Callbacks in Keras

Jaabir
featurepreneur
Published in
4 min readJan 6, 2022

Callbacks API

A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc).

You can use callbacks to:

  • Write TensorBoard logs after every batch of training to monitor your metrics
  • Periodically save your model to disk
  • Do early stopping
  • Get a view on internal states and statistics of a model during training

Available callbacks

I’ll be using EarlyStopping, ReduceLROnPlateau and a custom callback

read more : https://keras.io/api/callbacks/

Let’s work with the code example. I’ll be using this dataset : https://www.kaggle.com/grassknoted/asl-alphabet

  1. import the required modules
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from warnings import filterwarnings as filt
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, MaxPool2D, Conv2D, Flatten, Input, Dropout, BatchNormalization
import tensorflow.keras as keras
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback, EarlyStopping
plt.style.use('fivethirtyeight')
plt.rcParams['figure.figsize'] = (12, 6)
filt('ignore')

2. use keras datagenerators to create to augment the images and create a generator

3. Initialize the model architecture

def CNN():

ip = Input(shape = (180, 180, 3))

c1 = Conv2D(filters = 32, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(ip)
mp1 = MaxPool2D(pool_size=(2,2))(c1)

c2 = Conv2D(filters = 32, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp1)
bn2 = BatchNormalization()(c2)
mp2 = MaxPool2D(pool_size=(2,2))(bn2)

c3 = Conv2D(filters = 64, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp2)
bn3 = BatchNormalization()(c3)
mp3 = MaxPool2D(pool_size=(2,2))(bn3)

c4 = Conv2D(filters = 64, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp3)
bn4 = BatchNormalization()(c4)
mp4 = MaxPool2D(pool_size=(2,2))(bn4)

c5 = Conv2D(filters = 128, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp4)
bn5 = BatchNormalization()(c5)
mp5 = MaxPool2D(pool_size=(2,2))(bn5)

c6 = Conv2D(filters = 128, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp5)
bn6 = BatchNormalization()(c6)
mp6 = MaxPool2D(pool_size=(2,2))(bn6)

f = Flatten()(mp6)
bn7 = BatchNormalization()(f)
dp = Dropout(0.2)(bn7)

h1 = Dense(512, activation = 'relu')(dp)
dp = Dropout(0.25)(h1)
h2 = Dense(256, activation = 'relu')(dp)

op = Dense(29 , activation = 'sigmoid')(h2)

return keras.Model(inputs = ip, outputs = op)

4. compile the model:

model1 = CNN()
model1.compile(
loss = 'categorical_crossentropy',
metrics = ['accuracy'],
optimizer = keras.optimizers.RMSprop(learning_rate = 0.1)
)
model1.summary()Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 180, 180, 3)] 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 180, 180, 32) 896
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 90, 90, 32) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 90, 90, 32) 9248
_________________________________________________________________
batch_normalization_5 (Batch (None, 90, 90, 32) 128
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 45, 45, 32) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 45, 45, 64) 18496
_________________________________________________________________
batch_normalization_6 (Batch (None, 45, 45, 64) 256
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 22, 22, 64) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 22, 22, 64) 36928
_________________________________________________________________
batch_normalization_7 (Batch (None, 22, 22, 64) 256
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 11, 11, 64) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 11, 11, 128) 73856
_________________________________________________________________
batch_normalization_8 (Batch (None, 11, 11, 128) 512
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 5, 5, 128) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 5, 5, 128) 147584
_________________________________________________________________
batch_normalization_9 (Batch (None, 5, 5, 128) 512
_________________________________________________________________
max_pooling2d_11 (MaxPooling (None, 2, 2, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 512) 0
_________________________________________________________________
batch_normalization_10 (Batc (None, 512) 2048
_________________________________________________________________
dropout_2 (Dropout) (None, 512) 0
_________________________________________________________________
dense_3 (Dense) (None, 512) 262656
_________________________________________________________________
dropout_3 (Dropout) (None, 512) 0
_________________________________________________________________
dense_4 (Dense) (None, 256) 131328
_________________________________________________________________
dense_5 (Dense) (None, 29) 7453
=================================================================
Total params: 692,157
Trainable params: 690,301
Non-trainable params: 1,856
_________________________________________________________________

5. import the necessary callbacks using keras and initialize them

from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback, EarlyStoppinglr_decay = ReduceLROnPlateau(
monitor="val_loss",
factor=0.95,
patience=3,
verbose=1,
mode="auto",
min_lr=0.00001,
)
er_stopping = EarlyStopping(
monitor="val_loss",
min_delta=0.005,
patience=4,
verbose=1,
mode="auto",
restore_best_weights= True,
)
class StopTraining(Callback):
def __init__(self, thresh = 0.95, times = 3):
self.thresh = thresh
self.times = times
self.reached = 0

def on_epoch_end(self, epoch, logs = {}):
if logs.get("val_accuracy") >= self.thresh:
print(f"Reached 95% accuracy {self.reached} / {self.times} ...")
self.reached += 1
if self.reached >= self.times:
self.model.stop_training = True

stopTr = StopTraining(times = 3)

6. Assign these 3 callbacks into the callbacks argument while fitting the model

his = model1.fit_generator(
generator = train_gen,
validation_data = dev_gen,
epochs = 15,
callbacks = [lr_decay, stopTr, er_stopping]
)

7. This will stop the model from training more once any of these callbacks hold true and thus reduce overfitting the model.

--

--