Based on this paper: https://arxiv.org/pdf/1502.03167.pdf
📌 Note: if you see an asterisk next to an italicized phrase, that means more info can be found in the appendix at the bottom of the article.
😕 The Problem
Internal covariate shift: The change in the distribution of network activations due to the change in network parameters during training
- Requires lower learning rates and careful parameter initialization, resulting in slower training
- Makes it hard to train models with saturating nonlinearities*
Matryoshka dolls are nested figurines, which get smaller and less ornate as you remove each layer. We can imagine a deep neural network like these dolls. That is to say, let’s think of the second through Nth layers as a smaller, less complex neural network nested inside the outermost one. Then the third through Nth layers would be another neural network within the former! And so on. With this view in mind, we can form an intuitive understanding of the problem.
Picture a deep neural network in training as a set of matryoshka dolls. The data starts at the “outermost network (doll)”, where changes caused by the layer parameters could potentially result in a different distribution of data. This is the key dilemma! Each nested network applies its own changes to the data, which can create a wildly different distribution of inputs by the time data reaches the “innermost network”. This makes the training process very inefficient because the nested networks may not have matching distributions of input and test data.
💡 The Solution(s)
Batch normalization: A mechanism to reduce internal covariate shift by adding a step to training that normalizes (i.e. fixes the means and variances of) layer inputs
- Reduces dependence of gradients on the scale of the parameters or of their initial values
- Makes it possible to use saturating nonlinearities
- Allows for higher learning rates without risk of divergence
- Reduces the need for dropout
Whitening the Inputs
One method to get around internal covariate shift is whitening the inputs, which means linearly transforming them to have zero means and unit variances*, and decorrelated. However, whitening layer inputs is computationally expensive and not everywhere differentiable, the latter of which is problematic for backpropagation, which relies on taking derivatives to calculate the changes to layer parameters.
Normalization via Mini-Batch Statistics
The authors of the linked paper make two simplifications:
- Instead of whitening the features in layer inputs and outputs jointly, they normalized each scalar feature independently, taking care to ensure that the transformation inserted in the network can represent the identity transform.
- Since mini-batches are used in stochastic gradient training, each mini-batch produces estimates of the mean and variance of each activation. Thus, the statistics used for normalization can fully participate in the gradient backpropagation.
Through some calculus that won’t be covered here, the authors then prove that the batch normalization transform is a differentiable transformation that introduces normalized activations into the network.
What this means is that as the model trains, input distributions experience less internal covariate shift, allowing the “nested networks” (and thus the entire network) to avoid the problems outlined in the previous section, thus increasing the learning efficiency.
Higher Learning Rates
In a traditional deep network (i.e., a network with more than one layer), a high learning rate may result in:
- Exploding/vanishing gradients
- Getting stuck in local (rather than global) minima
Batch Normalization addresses these issues because normalizing activations means that:
- Backpropagation through a layer is unaffected by the scale of its parameters
- Changes in the parameters of one layer will not have an outsized effect on the next layer’s ability to learn its parameters
To further understand this, let’s use an analogy: picture a snowball gaining mass as it rolls down a hill. Without normalization, data just “snowballs” its way through a network, meaning that later layers may have to learn with really large (or really small) input values. With normalization, the output of a layer is mapped to fit within a certain range before being handed off as input to the next layer (in our analogy, maybe it’s an exceptionally sunny day, and the snowball is melting at the same rate it accumulates more snow).
Regularizing the Model
When training with Batch Normalization, the network does not produce deterministic values for a given example (i.e., the values produced are not the same every run) since it is seen along with other examples in the mini-batch. This has been shown to help the network generalize and reduce the need for dropout to reduce over-fitting.
Note: The authors of the paper note that the exact effects of Batch Normalization on gradient propagation remains an area of further study. For instance, the recommendation to not use dropout is simply based on their own observations. Many other machine learning algorithms also rest atop empirical evidence, sometimes more so than theory. ¯\_(ツ)_/¯
Accelerating Batch Normalization Networks
The following principles were established by the authors as a way of fully taking advantage of their proposed method:
- Increase learning rate
- Remove Dropout
- Reduce the L2 weight regularization*
- Accelerate the learning rate decay (i.e., how quickly the learning rate decreases in value)
- Remove local response normalization*
- Shuffle training examples more thoroughly
- Reduce distortions in the input data
🧽 What does “saturating nonlinearities” mean?
From the Wikipedia Page:
Saturation arithmetic is a version of arithmetic in which all operations such as addition and multiplication are limited to a fixed range between a minimum and maximum value.
Intuitively, something that is saturated — like a fully soaked sponge — cannot have any more added to it. Applying that analogy, a saturating function stops growing/shrinking as its inputs approach positive/negative infinity.
Nonlinearities refers to the activation functions we apply to the outputs of our neural network’s layers. Putting the two ideas of saturation arithmetic and nonlinear activation functions together, let’s look at some examples:
- The Rectified Linear Unit (ReLU) function is non-saturating because as x → ∞, f(x) → ∞. Using the sponge imagery, it would be like a sponge that can soak up an infinite amount of water — not very realistic.
- The sigmoid function is saturating because as x → ∞, f(x) → 1 and as as x → -∞, f(x) → 0. This would behave like an actual sponge — you can keep adding more water, but once it is completely soaked, it is unable to hold more.
📈 What does “zero means and unit variances” mean?
When we say “zero means and unit variances”, we are talking about data that has a mean of zero and a standard deviation of one. The process of converting data to meet these properties is called feature scaling.
For example, let’s say we have a black-and-white image that is 300 by 300 pixels, and each pixel can take on values in the range [0, 255] (e.g., pixel 1 is 128, pixel 2 is 60, pixel 3 is 207, …).
- To get a mean of zero, we will have to apply a negative shift such that the average value of all the pixels is zero. (This means that some values will become negative!)
- Then to get a unit variance, we will need to squash the new values to fit into the range [-1, 1].
Empirically, it has been found that gradient descent converges much faster with feature scaling that without it. So, machine learning engineers/scientists often pre-process data before handing it to a neural network.
High-level descriptions of some methods used to do that can be found on the Wikipedia page.
🏋️♀️ What is “L2 weight regularization”?
L2 weight regularization (aka “Ridge Regression”) adds the squared magnitude of coefficient as penalty term to the loss function. The goal is to induce weight decay, thus preventing overfitting.
Additional information can be found at this Towards Data Science article.
🔲 What is “Local Response Normalization”?
Local Response Normalization is another method for normalization of layer inputs that has largely been eclipsed by Batch Normalization now.
Additional information can be found at this Towards Data Science article.