13. Introduction to Deep Learning with Computer Vision — Learning Rates & Mathematics — Part 1
Written by Nilesh Singh and Praveen Kumar
Having spent enough time on Activation functions, Batch Normalization, Convolution types, Data Augmentation, and Receptive fields, we now understand how models use different strategies to learn the features of a class and also improve its learning by understanding more complex features. However, having said that the model learns these features and all works without a snitch, we will now try to gain a deeper understanding of how the model learns.
We will dedicate a few articles (starting with this one) and continue to learn about the learning rates, Mathematics behind learning rates, popular learning rate methods & finally the backpropagation algorithm along with its mathematics as well.
Let’s get started…
What is the Learning Rate?
It is essentially a hyperparameter that tells our model, how fast should it try to capture features and learn about a class. Intuitively, let’s say we have a 1-year-old baby and our task is to teach the baby how to recognize an apple.
Now imagine you are holding the apple in front of the baby for 1 second and hope it has seen and learned all the features of an apple, or you can hold for 1 minute and hope that this time the baby has learned an apple, or even a day or month. Based on the different duration of the apparent holding of our metaphorical apple, the baby will be able to learn different features of the apple. This rate of us showing the apple to the baby is simply the learning rate parameter.
Now you might feel that the longer we show the apple, the better the baby learns the features. However, it is not the case. The more time we show the apple to the baby, the less slowly it learns the features. Hence it will take more time for the model to learn all the features and we can not keep on waiting for it to learn for ages just to recognize an apple. What happens when we just show the apple for 1 second? In this case, the baby will feel hard done by. It is just 1-year-old and asking it to learn all the features of an apple within 1 second is very harsh. We have other issues as well based on learning rate values. We shall discuss them in detail as we move deeper.
Hence, choosing the right learning rate is very crucial for the model.
NOTE: In terms of models, a learning rate of 0.0001 is considered to be very slow (analogous to 1 month in the above example) and 0.9 is considered to be fast (analogous to 1 seconds in the above example). Smaller the value, slower the learning phase, and vice-versa.
Let’s now put on our technical hats and learn what the fuss about the learning rate on CNN is and what happens in the background.
Learning rate & Loss function relation:
The loss function is a method of evaluating how well your algorithm models your dataset. That is to say that how well your equation (network) fits the data at hand. If the equation perfectly fits the data being trained on, then the loss will be zero. As a rule of thumb, we don’t want that, because in most cases that would mean a horrible horrible case of ‘overfitting’, or a first step in the creation of Ultron. So, in general, for practical cases, we expect our loss to be approaching zero, but we’ll be wary of it ever actually becoming zero.
If your model’s predictions are off, the loss function will output a higher number. If they’re pretty good, it’ll output a lower number. What we mean by ‘predictions being off’ is the difference between the predicted model output and the actual expected output. This loss function calculates these values based on a few mathematical formulas, for example, the cross-entropy loss, the mean-squared error, the Huber loss, and the hinge loss. If you want a deeper understanding of loss functions, do check out this excellent blog.
Now things start getting interesting at this point. Let’s take a moment and recall what has happened to this point.
We have our data at hand, then we wrote a network to fit that data, we enter the training phase, and after each subsequent forward pass, we get a model that is trying hard to fit the data. We then calculate the accuracy of this model by calculating a loss function, which is essentially a function of predicted output and the expected output.
Alright, now let’s move forward.
From Fig 1, we understand, that during the training phase, the loss will increase or decrease based on the learning rate parameter.
Learning rate is essentially a value between 0 and 1, which signifies the weight that each calculated loss function carries when the network backpropogates. One can think of it as the magnitude of weight updates that happen during backpropogation. It is essentially a hyper-parameter that is used to minimize the loss function.
If the learning rate is very high, then loss starts to rise exponentially. Each learning rate has its pros and cons. One could think a low learning rate will be best, however, it will take a lot of time for you to train such a network (Think smaller weight updates, so slower localization). Hence we look for a good learning rate value. There are several methods to help you find out the optimal learning rate. We shall discuss them as we move along.
Now let’s try to understand the loss function more deeply to have a fair idea of why a small variation in learning rate value could be disastrous or beneficial for the model.
When we talk about the learning rate, we always refer to ways by which our weights are updated. Eventually, it’s the weights that matter to the model. These weights are updated after each epoch. Here is a simple weight equation:
updated_weight= old_weight - learning_rate * derivative_of_loss
Mathematically, it can be represented as
In this equation, W is weight, α is the learning rate and J is the loss function which is partially derived wrt. weight. The last term is called a gradient. The learning rate is multiplied by the derivative of the loss function. If the learning rate is high, the weights are decreased (or increased, depending on the loss) by a large value or if the learning rate is low, the weights are updated very slowly. Let’s visualize this equation along with our loss function curve.
In the above Fig, the blue curve is imagined to be the loss curve (y-axis). So, we would want to reach the bottom of this curve to have the minimum loss, this is called minima. Along the x-axis, we have weights.
Now let's visualize scenarios for both high and low learning rates.
As you can see, in the first part of the figure, the learning rate is too small, so the convergence takes a lot of time aka steps. This is not the only problem with this approach as we will see further. In the second part of the figure, the learning rate is too large, so our weights are jumping around instead of approaching the global minima.
We can visualize these mountainous valleys, where we want to get to the bottom of the valley by navigating the tough landscape that lies ahead of us. In relation to the above analogy, let’s try and understand some of the common terminologies that you will be using throughout your journey of deep learning:
- Local vs Global minima:
Let’s assume that you started your journey at point A, then followed the path down to point B. Now when you look around you, it can be safely concluded that you have reached the absolute bottom of our valley, right? Hold that thought for a moment.
Now when we see the whole picture, point B is definitely not the point where we want to be, in fact for all essential purposes we are trapped. Point B is called the local minima, and point E is called the global minima. Our goal should always be to reach the global minima. But it is generally not practical to reach a global minimum, so we at least target a local minimum that is close to the global one.
One interesting thing to note here is that if our learning rate is small, then there is no way we can come out of our local minimum B.
2. Saddle Points:
If on your journey to the bottom of the valley, you find yourself stuck at a point like that red dot in the figure, then you are stuck at something called a saddle point.
Saddle points are points that have a zero gradient but are not an extremum. That is just a fancy way of saying that, a saddle point is not an extreme point, it is neither a maxima nor minima. In the figure, if you move to the left or right, the loss increases. While for the other two directions, the loss decreases.
Saddle points are much more difficult to escape. You might think that, what’s the difficulty in that, just go in any of the directions where loss decreases, right? Well…..just look at point D in figure 6 above.
That’s it for this blog, in the next blog, we’ll cover ways to adjust our learning rates so that our model can escape the loss plateaus.