Paper Explained- Normalizer Free Nets (NFNETS): High Performance Large Scale Image Recognition Without Normalisation

Nakshatra Singh
Analytics Vidhya
Published in
6 min readFeb 15, 2021
NFNet-F1 model achieves comparable accuracy to an EffNet-B7 while being 8.7× faster to train. The image is taken from page 1 of the paper.

Introduction & Overview

So the point of this paper is to build networks in this case specifically, Convolutional Residual Style Networks that have no batch normalization build in them. But without the batch normalization usually these networks are not performing so well or cannot scale to larger batch sizes however this paper right here builds networks that scale to large batch sizes and are more efficient than previous state-of-the-art methods (like LambdaNets, I have also written a detailed article on it, click right here to check it out!!!🤞). The training latency vs accuracy graph shows that NFnets are 8.7× times faster than EffNet-B7 for the same top-1 accuracy score trained on ImageNet. This model is a new state-of-the-art without any additional training data and it is also the new state-of-the-art transfer learning. NFnets are currently ranked 2 on the global leaderboard behind a method that uses semi-supervised pre-training and extra data.

What’s the problem with Batch Normalization?

If you have a data-point that goes through a network, it will experience various transformations as it goes down the layers, however, some of these transformations are quite hapless if you build the network in a wrong way. In machine learning it’s a good practice to centre the data around the mean and scale it to unit variants but then as you progress through the layers, especially if you have layers like ReLU, they only extract the positive part of the signal. So with time it can happen that the intermediate representation between the layers further down is very skewed and not centred. The current methods in machine learning work better if your data has a nice condition number (i.e, centred around mean, not very skewed, and so on).

Basic illustration of how data is transformed using batch norm in neural networks.

Batch normalization has 3 significant disadvantages. First, it is a surprisingly expensive computational primitive, which incurs memory overhead. You need to compute the means, the scalings and you need to store them in a memory for the back propagation algorithm. This increases the time required to evaluate the gradients in some networks.

Secondly, it introduces a a discrepancy between the behaviour of the model during training and at inference time, which is true because at inference time you don’t want this kind of batch dependence, you want to be able to feed a single data point and the result should always be the same.

Thirdly, batch normalization breaks the independence between training examples in the mini-batch. This mean that, it now matters which other examples are in the batch.

This has 2 main consequences. Firstly, the batch size will matter for batch normalization. If you have a small batch size, the mean is going to be a very noisy approximation, whereas, if you have a large batch size the mean is going to be a good approximation. We know for some application large batches are favourable for training, they stabilise the training, reduce training time and more.

Secondly, distributed training becomes very cumbersome because for example, if you do data parallelism, which means, you have your batch of data and this batch is split into 3 different parts/shards , these 3 shards are forward propagated into a neural network which is same for all the 3 different machines used for training. Now imagine if you have a batch norm layer in all the 3 networks what you would have to do technically is forward propagate the signal to the batch norm layer and then you would have to communicate the batch statistics between the batch norm layers because otherwise you don’t have the mean and the variance over your whole batch you feed in. This enables the network to ‘cheat’ certain loss functions.

Overview of how data parallelism works using statistical connections between batch norm layers.

Paper Contributions

  1. The authors propose Adaptive Gradient Clipping (AGC), which clips gradients based on the unit-wise ratio of gradient norms to parameter norms, and they demonstrate that AGC allows us to train Normalizer-Free Networks with larger batch sizes and stronger data augmentations.
  2. The authors have designed a family of Normalizer-Free ResNets, called NFNets, which set new state-of-the-art validation accuracies on ImageNet for a range of training latencies. NFNet-F1 model achieves similar accuracy to EfficientNet-B7 while being 8.7× faster to train, and the largest model sets a new overall state of the art without extra data of 86.5% top-1 accuracy.
  3. The authors show that NFNets achieve substantially higher validation accuracies than batch-normalised networks when fine-tuning on ImageNet after pre-training on a large private dataset of 300 million labelled images. The best model achieves 89.2% top-1 after fine-tuning.

Adaptive Gradient Clipping (AGC)

Gradient clipping is often used in language modelling to stabilise training and recent work shows that it allows training with larger learning rates compared to gradient descent. Gradient clipping is typically performed by constraining the norm of the gradient. Specifically, for gradient vector G = ∂L/∂θ, where L denotes the loss and θ denotes a vector with all model parameters, the standard clipping algorithm clips the gradient before updating θ as:

Formula snippet is taken from page 4 of the paper.

During training it isn’t really good for the optimiser to take giant jumps to reach the global minima, so gradient clipping simply says that whenever a gradient of any parameter is very large, we’ll simply clip that gradient. If the gradient is good we’re surely going to see it again but if it’s a bad gradient we want to limit its impact. The problem is that it’s very sensitive to the clipping parameter λ and the reason is it’s not adaptive.

What AGC does is, it scales the gradients but, it not only scales the gradient to its own norm but it clips the gradient to the ratio (how large the gradient is / how large the weight is that the gradient acts upon is). It might seem confusing at first, but I would encourage you to have a look at this paper and read page 4 thoroughly to understand AGC with more clarity.

The clipping threshold λ is a scalar hyper-parameter which must be tuned. Empirically, the authors found that while this clipping algorithm enabled them to train at higher batch sizes than before, training stability was extremely sensitive to the choice of the clipping threshold, requiring fine-grained tuning when varying the model depth, the batch size, or the learning rate. The authors ignore to scale of the gradient by choosing an adaptive learning rate inversely proportional to the gradient norm.

Note that the optimal clipping parameter λ may depend on the choice of optimiser, learning rate and batch size. Empirically, the authors find λ should be smaller for larger batches.

Ablations for Adaptive Gradient Clipping (AGC)

Image is taken from page 4 of the paper.

For example, if you compare batch norm networks in graph 1, (NF-ResNet and NF-ResNet + AGC) you can see that after a certain batch size (2048) the non-AGC simply collapses while the AGC one prevails. This seems to be the recipe to go to higher batch size. The authors complain that the clipping threshold λ is very fastidious. In graph 2, you can see that λ has a crucial dependence on the batch size, you can see at small batch sizes you can get away with clipping at a pretty large threshold. At large batch sizes you have to keep the threshold very low because if you clip it higher then it collapses.

If you enjoyed this article and gained insightful knowledge, consider buying me a coffee ☕️ by clicking here :)

References

  1. High-Performance Large-Scale Image Recognition Without Normalization, 11 Feb 2021.

If you liked this post, please make sure to clap 👏. 💬 Connect? Let’s get social: http://myurls.co/nakshatrasinghh.

--

--

Nakshatra Singh
Analytics Vidhya

A Machine Learning, Deep Learning, and Natural Language Processing enthusiast. Making life easy for beginners to read SOTA research papers🤞❤️