# The Batch Normalisation in Keras is not broken

Recently there has been a blog post saying the batch normalisation in Keras is broken. In this post I explain why it is not broken and the actual ‘fixes’ needed to make Keras work well for transfer learning.

*tl;dr Solutions:*

- add
`training`

kwarg to`keras_applications`

models - reset moving statistics of the batch normalization layers

### Training kwarg

The proposed patch altered Keras’ API to enable the frozen batch norm layers to infer that when `trainable = False`

it should use the moving statistics instead of the batch statistics. However, this overloads the current interface of using the `training = False`

in the call to the batch norm layer, which isn’t good form.

Like Francois mentioned (in the PR) it is required to do the following for inference statistics:

x = BatchNormalization()(x, training=False)

**However, no one has updated the keras_applications module to reflect the necessary kwarg needed for inference time scripts.** The reason this needs declaring in code is because otherwise you would need to walk the tensorflow graph def (- a large assumption of using TF backend) and rewire it to use the moving statistics. This is possible, if you want to know how then please comment.

So, **Fix 1** is to add the possibility of calling pre-trained models in inference mode, along with updating the surrounding documentation about fine-tuning models.

### Reset the moving statistics

However, there is another fix, one which requires 2 extra lines to the proposed dummy model in the original blog post - reset the moving statistics. This is especially useful for fine-tuning models during transfer learning. It’s also described in this paper [1].

import numpy as np

from keras.datasets import cifar10

from scipy.misc import imresize

from keras.preprocessing.image import ImageDataGenerator

from keras.applications.resnet50 import ResNet50, preprocess_input

from keras.models import Model, load_model

from keras.layers import Dense, Flatten

from keras import backend as K

seed = 42

epochs = 10

records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.

# We will also use the same data for train/test and expect that Keras will give the same accuracy.

(x, y), _ = cifar10.load_data()

def filter_resize(category):

# We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.

return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]

x = np.stack(filter_resize(3)+filter_resize(5))

records_per_class = x.shape[0] // 2

y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)

# We will use a pre-trained model and finetune the top layers.

np.random.seed(seed)

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

l = Flatten()(base_model.output)

predictions = Dense(2, activation='softmax')(l)

model = Model(inputs=base_model.input, outputs=predictions)

for layer in model.layers[:140]:

if hasattr(layer, 'moving_mean') and hasattr(layer, 'moving_variance'):

layer.trainable = True

K.eval(K.update(layer.moving_mean, K.zeros_like(layer.moving_mean)))

K.eval(K.update(layer.moving_variance, K.zeros_like(layer.moving_variance)))

else:

layer.trainable = False

for layer in model.layers[140:]:

layer.trainable = True

model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))

# Store the model on disk

model.save('tmp.h5')

# In every test we will clear the session and reload the model to force Learning_Phase values to change.

print('DYNAMIC LEARNING_PHASE')

K.clear_session()

model = load_model('tmp.h5')

# This accuracy should match exactly the one of the validation set on the last iteration.

print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

print('STATIC LEARNING_PHASE = 0')

K.clear_session()

K.set_learning_phase(0)

model = load_model('tmp.h5')

# Again the accuracy should match the above.

print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

print('STATIC LEARNING_PHASE = 1')

K.clear_session()

K.set_learning_phase(1)

model = load_model('tmp.h5')

# The accuracy will be close to the one of the training set on the last iteration.

print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))

This enables the model to learn the statistics of the dataset for which the model is being transferred to, instead of using the previous models statistics. My personal preference is using the current dataset statistics over fixing the moving statistics, this allows the model to correct for any differences, such as fine-tuning on images which have different image characteristics to imagenet. One example being medical images.

An even better approach is to not freeze the initial layers but set a smaller learning rate for them, therefore allowing them to adapt and not drastically forget.

### Conclusion

In conclusion, Keras is not broken but the documentation needs clarifying for these matters.

### References:

[1] Li, Yanghao et al. , ICLR2017, Revisiting Batch Normalization For Practical Domain Adaptation