Hyper Parameter—Momentum

Yuanrui Dong
AI³ | Theory, Practice, Business
4 min readSep 9, 2019

When we use the SGD (stochastic mini-batch gradient descent, commonly known as SGD in deep learning) to train parameters, sometimes it decreases very slowly and may fall into the local minimum value, or even zero, as shown in Fig 1 (the picture is from li hongyi, 《one day understanding deep learning》). But it is not always what you want.

Fig 1. SGD without momentum

There are three places where this pseudo-optimal solution occurs:

  1. Plateau;
  2. Saddle point;
  3. Local minima.

Momentum is introduced to speed up the learning process, especially for the gradient with high curvature, small but consistent, which can accelerate the learning process. The main idea of momentum is to accumulate the moving average of previous gradients decaying exponentially.

Fig 2. SGD with momentumn

It doesn’t necessarily mean that it’s going to be optimal using momentum. To illustrate these points in Fig 2 (the picture is from li hongyi, 《one day understanding deep learning》), in which red represents the direction of gradient descent, dotted green represents the direction of momentum, and blue represents the direction of actual movement :

For the first point. The direction of the gradient descent is to the right, but since we set it, there is no momentum at the beginning, so the actual direction of movement is the direction of the gradient descent.

For the second point. The gradient is going down to the right, but now the ball has a momentum that is moving to the right, which causes the ball to continue moving to the right.

For the third point. Since it is a local minimum, the gradient value is 0. If it’s normal gradient descent, it’s going to get stuck here. But we also have an impulse to the right, so using momentum, it’s actually moving to the right.

For the fourth point. Now our gradient descent is going to the left, and we can assume that if the impulse here is bigger than the gradient is going to be. At this point, the ball will continue in the direction of momentum, and it can even get out of the mountain, and get out of the local minima.

To better understanding momentumn, we create a random signal, and plot the figure when betas are 0.5,0.7,0.9,0.99.

>>> x = torch.linspace(-4, 4, 200)
>>> y = torch.randn(200) + 0.3
>>> betas = [0.5, 0.7, 0.9, 0.99]

The regular momentum function is shown below:

res = beta*avg+yi

Including that res represents the result, beta is the value of momentumn, avg is 0.3 in the axample, and yi is the ith number.

Fig 3. Random result with regular momentum function

The result is shown in Fig 3. When the beta is small, it gets a little bumpy. When beta is 0.99, the answer is totally wrong. The reason is that if the momentum is high, the basics you’re away from where you need to be in weight space. It’s literally biased to end up being a higher gradient than the actual gradient.

So we can fix that with modifying the momentum function:

res = beta*v1 + (1-beta)*v2.

It’s quite an exponentially weighted moving average as we know. It dampens the thing that we’re adding in. We can see the result is shown in Fig 4.

Fig 4. Random result with the modified momentumn function

We can see it gets to a zero-constant when the data is purely random.

But in fact, we always deal with the signal that is specific, does that work well? To find out it, we create a new function here. The momentum is res = beta*v1 + (1-beta)*v2.

>>> y = 1 — (x/3) ** 2 + torch.randn(200) * 0.1
>>> y[0]=0.5

At the start, we add an outlier. The result is shown in Fig 5, it shows we’ve got in trouble. When beta is 0.99, in the momentum function, item number two is 0.99 times item number one plus 0.01 times item number two. The number one is massively biasing this data, it takes a very long time to close to the number.

There is also a problem that we’re always running a bit behind where we should be because we always only taking point one times a new thing.

Fig 5. Function result with momentum

To get a better result, we can use debiasing. Debiasing is to correct the wrong information we may have in the very first batch.

And the function is:

res = avg/(1-beta**(i+1))

The result is shown in Fig 6. It’s pretty good, debiasing is very quickly even if you have a bad starting point.

Fig 6. Function result with debiasing

--

--