The Batch Normalisation in Keras is not broken

Recently there has been a blog post saying the batch normalisation in Keras is broken. In this post I explain why it is not broken and the actual ‘fixes’ needed to make Keras work well for transfer learning.

tl;dr Solutions:

  • add training kwarg to keras_applications models
  • reset moving statistics of the batch normalization layers

Training kwarg

The proposed patch altered Keras’ API to enable the frozen batch norm layers to infer that when trainable = False it should use the moving statistics instead of the batch statistics. However, this overloads the current interface of using the training = False in the call to the batch norm layer, which isn’t good form.

Like Francois mentioned (in the PR) it is required to do the following for inference statistics:

x = BatchNormalization()(x, training=False)

However, no one has updated the keras_applications module to reflect the necessary kwarg needed for inference time scripts. The reason this needs declaring in code is because otherwise you would need to walk the tensorflow graph def (- a large assumption of using TF backend) and rewire it to use the moving statistics. This is possible, if you want to know how then please comment.

So, Fix 1 is to add the possibility of calling pre-trained models in inference mode, along with updating the surrounding documentation about fine-tuning models.


Reset the moving statistics

However, there is another fix, one which requires 2 extra lines to the proposed dummy model in the original blog post - reset the moving statistics. This is especially useful for fine-tuning models during transfer learning. It’s also described in this paper [1].

import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as K

seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()

def filter_resize(category):
# We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)

# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:
if hasattr(layer, 'moving_mean') and hasattr(layer, 'moving_variance'):
layer.trainable = True
K.eval(K.update(layer.moving_mean, K.zeros_like(layer.moving_mean)))
K.eval(K.update(layer.moving_variance, K.zeros_like(layer.moving_variance)))
else:
layer.trainable = False
for layer in model.layers[140:]:
layer.trainable = True
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk
model.save('tmp.h5')
# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))
print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))
print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

This enables the model to learn the statistics of the dataset for which the model is being transferred to, instead of using the previous models statistics. My personal preference is using the current dataset statistics over fixing the moving statistics, this allows the model to correct for any differences, such as fine-tuning on images which have different image characteristics to imagenet. One example being medical images.

An even better approach is to not freeze the initial layers but set a smaller learning rate for them, therefore allowing them to adapt and not drastically forget.

Conclusion

In conclusion, Keras is not broken but the documentation needs clarifying for these matters.

References:

[1] Li, Yanghao et al. , ICLR2017, Revisiting Batch Normalization For Practical Domain Adaptation

Like what you read? Give Joe Yearsley a round of applause.

From a quick cheer to a standing ovation, clap to show how much you enjoyed this story.