Stochastic Gradient Descent in Deep Learning

Biswadip Mandal
Analytics Vidhya
Published in
5 min readApr 14, 2021

--

Neural Network often consist of millions of weights which we need to find the right value for. Optimizing this networks with available data needs careful consideration of the optimizer to be chosen. In this post, I will discuss Gradient Descent and the Stochastic Gradient Descent(SGD) and why SGD is preferred in deep learning

for k in [1, number_iterations]:
X(k+1) = X(k) - α ▽L(X(k))

This works well for appropriate selection of α for convex functions. But, one of the most desirable properties of deep neural networks are that they are universal approximators. Which means they should be able to cover non-convex functions as well. The problem with non-convex functions is that your initial guess might not be near the global minima and gradient descent might converge to a local minima. Consider the following case:

But wait, are really looking for the global minima in deep learning?

We are optimizing the loss or cost function which most of the time represent some distance between the actual output and the prediction. We want to minimize the loss function, but at the same time we don't want it to be too close to zero as that would often lead to over-fitting. We want it to work well on the unseen data and hence needs generalization. What we are looking here is for flatness.

Good local minima vs bad local minima
Rather than desiring the global minima, we look for `good local minima` as opposed to a `bad local minima`. Good local minima often refers to a minima with more flatness.

How do you define flatness?
Well, it can be very intuitive to understand in 2D or 3D by just looking at the curves. In general flatness has to do with smaller eigenvalues of the Hessian of the function. Greater the eigenvalues, more the curvature and more chances of a critical point being a sharp minima.

Stochastic Gradient Descent:
Stochastic Gradient Descent(SGD) replaces the costly operation of calculating average loss over whole dataset by drawing a random sample and calculating the loss for that on each iteration. This in turn changes the convergence behaviour as compared to gradient descent.

for k in [1, number_iterations]:
pick a random datapoint di;
calculate loss Li
X(k+1) = X(k) - α ▽Li(X(k))

An intuition as to why flat minima is better generalization
In previous work, Hinton & Van Camp(1993), Hochreiter & Schmidhuber gave an argument that flat minima requires less information to describe and hence should generalize better that the sharp ones.

Sharp minima has very few incorrect predictions on the train data. However, a small change in the network parameters can change the outputs by a lot. For a classification task, this means the bounderies are very close the points. In contrast, in flat minima the boundary is on a safe distance from the points, and a small change in network parameters won't lead to a bad accuracy making it a better generalization of the actual function we are trying to approximate. Have a look the following visualization from the paper: Understanding Generalization through Visualizations

Left: Classifier boundary on a flat minima; Right: Classifier boundary on a sharp minima

Before we jump into SGD converging to flat minima, let’s understand the following optimization technique for global optimization.

Simulated Annealing
This is a global optimization technique. Simulated Annealing often takes steps that do not match with the gradient direction. Taking these abnormal steps depends on something called energy. There is greater energy at the beginning which makes it take abnormal steps. These abnormal steps are quite necessary as they make it possible to get out of the local minima well and explore the horizon for a global minima. The energy decreases with the number of iterations and the algorithm tends to take less abnormal steps and ultimately converges to a minima. This has very less chances of getting stuck in a local minima as compared to Gradient Descent.

To abstract away from the messy implementation(as the actual implementation is more detailed), consider a helicopter has been assigned the task to find and land on the deepest valley on a mountain range and has been given limited fuel. The helicopter would definitely like to arbitrarily explore the mountain range at first, but with time it will try to avoid arbitrary exploration and will try to go deeper in the valley it is in keeping the fuel limit in mind. The initial arbitrary exploration is what makes it possible for the helicopter to get out of the local minima and explore descents beyond that.

Even though SGD looks for a flat minima rather than a global minima, there are a few similarities between SGD and Simulated Annealing which makes them escape local minima.

Why SGD reaches a flat minima?
At the initials iterations SGD carries a lot of noise(similar to simulated annealing), and the gradient can vary a lot at each iteration. This behaviour lets the SGD algorithm act like a Simulated Annealing . The loss being very random at the initial steps carries the energy to get out of the local minima well. The flat minima has greater area to be explored. Think of the helicopter trying to find a landing stop in a huge mountain range. Obviously the Valley with greater area(hence greater flatness) will have better chances of containing a landing stop.

There are some other advantages of using SGD:

  • SGD converges fast when examples are similar. If there are duplicate or similar datapoints, SGD can converge much faster as after sometime it’s optimizing for points it has already seen.
  • SGD can train on the fly. If your data is not readily available and coming in stream, SGD can be used.
  • One of the important benefits of SGD is that you process one datapoint each iteration. This is really helpful when you have large datasets that can't hold all the data into RAM(I haven't worked with any such dataset. Maybe I haven't worked with large enough datasets yet).

References:

  1. Huang, Emam, Goldblum, Fowl, Terry, Huang, and Goldstein. Understanding generalization through visualizations. 2019
  2. Hinton & Van Camp(1993)
  3. Hochreiter , Schmidhuber. Flat Minima 1997

--

--

Biswadip Mandal
Analytics Vidhya
0 Followers

Data Scientist. Pursuing master's in computer science