Callbacks in Keras
Callbacks API
A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc).
You can use callbacks to:
- Write TensorBoard logs after every batch of training to monitor your metrics
- Periodically save your model to disk
- Do early stopping
- Get a view on internal states and statistics of a model during training
Available callbacks
- Base Callback class
- ModelCheckpoint
- TensorBoard
- EarlyStopping
- LearningRateScheduler
- ReduceLROnPlateau
- RemoteMonitor
- LambdaCallback
- TerminateOnNaN
- CSVLogger
- ProgbarLogger
- BackupAndRestore
I’ll be using EarlyStopping, ReduceLROnPlateau and a custom callback
read more : https://keras.io/api/callbacks/
Let’s work with the code example. I’ll be using this dataset : https://www.kaggle.com/grassknoted/asl-alphabet
- import the required modules
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from warnings import filterwarnings as filt
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, MaxPool2D, Conv2D, Flatten, Input, Dropout, BatchNormalization
import tensorflow.keras as keras
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback, EarlyStoppingplt.style.use('fivethirtyeight')
plt.rcParams['figure.figsize'] = (12, 6)
filt('ignore')
2. use keras datagenerators to create to augment the images and create a generator
3. Initialize the model architecture
def CNN():
ip = Input(shape = (180, 180, 3))
c1 = Conv2D(filters = 32, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(ip)
mp1 = MaxPool2D(pool_size=(2,2))(c1)
c2 = Conv2D(filters = 32, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp1)
bn2 = BatchNormalization()(c2)
mp2 = MaxPool2D(pool_size=(2,2))(bn2)
c3 = Conv2D(filters = 64, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp2)
bn3 = BatchNormalization()(c3)
mp3 = MaxPool2D(pool_size=(2,2))(bn3)
c4 = Conv2D(filters = 64, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp3)
bn4 = BatchNormalization()(c4)
mp4 = MaxPool2D(pool_size=(2,2))(bn4)
c5 = Conv2D(filters = 128, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp4)
bn5 = BatchNormalization()(c5)
mp5 = MaxPool2D(pool_size=(2,2))(bn5)
c6 = Conv2D(filters = 128, kernel_size = (3,3), strides = 1, padding = 'same', activation = 'relu')(mp5)
bn6 = BatchNormalization()(c6)
mp6 = MaxPool2D(pool_size=(2,2))(bn6)
f = Flatten()(mp6)
bn7 = BatchNormalization()(f)
dp = Dropout(0.2)(bn7)
h1 = Dense(512, activation = 'relu')(dp)
dp = Dropout(0.25)(h1)
h2 = Dense(256, activation = 'relu')(dp)
op = Dense(29 , activation = 'sigmoid')(h2)
return keras.Model(inputs = ip, outputs = op)
4. compile the model:
model1 = CNN()
model1.compile(
loss = 'categorical_crossentropy',
metrics = ['accuracy'],
optimizer = keras.optimizers.RMSprop(learning_rate = 0.1)
)model1.summary()Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 180, 180, 3)] 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 180, 180, 32) 896
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 90, 90, 32) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 90, 90, 32) 9248
_________________________________________________________________
batch_normalization_5 (Batch (None, 90, 90, 32) 128
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 45, 45, 32) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 45, 45, 64) 18496
_________________________________________________________________
batch_normalization_6 (Batch (None, 45, 45, 64) 256
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 22, 22, 64) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 22, 22, 64) 36928
_________________________________________________________________
batch_normalization_7 (Batch (None, 22, 22, 64) 256
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 11, 11, 64) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 11, 11, 128) 73856
_________________________________________________________________
batch_normalization_8 (Batch (None, 11, 11, 128) 512
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 5, 5, 128) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 5, 5, 128) 147584
_________________________________________________________________
batch_normalization_9 (Batch (None, 5, 5, 128) 512
_________________________________________________________________
max_pooling2d_11 (MaxPooling (None, 2, 2, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 512) 0
_________________________________________________________________
batch_normalization_10 (Batc (None, 512) 2048
_________________________________________________________________
dropout_2 (Dropout) (None, 512) 0
_________________________________________________________________
dense_3 (Dense) (None, 512) 262656
_________________________________________________________________
dropout_3 (Dropout) (None, 512) 0
_________________________________________________________________
dense_4 (Dense) (None, 256) 131328
_________________________________________________________________
dense_5 (Dense) (None, 29) 7453
=================================================================
Total params: 692,157
Trainable params: 690,301
Non-trainable params: 1,856
_________________________________________________________________
5. import the necessary callbacks using keras and initialize them
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback, EarlyStoppinglr_decay = ReduceLROnPlateau(
monitor="val_loss",
factor=0.95,
patience=3,
verbose=1,
mode="auto",
min_lr=0.00001,
)er_stopping = EarlyStopping(
monitor="val_loss",
min_delta=0.005,
patience=4,
verbose=1,
mode="auto",
restore_best_weights= True,
)class StopTraining(Callback):
def __init__(self, thresh = 0.95, times = 3):
self.thresh = thresh
self.times = times
self.reached = 0
def on_epoch_end(self, epoch, logs = {}):
if logs.get("val_accuracy") >= self.thresh:
print(f"Reached 95% accuracy {self.reached} / {self.times} ...")
self.reached += 1
if self.reached >= self.times:
self.model.stop_training = True
stopTr = StopTraining(times = 3)
- StopTraining : class will check the validation accuracy at the end of the every epoch. if the accuracy reached or got more than the specified threshold at n times, which is in this case if the validation accuracy is ≥ 0.95 atlest 3 times or more then it will stop the training
- EarlyStopping : https://keras.io/api/callbacks/early_stopping
- ReduceLROnPlateau : https://keras.io/api/callbacks/reduce_lr_on_plateau
6. Assign these 3 callbacks into the callbacks argument while fitting the model
his = model1.fit_generator(
generator = train_gen,
validation_data = dev_gen,
epochs = 15,
callbacks = [lr_decay, stopTr, er_stopping]
)
7. This will stop the model from training more once any of these callbacks hold true and thus reduce overfitting the model.