Gradient Descent — A deep dive
A deep dive into the most important algorithm Gradient Descent with Practical Example using JAX.
Initially, when I was introduced to this algorithm, it took me time to get answers to the following questions:
How does the gradient descent algorithm work?
What is does it mean by gradient?
How to calculate a gradient?
How are the parameters learned and how does a gradient help in learning?
In theory, I understood what it does, but I could not find a practical example that was good enough for me to understand and answered the above-mentioned questions. I got introduced to the Library JAX that can help in calculating the gradients. I thought to put up a toy example and experiment with it to better understand the algorithm.
The following article is structured such that it first explains the Gradient Descent algorithm along with its variants. Later it ends with a practical example using JAX. In the practical example, I try to show how the gradient descent algorithm functions and how the parameters are learned by the network. So let us begin.
Gradient Descent Algorithm — The theory
In order to use the Neural Network for solving real-life problems, the Neural network must learn the parameters on its own with the given data. For this reason, a learning algorithm is used.
Consider the task of predicting house prices where the task of a network is to output a price for the particular information related to the house e.g. size of a house, location of the house, etc. The network is provided with the data x
and y
. The input x
is information related to the house and y
is its corresponding price.
Here, the network has to map x
with y
by learning a mapping function. The mapping function has weights as the parameters and these parameters have to be learned using an optimizer. The network will have more parameters as we increase the number of computational units (neurons). At first, the weights are initialized with random values, and during training, they are learned using the optimizer such as gradient descent optimizer.
Let’s consider an example, where our network has the following parameters:
These parameters have to be learned in order to provide a better result. The input is x
and y
is the output. If the network gives output y'
we can calculate how well the network performed by implementing a loss function. We can use different loss functions, but let us consider we are using Mean Squared Error (MSE). Using this function we can calculate the error, which our network made while predicting the house price.
The code to calculate the MSE is shown below:
# preds = prediction made by our model. targets = actual ground #truth. n = number of samples.loss = np.sum((preds - targets) ** 2) / n
Here n
denotes the number of samples. The task of an optimizer is to reduce the error as much as possible and, to achieve that we take the partial derivative of the error function E
. This is denoted by the sign ∇ and this is called the gradient. The partial derivative is taken w.r.t each parameter and is shown below:
Using the gradients calculated each parameter is then updated as shown below:
We can see here, that the parameters w
at the time t + 1
will get updated based on the partial derivative of the error function at time t. The α denotes the learning rate with which the parameters will be learned. The learning rate is one of the hyperparameters. Parameters that are set before starting the experiment are called hyperparameters. The learning rate and the choice of optimizer have to be fixed before starting the experiment.
The code for updating the parameter will look something like this:
# here the parameter W will be updated with the gradient and learning
#rate. for example the learning rate here is 0.01.
W = W - ((0.01) * W_grad)
The performance of the gradient descent optimizer greatly depends on the amount of data it uses to optimize the error function at once. There are three different ways to use the gradient descent optimizer based on the amount of data it processes. The Batch gradient descent, stochastic gradient descent, and mini-batch gradient descent are the variants [3].
Batch Gradient Descent
This variant is also called vanilla gradient descent. In this variant, the whole training dataset is considered for minimizing the error function and calculating the gradient. As the whole training dataset is considered for just one update of the parameter, it brings time and memory issues. It takes a lot of time to update parameters and, furthermore the memory dependency is high which makes it difficult to fit the whole training dataset in memory. It cannot be used for online learning as the update happens for the whole training dataset and not a single training data [3].
Stochastic Gradient Descent
This variant updates the parameters for each training data and therefore can be used for online training as well. This variant is much faster and does not require high memory as only single training data is in the memory when calculating the gradient for the error function [3]. On one hand, this variant is promising, but on the other hand, it could be determined that the error convergence to a local/global minimum is not stable [3]. If the learning rate keeps on decreasing, this variant does seem to attain the same behavior as the vanilla gradient descent [3].
Mini-batch Gradient Descent
It uses the best of both the Stochastic and vanilla gradient descent algorithm. It minimizes the error function for a minibatch training dataset. Usually, the mini-batch size consists of 8,32,64,.. and so on
either one of these values. So the issue of unstable convergence is solved as the parameters are updated not for a single training data but for a mini-training dataset. It does give similar behavior in terms of convergence as the vanilla gradient descent [3].
Gradient Descent Algorithm — In Practice
Now we can check in practice how does the gradient helps the parameters to learn in the correct way. I feel a practical example is better to visualize the theory we have learned.
Consider the following toy dataset[2].
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
The above dataset can be mapped by a simple function y=2x-1
. It is a linear function. If we want to predict y
for a given x
the `predict` function will look something like this:
# From the function y = 2x - 1. Here x = W and -1 is b. A linear function.
def predict(W, b, inputs):
return (inputs * W) + b
So this function predicts the outcome for the inputs
given the parameters w
and b
. We will like to have our predictions accurate that matches the function y=2x-1
.
So let us create our loss function i.e, MSE function:
# Here W and b are the parameters and 6 is actually n. From our toy # example you can see that n is 6 for our dataset.
def loss(W, b):
preds = predict(W, b, inputs)
loss = np.sum((preds - targets) ** 2) / 6
return loss
Let us plot the loss function. As mentioned earlier inputs
are our xs
. We will be keeping b
constant to -1
for simplicity purposes:
import matplotlib.pyplot as plt
import numpy as np# let us check for differnt values of W how our loss function looks like.
w = np.asarray([-8, -3, -2, -1, 0, 1, 2, 3, 8, 10])
l = []for i in range(10):
# Keeping b constant to -1.
l.append(loss(w[i], -1))plt.plot(w, l, ‘-r’, label=’Loss Function’)
plt.title(‘Graph of Loss Function’)
plt.xlabel(‘Weight’, color=’#1C2833')
plt.ylabel(‘Loss’, color=’#1C2833')
plt.legend(loc=’upper left’)
plt.grid()
plt.show()
Keeping b
constant at -1
and taking a range of (-8,10)
for the W
, we get the above plot for the loss function. You can see clearly from the plot that if W
is 2
and b
is -1
; we get the minimum loss of 0
.
We can prove it by directly passing to the loss function (2,1)
which forms y=2x-1
; we will receive the loss as 0
.
From the above image and plot, you can get an idea of what we would want to achieve. We want to learn the Parameter W
and b
such that we receive the loss as 0
. The parameter values that fit our dataset the best is(2, -1)
. This is the goal of learning, to identify the best parameter that fits the dataset.
For simplicity purpose, we will keep the parameter
b
constant to-1
and make the parameterW
to be learned by gradient descent algorithm.
Now once we have cleared our goal, we will take a look at how to do it programmatically.
We will install the JAX package if we do not have it in our machine:
pip install --upgrade jax jaxlib
Get our imports straight:
import numpy as np
import jax.numpy as np
from jax import grad
from jax import value_and_grad
Initialize our toy dataset:
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
Creating the predict
, loss
functions. Initializing the parameters W
and b
. As mentioned earlier we will keep the parameter b
constant to -1
. We will initialize the parameter W
to 8
and let it converge to 2
using the gradient descent algorithm:
def predict(W, b, inputs):
return (inputs * W) + b# Setting the inputs and targets.
inputs = xs
targets = ys# Creating the loss function. n = 6 as our dataset as 6 elements.
def loss(W, b):
preds = predict(W, b, inputs)
loss = np.sum((preds - targets) ** 2) / 6
return loss# Finally initialize our parameters.
# Initalizing W to 8.
W = numpy.array(8, dtype=numpy.float32)
# keeping b constant to -1.
b = numpy.array(-1, dtype=numpy.float32)
The following code is the most important part of the code. We will now calculate the gradients with respect to W
and update the parameter with the new value. We will do this repeatedly in a loop:
# Keeping the W parameter history.
W_array = []
# Keeping the loss history
L_array = []# Looping 50 times.
for i in range(50):
print('Loop starts -----------------------------------')
print('W_old', W)
# Appending the W to a list for visualization later.
W_array.append(W)
# Calculting the gradient with respect to W.
loss_value, W_grad = value_and_grad(loss, 0)(W, b)
L_array.append(loss_value)if loss_value != 0:
# learning rate is 0.01.
W = W - ((0.01) * W_grad)
print('W_new', W)
print('loss is ', loss_value)
print('Loop ends ------------------------------------')
else:
print('W', W)
print('loss is ', loss_value)
print('Loop ends ------------------------------------')
break
The API value_and_grad
provides us with the gradient of the loss function w.r.t W
. It also provides us with the loss value:
# Calculting the gradient with respect to W.
loss_value, W_grad = value_and_grad(loss, 0)(W, b)
This W_grad
is used to update the parameter:
W = W - ((0.01) * W_grad)
As simple as that. Now let us see how we receive the output. I will just show the first 3 loop results and the last 3 loop results for clarity.
The first three-loop results are as follows:
Loop starts -----------------------------------
W_old 8.0
W_new 7.38
loss is 186.0
Loop ends ------------------------------------
Loop starts -----------------------------------
W_old 7.38
W_new 6.8240666
loss is 149.54607
Loop ends ------------------------------------
Loop starts -----------------------------------
W_old 6.8240666
W_new 6.3255796
loss is 120.236694
Loop ends ------------------------------------
The W
started with the value 8
and after three loops it reached the value 6.32
. The loss value started with 186.0
and after three loops the loss reached 120.23
. You can see that both of them have a decreasing trend.
Now let us look at the last three loops:
Loop starts -----------------------------------
W_old 2.0356295
W_new 2.0319479
loss is 0.0065588956
Loop ends ------------------------------------
Loop starts -----------------------------------
W_old 2.0319479
W_new 2.0286465
loss is 0.005273429
Loop ends ------------------------------------
Loop starts -----------------------------------
W_old 2.0286465
W_new 2.0256863
loss is 0.004239871
Loop ends ------------------------------------
Here the W
tends to reach the value of 2
and the loss tends to reach the value of 0
.
Our parameter has been learned. We have reached the point where we had started with8
and now reached 2.025
for the parameter W
. This is how the gradient descent algorithm helps in learning the parameter and decreases the loss.
One final plot to make it even clearer on our learning process. The below plot shows how did the learning takes place. The red line shows theloss
function and the blue line shows how the learning of the parameter W
took place:
import matplotlib.pyplot as plt
import numpy as np
w = np.asarray([-3, -2, -1, 0, 1, 2, 3])
l = []for i in range(7):
l.append(loss(w[i], -1))plt.plot(w, l, '-r', label='Loss Function')
plt.plot(W_array, L_array, '-b', label='Parameter learning')
plt.title('Plot learning and loss function')
plt.xlabel('Weight', color='#1C2833')
plt.ylabel('Loss', color='#1C2833')
plt.legend(loc='upper left')
plt.grid()
plt.show()
This blog showed and explained the gradient descent algorithm with a practical example. I have also uploaded the code on Github. Feel free to experiment with different initial values of the parameters.
I hope you enjoyed reading as much I did while creating this content. I appreciate your time and do give feedback on how you liked the article, it keeps me motivated to write more of these.
References:
- https://en.wikipedia.org/wiki/Mean_squared_error
- https://github.com/lmoroney/dlaicourse/blob/master/Course%201%20-%20Part%202%20-%20Lesson%202%20-%20Notebook.ipynb
- Sebastian Ruder. “An overview of gradient descent optimization algorithms”. In: CoRR abs/1609.04747 (2016). URL: http://arxiv.org/abs/ 1609.04747.
- Marcin Andrychowicz et al. “Learning to learn by gradient descent by gradient descent”. In: CoRR abs/1606.04474 (2016). url: http : //arxiv.org/abs/1606.04474.