Visualizing Gradient Descent with Momentum in Python

This post is to visually show that gradient descent with momentum can converge faster compare with vanilla gradient descent when the loss surface is raven-like (i.e., one direction is substantially steeper than other directions). Let’s start with creating a raven loss surface and project it to two-dimension.

L (x, y)= 1/16 x² + 9 y²

The below figure is the raven loss surface. For this loss surface, there is a global minimum at the point (0, 0) where the loss is equal to zero.

Loss surface for L (x, y)= 1/16 x² + 9 y²

Gradient Descent with Momentum

Gradient descent is an optimization algorithm which can find the minimum of a given function. In Machine Learning applications, we use gradient descent to navigate the loss surface. The problem with vanilla gradient descent is that the convergence time is long for certain loss surface, and one of the tricks is to use momentum to reach faster convergence.

The momentum method can accumulate velocity in the direction where the gradient is pointing towards the same direction across iterations. It achieves this by adding a portion of the previous weight update to the current one. We first show the math of momentum and then run some experiments to see if momentum can really help us to attain faster convergence.

The following equation is the gradient descent with momentum update. β is the portion of the previous weight update you want to add to the current one ranges from [0, 1]. When β = 0, it reduces to vanilla gradient descent.

v = βv + ∇L(w)
w = w − 𝛼v

Let’s first initialize our weights at (-2.4, 0.2) in the ravine loss surface we’ve created earlier. Next, try vanilla gradient descent (β = 0) with a learning rate 𝛼 = 0.1, run for 50 iterations and see how the loss decay.

Vanilla gradient descent, β = 0

The loss is around 0.1 after 50 iterations but it is still far from the global minimum (0, 0) where the loss is zero. We can see that the gradient keep changing directions because the gradient in the y direction is changing signs for every iteration, this makes the algorithm navigates slowly towards the optimum. We can add momentum to accelerate the learning speed. Let’s try β = 0.8 and run same number of iterations for weight updates.

Gradient descent with momentum, β = 0.8

We now achieve a loss of 2.8e-5 for same number of iterations using momentum! Because the gradient in the x direction always points to the positive x direction, the velocity can be accumulated. Effectively this means larger learning rate in the x direction so it can reach to lower loss for same number of iterations compare to vanilla gradient descent. The is also why momentum is said to be able to damp oscillation.

We might wonder what if we choose β = 0.9, wouldn’t it accumulates velocity more and reach to the global minimum faster?

Gradient descent with momentum, β = 0.9

The answer is no. The above figure shows that too much velocity will run pass the global minimum with large step. We can also plot the velocity-iteration relation to se how velocity changes in the x and y direction, respectively.

Velocity-iteration relation for weight x under different β. With β > 0, velocity in direction x accumulates.
Velocity-iteration relation for weight y under different β. Velocity in y direction does not accumulate because the sign of gradient changes in every iteration.

The python code and the steps to generate these figures can be found in the github repo. Matplotlib and NumPy are the required libraries.

Conclusion

We visually showed that gradient descent with momentum converges faster than vanilla gradient descent when the loss surface is raven-like. We also learn that β can not be too large for it will accumulate too much momentum and run past the optimal point quickly.

This post is motivated by Sebastian Ruder’s great post on gradient descent algorithm overview in which inspired me to visualize the momentum update using python and understand it more clearly, and David Robinson’s blog post to advice every data scientist to start a blog. This is my first post so any suggestions and comments are more than welcome!