Gradient Descent — A deep dive

A deep dive into the most important algorithm Gradient Descent with Practical Example using JAX.

Prateek Bhatt
Analytics Vidhya
10 min readJul 12, 2020

--

Initially, when I was introduced to this algorithm, it took me time to get answers to the following questions:

How does the gradient descent algorithm work?

What is does it mean by gradient?

How to calculate a gradient?

How are the parameters learned and how does a gradient help in learning?

In theory, I understood what it does, but I could not find a practical example that was good enough for me to understand and answered the above-mentioned questions. I got introduced to the Library JAX that can help in calculating the gradients. I thought to put up a toy example and experiment with it to better understand the algorithm.

The following article is structured such that it first explains the Gradient Descent algorithm along with its variants. Later it ends with a practical example using JAX. In the practical example, I try to show how the gradient descent algorithm functions and how the parameters are learned by the network. So let us begin.

Gradient Descent Algorithm — The theory

In order to use the Neural Network for solving real-life problems, the Neural network must learn the parameters on its own with the given data. For this reason, a learning algorithm is used.

Consider the task of predicting house prices where the task of a network is to output a price for the particular information related to the house e.g. size of a house, location of the house, etc. The network is provided with the data x and y. The input xis information related to the house and yis its corresponding price.

Here, the network has to map xwith yby learning a mapping function. The mapping function has weights as the parameters and these parameters have to be learned using an optimizer. The network will have more parameters as we increase the number of computational units (neurons). At first, the weights are initialized with random values, and during training, they are learned using the optimizer such as gradient descent optimizer.

Let’s consider an example, where our network has the following parameters:

Network Parameters

These parameters have to be learned in order to provide a better result. The input is x and y is the output. If the network gives output y'we can calculate how well the network performed by implementing a loss function. We can use different loss functions, but let us consider we are using Mean Squared Error (MSE). Using this function we can calculate the error, which our network made while predicting the house price.

Mean Squared Error [1]

The code to calculate the MSE is shown below:

Here n denotes the number of samples. The task of an optimizer is to reduce the error as much as possible and, to achieve that we take the partial derivative of the error function E. This is denoted by the sign ∇ and this is called the gradient. The partial derivative is taken w.r.t each parameter and is shown below:

Partial Derivative

Using the gradients calculated each parameter is then updated as shown below:

Gradient Descent Optimizer[4]

We can see here, that the parameters w at the time t + 1 will get updated based on the partial derivative of the error function at time t. The α denotes the learning rate with which the parameters will be learned. The learning rate is one of the hyperparameters. Parameters that are set before starting the experiment are called hyperparameters. The learning rate and the choice of optimizer have to be fixed before starting the experiment.

The code for updating the parameter will look something like this:

The performance of the gradient descent optimizer greatly depends on the amount of data it uses to optimize the error function at once. There are three different ways to use the gradient descent optimizer based on the amount of data it processes. The Batch gradient descent, stochastic gradient descent, and mini-batch gradient descent are the variants [3].

Batch Gradient Descent

This variant is also called vanilla gradient descent. In this variant, the whole training dataset is considered for minimizing the error function and calculating the gradient. As the whole training dataset is considered for just one update of the parameter, it brings time and memory issues. It takes a lot of time to update parameters and, furthermore the memory dependency is high which makes it difficult to fit the whole training dataset in memory. It cannot be used for online learning as the update happens for the whole training dataset and not a single training data [3].

Stochastic Gradient Descent

This variant updates the parameters for each training data and therefore can be used for online training as well. This variant is much faster and does not require high memory as only single training data is in the memory when calculating the gradient for the error function [3]. On one hand, this variant is promising, but on the other hand, it could be determined that the error convergence to a local/global minimum is not stable [3]. If the learning rate keeps on decreasing, this variant does seem to attain the same behavior as the vanilla gradient descent [3].

Mini-batch Gradient Descent

It uses the best of both the Stochastic and vanilla gradient descent algorithm. It minimizes the error function for a minibatch training dataset. Usually, the mini-batch size consists of 8,32,64,.. and so oneither one of these values. So the issue of unstable convergence is solved as the parameters are updated not for a single training data but for a mini-training dataset. It does give similar behavior in terms of convergence as the vanilla gradient descent [3].

Gradient Descent Algorithm — In Practice

Now we can check in practice how does the gradient helps the parameters to learn in the correct way. I feel a practical example is better to visualize the theory we have learned.

Consider the following toy dataset[2].

The above dataset can be mapped by a simple function y=2x-1 . It is a linear function. If we want to predict y for a given x the `predict` function will look something like this:

So this function predicts the outcome for the inputs given the parameters w and b . We will like to have our predictions accurate that matches the function y=2x-1 .

So let us create our loss function i.e, MSE function:

Let us plot the loss function. As mentioned earlier inputs are our xs . We will be keeping b constant to -1 for simplicity purposes:

Plot — loss function W = (-8,10) and b = -1

Keeping b constant at -1 and taking a range of (-8,10) for the W , we get the above plot for the loss function. You can see clearly from the plot that if W is 2 and b is -1 ; we get the minimum loss of 0 .

We can prove it by directly passing to the loss function (2,1) which forms y=2x-1 ; we will receive the loss as 0 .

From the above image and plot, you can get an idea of what we would want to achieve. We want to learn the Parameter W and b such that we receive the loss as 0 . The parameter values that fit our dataset the best is(2, -1) . This is the goal of learning, to identify the best parameter that fits the dataset.

For simplicity purpose, we will keep the parameter b constant to -1 and make the parameter W to be learned by gradient descent algorithm.

Now once we have cleared our goal, we will take a look at how to do it programmatically.

We will install the JAX package if we do not have it in our machine:

Get our imports straight:

Initialize our toy dataset:

Creating the predict , loss functions. Initializing the parameters W and b . As mentioned earlier we will keep the parameter b constant to -1 . We will initialize the parameter W to 8 and let it converge to 2 using the gradient descent algorithm:

The following code is the most important part of the code. We will now calculate the gradients with respect to W and update the parameter with the new value. We will do this repeatedly in a loop:

The API value_and_grad provides us with the gradient of the loss function w.r.t W . It also provides us with the loss value:

This W_grad is used to update the parameter:

As simple as that. Now let us see how we receive the output. I will just show the first 3 loop results and the last 3 loop results for clarity.

The first three-loop results are as follows:

The W started with the value 8 and after three loops it reached the value 6.32 . The loss value started with 186.0 and after three loops the loss reached 120.23 . You can see that both of them have a decreasing trend.

Now let us look at the last three loops:

Here the W tends to reach the value of 2 and the loss tends to reach the value of 0 .

Our parameter has been learned. We have reached the point where we had started with8 and now reached 2.025 for the parameter W . This is how the gradient descent algorithm helps in learning the parameter and decreases the loss.

One final plot to make it even clearer on our learning process. The below plot shows how did the learning takes place. The red line shows theloss function and the blue line shows how the learning of the parameter Wtook place:

Learning Reached to the loss of 0.

This blog showed and explained the gradient descent algorithm with a practical example. I have also uploaded the code on Github. Feel free to experiment with different initial values of the parameters.

I hope you enjoyed reading as much I did while creating this content. I appreciate your time and do give feedback on how you liked the article, it keeps me motivated to write more of these.

References:

  1. https://en.wikipedia.org/wiki/Mean_squared_error
  2. https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb
  3. Sebastian Ruder. “An overview of gradient descent optimization algorithms”. In: CoRR abs/1609.04747 (2016). URL: http://arxiv.org/abs/ 1609.04747.
  4. Marcin Andrychowicz et al. “Learning to learn by gradient descent by gradient descent”. In: CoRR abs/1606.04474 (2016). url: http : //arxiv.org/abs/1606.04474.

--

--