How do neural networks learn so well?
Remember that time when DeepMind created an AI model capable of beating the world champion of the game Go? What about the time when everyone got hooked on posting deep fake videos, all generated by AI? If you don’t remember either of those, you’re perfectly excused. Since then, so many other breakthrough use cases have emerged for AI models. Take, for instance, the tool created to predict the shape of every known protein, or the one created to generate images based on textual prompts, or more recently, the internet’s newest chatbot obsession, ChatGPT, now powering some of Microsoft’s product suite.
Every single one of the breakthroughs above— and many more — share a commonality: they were all created with some flavor of a ‘neural network’, an advanced class of AI architecture that typically requires lots of data and computing resources. But one of the secret sauces of neural networks lies not in their great scaling ability or large sizes, but in a clever learning algorithm they use called gradient descent. Without this algorithm, it’s possible neural networks could have never been deployed at the scale they have — and problems in data science, medicine, and business may never have been solved. In this piece, we will discuss what gradient descent is, how it works, and along the way, we’ll learn a few things about loss functions and partial derivatives. Buckle up!
How do AI models actually learn?
To talk about gradient descent, an algorithm that helps with learning, we have to start with a discussion about how AI models learn to begin with. Artificial intelligence (AI) is a broad term used to describe systems that can intelligently make predictions or decisions that, to some extent, emulate complex human cognition. An AI system can be deterministic, that is, built with precise rules that will always lead to the same outcome as long as the inputs are the same. But more often than not, AI systems are stochastic, that is, built with probabilistic models that can give different outputs even when the inputs are the same. When you read about AI these days, you are probably going to be reading about stochastic models, or rather, by their 21st-century name: machine learning. Neural networks, by the way, are great examples of machine learning models.
The great thing about machine learning is that we have a clear mechanism to improve the model’s accuracy —through learning. Of course, models don’t have artificial brains that can remember things and apply them later. But models can still learn — it’s just that what we mean by learning here has a different meaning. In machine learning, we refer to ‘learning’ as the process where we change different aspects of the model and calculate how those changes impact the model's accuracy. Since our goal is to have the most accurate model possible, successful learning is one where parameters change until accuracy is maximized.
There are different ways to measure accuracy. In some use cases, it may be more important for the model to prioritize minimizing false positives, in other cases, it may be more important for the model to minimize false negatives. And, of course, there are many cases where the outcome is not binary — the price of a house, the probability of an event happening, and word similarity, to name a few. In general, even though there are use-specific types of accuracy, we can always think of accuracy as an inverse measurement of error. In other words, the less error our models produce, the more accurate they are. We can, therefore, reframe model learning as the tuning of parameters that yields the minimum overall error — and we’ll stick to this definition for the rest of this piece.
What do we mean by minimum overall error?
To illustrate what minimum overall error is, we’ll start with a very simple example that does not involve a neural network. In our example, we are trying to predict the average price of a house as a function of how many bedrooms it has. We’ll plot the number of bedrooms on the horizontal axis and the price on the vertical axis. Let’s begin assuming our model knows nothing and simply makes random guesses at first — plotting the random guesses as dots, we have the following image.
Let’s then suppose we actually knock door-to-door on several houses and collect the true average price of homes given their number of bedrooms. Now we plot those true values too.
The overall error of our predictions can be calculated by subtracting our prediction from the true value, then squaring that subtraction, then adding everything up. Here’s a three-step guide:
In our case, our overall error — also known as the sum of squared errors — is 115,000,000,000 dollars squared. We can certainly do better than that. By the way, you may ask why we need to square the errors. Many articles forget to explain this. We square the errors for several reasons. One of them is that the sign of each error calculation (i.e., positive or negative) could play a role in the overall error otherwise. If we don’t square each term, we could end up with a situation where the errors cancel each other rate giving us the illusion of a low error rate when every prediction was wrong! Check it out:
Some of you may wonder if we couldn’t solve the problem by simply taking the absolute value of each error — the answer is that it would solve this specific problem above, but would miss out on two very important features. First, squaring errors has a more punitive effect on the errors. That is, small errors lead to a significantly smaller error squared than large errors squared. Secondly, having a squared function makes differentiation possible everywhere in the curve, and we’ll need to perform differentiations very soon.
How do we actually minimize overall error?
The housing price problem described above can be solved by the method of linear regression. We can try to fit a line of best fit such that the slope (i.e., angle) of the line makes predictions that minimize the overall squared error. This will make it far more likely that we’ll do better than random guessing.
But how do we know which value to pick for the slope so we minimize the sum of the squared errors? One way to achieve this is to plot several different values for the slope on the x-axis and calculate their corresponding sum of the squared errors on the y-axis. Plotting each as a point may give us the shape of the overall function — we call this the loss function. Those familiar with calculus may remember we can take the derivative of the loss function and set it equal to zero to find where the minimum value of the loss function is. The corresponding x-axis value at that point is the slope that minimizes the sum of the squared errors.
This is all great — and not that mathematically complicated. Even if we had a linear regression with multiple variables, the method outlined above could work feasibly well. But what about cases, such as neural networks, where we have hundreds, if not thousands, or millions of parameters that attempt to make classifications of non-linear, high dimensional data such as for instance word embeddings? Imagine trying to replicate the method above, but instead of considering only the slope, consider instead the value for each of the connecting lines below. Daunting, no?!
Trying to replicate the method above for such a problem would be computationally prohibitive — neural networks would take far, far longer to be trained. The cost would be financially prohibitive for anyone running this on the cloud. Think of all the amazing products we have today — the amazing discoveries — that may never have seen the light of day if it was this hard to minimize errors in large neural networks. Thankfully, there is a method out there that can help us substantially.
A very, very basic primer on neural networks
We’ll be talking about gradient descent shortly — all in the context of a neural network. So before we get into the math and intuition for gradient descent, here’s what you need to know about neural networks for today.
We can break a neural network into three parts — an input layer, a hidden layer, and an output layer. Most neural networks actually have several hidden layers but for simplicity’s sake, I only included one here. The input layer has one node per feature of the data — therefore a dataset with three features will have three nodes. The hidden layer has nodes too, each of which performs a calculation with nodes from the input layer that connect to it. Nodes in the hidden layer then pass on the results of their calculations to the output layer, where the outcome of the problem we are trying to solve is determined. Here’s what a simple neural network would like for a housing price prediction with three variables.
Nodes are connected to one another by lines that we call weights. The weights that connect them are a parameter of the model that we can change to minimize the error. The other parameter we can change is the bias term that exists within each hidden layer neuron. In the illustration above, everything that is purple can be changed to decrease overall error. But how can we change them efficiently — especially when we are dealing with networks with multiple hidden layers and thousands of nodes? That’s what gradient descent is here for.
Gradient Descent: an efficient approximation of the minimum error
Gradient descent is an algorithm that is used to find an approximation to the minimum point of a loss function. Because it attempts to compute an approximation, it does not need to solve the problem analytically. In other words, it does not need to know the entire shape of the function and set the derivative of that function to zero. This is important to us because in the example we are about to see, we are dealing with 10 dimensions — and I don’t know about you, but I can’t visualize them all! To demonstrate how gradient descent works, we will use our housing price example — our goal here is to find the value for our 8 weights and 2 biases such that we minimize our loss function and arrive at the highest possible accuracy.
And without further ado, here are the four steps in the gradient descent algorithm and how they can be used to solve our problem.
1. Initialize random guesses for weights and biases
Just like how we started with our one-dimensional house pricing problem, we begin with a random guess. But this time, we start with random initial values for our ten parameters — 8 weights and 2 biases. Then, we run the neural network forwards. The two purple nodes will perform calculations and pass the results along. Each purple node will have its own function to transform the value they get into something that can be translated into the price of a house — but this is for a different piece. The important part is that when training our model, we can compare the value predicted by the network in the green node with the price we know the house costs and in doing so, calculate the sum of the square errors after we do this for all houses in our dataset.
2. Compute partial derivatives for each parameter and add them up
Let’s say we go through all houses in our dataset, add up the errors and come to a value for the sum of the squared errors. Now what? We can’t plot it against ten dimensions. But, we can do something nifty nonetheless. We now compute the partial derivative of the squared error function with respect to each individual parameter and then add them all up.
Partial derivatives are a little too complicated for me to go into detail in this piece, but in simple terms, it isolates one parameter at a time and calculates an equation that tells us how much changing that individual parameter changes the overall error function. It won’t tell us a number, but rather a relationship. Then, we plug in the randomly initialized value for that parameter and evaluate the second derivative at that specific point. Here’s an example for the first parameter w1 with a made-up partial derivative:
If we do this for every single of our 10 parameters, we end up with a vector of their partial derivatives — we call this vector the gradient. We can then add up the partial derivatives of this vector and arrive at a single value. And, finally, we can plot that value against the corresponding sum of the squared errors too!
3. Use the gradient to generate the next guess
We can now make a second guess, but this time, it won’t be random! This time, it will be computed by taking the original guess and subtracting from it the gradient times a learning rate. The learning rate is a parameter we choose ourselves— if we want the next guess to be very far from the original one, we make that value large. Usually, we keep this value small so that the gradient is the one dictating how far our next guess will be. Since we are subtracting the gradient, this also means that when the gradient is negative, we end up increasing the value of the next guess, and when the gradient is positive, we end up decreasing the value of the next guess. Here’s a guided example where we start with our current parameter value (2nd column) and end up with our new guess for the parameter value (5th column). We use the value of 0.05 for the learning rate all throughout.
Notice how every time the gradient is negative, the new parameter value is larger than the previous parameter value. Conversely, every time the gradient is positive, the new parameter value is smaller than the previous parameter value. This brilliant mechanism makes it so that over time, we hone in closer and closer to a value in the middle — one where the gradient will be close to zero. If we compute the sum of squared errors with the new parameter values and then the partial derivatives (made-up values in this piece) of the new parameters and add them up again, we can plot our second point.
4. Iterate until we’re happy
Now, we repeat this several times. How many times? That’s up to us. We call this parameter epochs. If we set the epochs to five, then we will loop over this process five times. Notice how as we begin to converge on a value, the distance between the guesses shrinks. That’s because the gradient is getting closer and closer to zero so the difference between the previous value and the new value will decrease. If we arrive at a gradient equal to zero, then the subtracting term will also be zero and the input value will stagnate — meaning we have probably hit the minimum! Using five epochs, here’s what our zig-zag looks like.
And that’s it. Using this method, we find an approximation that gives us a derivative very close to zero. Rather than trying to visualize the loss function ahead of time as a sort of parabola, we simply zig-zag our way to a near-perfect answer. And the best part — this method works in any number of dimensions.
But what if there is more than a ‘minimum point’ in my loss function?
So far, whether we knew what the loss function looked like or not, all examples had a clear minimum point. But in some cases, our loss function could have multiple minimum points — called local minima. How can gradient descent guarantee we will not be ‘stuck’ at a local minimum simply due to the fact our initial guess just happened to be closer to it? A good way to visualize this is to imagine our first guess as a blue circle. That circle will fall to a minimum point, but not the lowest point in the entire loss function — rather, that would be a point earlier in the function, represented by where a purple circle would fall.
With gradient descent, there is no guarantee that we will land on the true minimum point. But, this isn’t as big of a problem as it may seem in theory. This is due to the fact neural networks are usually trained several times, with different initial guesses, which will mitigate the risk of a single guess landing in the wrong value. In the context of the example above, we would train so many model versions, that eventually we would guess an initial set of inputs corresponding to the purple sphere.
Another reason why it doesn’t seem to matter is that once we take into consideration high-dimensional spaces which we, unfortunately, can’t visualize, it seems that local minima are not common. This is more of a postulation than something we know concretely as it is difficult to create faithful visual representations of what is going on in high-dimensional space. Nonetheless, even with a 3-D space visualization created with an online visualizer, we can already see how some complex functions could still conceivably end up having a single steeper path that will converge most initial guesses.
Additional Considerations
Different flavors of gradient descent
There are two final things I want you to know about gradient descent. The first is that researchers have found better ways to optimize it over time and today there are several ‘flavors’ of gradient descent. Most neural networks these days are trained using mini-batch gradient descent so that we can balance the robustness of the gradient descent with its computational efficiency. It will be interesting to see if further optimizations can be done to gradient descent with more sophisticated cloud computing products or even quantum algorithms in a few years.
Vanishing and Exploding Gradients
Since gradient descent is in a sense a form of a feedback loop, we also have to watch out for two edge cases that can significantly hinder our models: vanishing gradients and exploding gradients. A vanishing gradient happens when we start our initial guess in a flat region of the loss function where the gradient is so close to zero that the point barely moves within each epoch. By contrast, an exploding gradient happens when we start our initial guess in a region that is too steep and the gradient is so large, our points move way too far.
There are solutions to these problems, though. Some solutions involve changes to parameters about how nodes interact and communicate with one another — some examples are regularization techniques, auto encoding, weight normalization, and many other complicated words that perhaps we can cover in future pieces. The point is that there is no shortage of methods to successfully address this problem — gradient descent still stands by the end.
Summary
Gradient descent is a clever algorithm that enables neural networks to learn the way they do. While it computes only an approximation of the minimum overall error in a model, the approximation is more than sufficient for the vast majority of cases. Moreover, the time and cost savings it enables is precisely why they have become an integral cornerstone of neural networks — and why they have evolved and improved over the years. While it’s important to note gradient descent can lead us astray with issues such as vanishing or exploding gradients, there are many techniques out there that can successfully address these. All modern neural networks use some form of gradient descent algorithm behind the curtains, so chances are, your Spotify recommendations are being helped by gradient descent — as are your Google searches, ChatGPT, your newsfeed on LinkedIn, and many more.
Further Resources
The following resources are excellent complements of what you’ve read above. I particularly want to point you to 3Blue1Brown’s video on gradient descent which in my opinion is the best visual explanation I’ve seen on this topic in video form. There is also a course on Udemy called ‘A deep understanding of deep learning’ by Mike Cohen where I largely drew inspiration for my visuals of vanishing and exploding gradients. I’ll link to the course below as well!
- Gradient Descent, How Neural Networks Learn (3Blue1Brown)
- A deep understanding of deep learning (Mike Cohen, Udemy)
- Gradient Descent (Kwiatkowski, 2021)
- A gentle introduction to neural networks (Fumo, 2017)
- Derivatives(Khan Academy)
- Partial Derivatives (Khan Academy)
- Exploding and Vanishing Gradients (Bohra, Analytics Vidhya)