Gradient Descent: How Machine Learning Learns (Part II)
Welcome to Part II of understanding gradient descent. If you haven’t yet, check out Part I for a conceptual understanding. Now, let’s dive into the code of optimizing a linear regression from scratch with our own gradient descent algorithm!
Data Processing
We start by importing the libraries we need: pandas for data manipulation, requests to query our data, and statsmodels for a formal linear regression to compare our model too.
We’ll pull our data from the World Bank Indicators API, which allows us to get major demographic indicators without needing any authentication.
I chose total fertility rate and crude birth rate because they are obviously correlated and linear regression will be suitable.
More information is available from the request than this function collects, but I focused on the regression variables and their identifiers.
Model Benchmark
We’ll run a regression of crude birth rate on total fertility rate using statsmodels to compare our model performance against.
We aim to achieve as close to this regression slope (coef for indicator_value_1) and coefficient (coef for const) as possible with our own optimization model!
Building Gradient Descent — Set Up
To initialize the slope, I approximated the average rate of change of these variables. For intercept, the minimum value of crude birth rate.
Next, we’ll calculate the loss function (the sum of squared residuals) for each iteration of the training process, so it’s best to define a function.
Let’s use this function to see what the error of our initialized (m, b) is, and what the error of the ideal regression coefficients is.
When running our model, we’ll also be calculating the rate of change in the loss function (its gradient vector) for each iteration of the training process, so let’s define a function for that as well.
Note that this function returns both the partial with respect to x (the slope m in our case) and the partial with respect to y (the intercept b in our case) so that we can update simultaneously. Whether this is split up into 1 or 2 functions isn’t relevant, but both partials must be calculated before either m or b is updated.
Building Gradient Descent — Iteration and Visualization
We’ll be visualizing our model’s iteration process, so we’ll need 3D plotting libraries:
Finally, we’ll define a function to plot our loss function as a surface and (m, b) value at each iteration as a point. If our model is working correctly, we should see the points move toward the low point of the surface. We’ll discuss this more later, but passing the learning rate and max iterations as easily-changeable arguments will help us explore what suits our model.
As you can see, it seems like we’re having success at 100 iterations and a .0003 learning rate. The runtime isn’t too bad either.
This looks to be a solid learning rate — we smoothly and efficiently move from our initial point (the blue arrow) towards our end goal, the minimum of the loss function surface (the green arrow). The slope and intercept returned by this function are also pretty close to the statsmodels benchmark.
Truth be told, this was found after much experimentation. If we try a smaller learning rate like .0001, we don’t get as close to the optimal values.
If we try a larger learning rate like .0004 or .0005, we see the ping-pong behavior of the algorithm before it converges.
And with any rate .0006 or higher, the algorithm diverges entirely.
For all of the learning rates that converged, the algorithm required the full 100 iterations to stop rather than stopping because the loss function changed by less than .001. Thus, by increasing the iterations to 1000 rather than 100, we get even closer to the optimal coefficients at the cost of longer training time.
Conclusion
Though our model requires a specific learning rate, it was successful in determining close-to-optimal regression slope and intercept values. Better gradient descent algorithms utilize algorithms for initializing values (such as a multi-start procedure) and for choosing a learning rate, but our implementation still displays the major concepts of the algorithm.
Full code linked here!