Optimization Challenges in Deep Learning

Tejovk
11 min readAug 5, 2023
Image via: https://d2r55xnwy6nx47.cloudfront.net/uploads/2022/01/Gradient-descent.jpg
Image via: link

We have been using optimization methods for many years, and they have typically produced the desired results. However, there are some hairline cases where we may encounter difficulty in implementing these algorithms. In this blog, we will discuss these cases and explore some strategies for overcoming them.

Optimization in the case of convex functions is a relatively straightforward task. In machine learning, for example, the loss function for linear regression is the mean squared error (MSE), which is always convex. This means that there is only a single optimal minimum which is the global minimum in the entire MSE function, and we do not need to worry about the existence of local optima.

However, in the case of neural networks, we have many parameters (weights and biases) to train in order to get the optimal solution and minimize the loss. This is because neural networks are composed of many non-linear functions that are stacked on top of each other. The composition of non-linear functions can produce a non-convex loss function. Non-convex functions have many local minima, out of which only one is the global minimum. Our goal is to reach the global minimum by avoiding these local minima.

The Initial Guess Problem

Choosing the right hyperparameters is one of the crucial steps for optimization. The learning rate is one such hyperparameter that influences the convergence rate. If we choose a too-small learning rate it would require many updates before reaching the optimum, and if we choose a too-big learning rate it might lead to divergence and it will never converge. Hence our goal is to find the right learning rate which will get us to the optimum in the right amount of steps. Luckily we have many libraries/ frameworks which help in selecting the best hyperparameters like optuna, hyperopt, gridsearchcv, Ray-Tune, and many more.

Image via: link

The learning rate selected also has an effect on escaping local optima. Local optima are generally surrounded by high gradients, which makes it difficult to escape them. A high learning rate can cause the model to overshoot the local minimum and continue to search for a better solution. However, a low learning rate can cause the model to get stuck in the local minimum.

Image via: Link

After tackling all these challenges, it is never guaranteed that we will stay at the global optimum once we reach it. This can be better understood with an example. When we train a model using epochs, we generally observe that the loss in each epoch decreases. This means that the parameters are converging to the optimum. However, this is not always the case. There are cases where the loss might increase. This is why we say that it is never guaranteed to stay at the optimum once we reach it.

The Local Optima

Before we dig deep into this topic, we must understand one thing: real-world problems are often associated with noise and uncertainty. This leads to a more challenging optimization landscape, where it may not matter if we are in a local optimum (with negligible error) or the global optimum. In such cases, pursuing the global optimum may be less critical because the solution may be sensitive to the noise, and a local optimum may still yield satisfactory results.

Another aspect to keep in mind is the trade-offs and practicality. In some situations, the difference between the local optimum and the global optimum may be negligible in terms of practical outcomes. Achieving the global optimum may involve significant effort and cost, while the improvement gained may not justify the investment.

In the above image, we might consider local minima at c instead of the global minima, but we must not choose the local optima at a and b since they are far flung from the optima. Image via: Link

Plateau and Saddle Points

The fundamental logic behind most optimization algorithms is to stop weight updates the moment we reach a point where the gradient or derivative is zero. This is because the gradient or derivative represents the direction of the steepest descent, and if the gradient is zero, then there is no further descent to be made. That’s the reason why we have the formula of descent algorithms in the form of:

Convergence will stop the moment the gradient reaches 0.

Typically, we want our descent algorithm to converge and stop at the global minimum, which has a gradient of 0. However, during this entire procedure, we encounter a few obstacles in the form of plateaus and saddle points. Let’s briefly take a look at these:

Plateau: these are regions on the loss surface where the gradient of the loss function becomes very small or nearly flat. In other words, the loss function has a very shallow slope in multiple dimensions, making it difficult for the optimization algorithm to make significant progress toward the minimum. When the gradients are close to zero, the optimization process slows down, and it may take a long time to converge to the optimal solution. Plateaus are particularly problematic in deep learning because neural networks can have a large number of parameters, and the likelihood of encountering flat regions increases with the model’s complexity.

It is evident in the image that the gradient in the plateau region is negligible hence, it takes a large amount of time to converge or, it may end up stopping there.

Saddle points: are critical points on the loss surface where the gradient is zero, and the surface is relatively flat in some directions and steep in others. It is not an extremum (neither a minimum nor a maximum).

The point of intersection of the red and the green curves represents the saddle point, we can observe one thing over here, the saddle point is the point of minima for the green curve, and at the same time, it is the point of maximum for the red curve. Image Via: Link

At a saddle point, the loss function may have both directions of ascent and descent, making it challenging for the optimization algorithm to decide which way to proceed. As a result, the optimization algorithm can get stuck at the saddle point and struggle to find the path leading to the global minimum.

Let’s now talk about a few algorithms which help us avoid getting struck at these points.

Regular gradient descent algorithms do not optimize further if they reach a point with a gradient of 0. To address this, we can introduce a momentum term into the algorithm. Momentum takes into account the gradients of the previous iterations, which helps the algorithm converge faster even with a relatively smaller learning rate. Additionally, momentum helps to prevent the algorithm from diverging or missing any optimal points. Hence the new momentum factor helps us in two ways, it improves robustness and results in faster convergence.
Examples of such momentum-based algorithms are Stochastic Gradient Descent with Momentum and Nesterov Accelerated Gradient (uses look ahead gradient).

Vanishing Gradient Descent

Before getting into the topic of vanishing and exploding gradient descent we must have a thorough understanding of the Chain Rule of Back Propagation. So let us start with that:

The notations:
fₐᵦ represents the bᵗʰ neuron of the aᵗʰ layer, this function applies the activation function on the weighted sum of features and adds a bias to it.
Oₓᵧ represents the output corresponding to the fₐᵦ neuron.
Wₓᵧᵃ
represents the weight between the xᵗʰ neuron and yᵗʰ neuron of the aᵗʰ hidden layer.

By understanding the above network properly you will be able to correlate the relationships between each neuron properly.

The general optimization formula is given below, and we need to work on the derivative of L i.e., loss.

Let’s try to break down the derivative of loss into smaller derivatives. We know that the loss value is calculated using the predicted value of the neural network. This predicted value is dependent on weight W³ 11, that’s why we find the change in the loss of the function with respect to the weight W³ 11.

From the network in the figure, we can observe a relationship between O21, O31, and W³ 11. This is because if the output of the f21 neuron changes, the sum of weighted features sent as input to f31 changes, which results in changing the output O31.

Keeping all these things in mind, we can finally formulate a derivative chain that represents the initial derivative of loss and W³ 11.

The chain rule.

Now that we have understood the chain rule of back-propagation, we can talk about the problem of vanishing gradient descent.

The problem of vanishing gradients occurs during optimization when we use activation functions like the sigmoid or tanh (basically all the activation functions that convert the input value to a narrow range, for example, between 0 and 1).

From the chain rule, we know that the final output depends on the output generated by the neurons of the hidden layers. This is where the actual problem arises. The derivative of the sigmoid activation function lies in a very narrow range, from 0 to 0.25. If we use such activation functions in the hidden layers, we end up multiplying many such values between 0 and 0.25, which can result in a negligible value at the derivative of loss with respect to the W_old.

Proof for the derivative of sigmoid ranging between 0–0.25:

Images Via: Link

The resulting graph of sigmoids derivative:

Image Via: Link

Now with such a minimal derivative value no matter how large the learning rate we take, we will never be able to converge fast and might end up stopping at some random point on the optimization curve scary, isn’t it?

That’s the reason why we don't use sigmoid activation in hidden layers and instead use ReLU, even ReLU has issues like dead activation which is removed by using leaky ReLU, which we won’t discuss here. That’s all about the vanishing gradient descent, let’s now start exploring the exploding gradient descent.

Exploding Gradient Descent

It is a challenge we face during optimization, it’s caused due to the initialization of weights (by initializing weights with a higher value).

Let’s take the example of the below-given network:

In the above network, Z is the sum of the weighted features and the bias. Therefore Z = W²11.O11 + Bias, for the sake of simplicity let's assume bias to be 0. Also, O21 will now become σ(Z), we are taking sigmoid as an activation function.

On finding the derivative chain from the backpropagation chain rule it will be having a term dO21/dO11, which can further be expanded as:

since O21 = σ(Z)

Let’s now explore the derivative of Z wrt O11:

We end up getting the value as the weight W² 11 itself.

Now if the chosen weight is very high, the overall derivative increases, hence there will be a huge difference between the new weight and the old weight, and now you know what’s happening… the optimizer will never converge and will end up oscillating near the optima, this is the exploding gradient problem.

So we can clearly say this problem arises due to the loopholes in the weight initialization, this can be handled by using several weight initialization techniques like Xavier/Gorat initialization in the case of the Sigmoid activation function and He init in the case of the ReLU activation function.

Ill-Conditioned Hessian Matrix

Hope you guys are aware of the terms Hessian and the condition number of a matrix, if not then let me give you a gist of it. A Hessian matrix can be defined as the second-order partial derivative of a function, it is a square matrix with the number of rows and columns equal to the dimension or the number of variables in the function.

In mathematical terms, we can also define the Hessian matrix as the Jacobian of the gradient of the function. Where Jacobian again is the first-order partial derivative of a function.

Talking about the condition number of a matrix, it simply tells us the characteristics of the matrix in terms of further computations and calculations, or formally it can be defined as a measure of how much the output value of the function can change for a small change in the input argument.

It is mathematically defined as the ratio of the maximal and minimal eigenvalues of the Hessian matrix or also as the product of the norms of the matrix and its inverse.

A matrix is said to be Ill-conditioned if the condition number is very high, so for a small change in the input function/the Hessian matrix we will end up getting outputs with high variance. Can you co-relate this property with neural networks? In the case of neural networks too if we make a minute change in the parameters, we end up getting an entirely different performance measure value.

Condition Number and Convergence Rate:

  1. Good Condition Number: If the condition number of the Hessian is small (close to 1), it means that the eigenvalues are well-scaled (meaning, eigenvalues are comparatively close to each other and are scalable), and the Hessian is well-conditioned. In this case, the optimization algorithm can converge quickly since the direction and magnitude of the steepest descent align well with the local curvature of the objective function.
  2. Ill-Condition Number: On the other hand, if the condition number of the Hessian is large, it indicates that the eigenvalues are significantly different in scale, and the Hessian is ill-conditioned. In such cases, the optimization algorithm may take longer to converge or might even struggle to converge properly. Ill-conditioned Hessians can lead to slow convergence or oscillations in the optimization process.
Optimization in case of well conditioned Vs ill-conditioned hessian of a function Image Courtesy: Link

We can also use this Hessian matrix to determine whether a point is a minima, maxima, or saddle point. If all the eigenvalues of the Hessian matrix are positive then the point is a minima, if all the eigenvalues of the Hessian matrix are negative then the point is maxima, and if the eigenvalues are mixed i.e., both positive and negative then the point is a saddle point.

“In the world of neural networks, the road to convergence is paved with the potholes of optimization challenges.” — Anonymous

POST SCRIPT

This blog post has discussed the challenges faced during optimization. I will be describing the different optimization algorithms for some of these cases in another blog post soon.

REFERENCES:

https://cedar.buffalo.edu/~srihari/CSE676/8.2%20NNOptimization.pdf

https://www.youtube.com/playlist?list=PLZoTAELRMXVPGU70ZGsckrMdr0FteeRUi

--

--

Tejovk

A machine learning and AI enthusiast | I love developing models that help uncover insights and patterns in datasets.