Solving Optimization Problems with JAX

Mazeyar Moeini
The Startup
Published in
8 min readMay 25, 2020
Joseph-Louis Lagrange & Isaac Newton, JAX logo by Google

1 Introduction

What is JAX? As described by the main JAX webpage, JAX is Autograd and XLA, brought together for high-performance machine learning research. JAX essentially augments the numpy library to create a nouvelle library with Autograd, Vector Mapping (vmap), Just In Time compilation (JIT), all compiled with Accelerated Linear Algebra (XLA) with Tensor processing unit (TPU) support and much more. With all of these features, problems that depend on linear algebra and matrix methods can be solved more efficiently. The purpose of this article is to show that indeed, these features can be used to solve a range of simple to complex optimization problems with matrix methods and to provide an intuitive understanding of the mathematics and implementation behind the code.

Firstly we will be required to import the JAX libraries and nanargmin/nanargmax from numpy, as they are not implemented in JAX yet. If you are using Google Colab, there is no installation of JAX required, as JAX is open sourced and maintained by Google.

2 Grad, Jacobians and Vmap

Grad is best used for taking the automatic derivative of a function. It creates a function that evaluates the gradient of a given function. If we called grad(grad(f)), this would be the second derivative.

Jacobian is best used for taking the automatic derivative of a function with a vector input. We can see that it returns the expected vector from a circle function.

Even more interesting is how we can compute the Hessian of a function by computing the Jacobian twice; this is what makes JAX so powerful! We see that the function hessian takes in a function and returns a function as well.

It should be noted that the gradients are computed with automatic differentiation, which is much more accurate and efficient compared to finite differences.

3 Single Variable Optimization

3.1 Gradient Descent

Let’s imagine that we have the following optimization problem from UC Davis; A rectangular piece of paper is 12 inches in length and six inches wide. The lower right-hand corner is folded over so as to reach the leftmost edge of the paper, find the minimum length of the resulting crease where L is the length.

Image by University of California Davis

After doing some trigonometry, we can find the length of the crease with respect to the variable x to be:

To find the minimum we would have to check all the critical points such that L’=0. However, although this is a relatively simple optimization problem, it would still lead to a messy derivative that requires chain rule and quotient rule. Therefore, as these problems only become more complex, it would be wise to find numerical methods to solve them. Jumping over to JAX, we can define the functions in python.

Then, using grad(L) we can find the derivative of L and minimize this using stepwise gradient descent.

We can see how simple things become with JAX; the actual optimization happens with 6 lines of code! Notice how the first vmap is used in each epoch to map the minGD function over the whole domain, then it’s used to map the domain with the objective function L to find the objective minimum and argmin.

The numeric answer gives a 0.001851% error from the actual answer which is 9*sqrt(3)/2, the error is acceptable given that the true value is an irrational number, to begin with.

3.2 Newton’s Method

The same problem can be solved using Newton’s Method. Usually, Newton’s Method is used for solving a function that is equal to zero such as f(x)=x²−2=0, in the form of:

This can easily be used for optimization given that we search for f’(x)=0

Newton’s Method for optimization can easily be implemented with JAX.

Notice how easily L’’ is calculated in line 2 of the code.

Newton’s Method has the added advantage of the error being squared in each step.

4 Multivariable Optimization

4.1 The Jacobian

In multivariable problems, we define functions such that f(X),

X=[x0,x1,x2…,xn] . When the number of variables increases, we can no longer use the normal derivative; it requires the Jacobian also written as ∇f.

A Jacobian is a derivative of multivariable function, therefore, it captures how each variable affects a function. Since these are the first derivatives, we can again use these to optimize a multivariable function.

Now, to implement this with JAX is just as simple as the single variable case. We will optimize the following function:

Notice again how easily JAX allows us to calculate the Jacobian.

Similar to last time, once we have the optimization function we can run it through a loop.

Then we check for the results.

4.2 The Hessian

In multivariable problems we define functions such that f(X),X = [x0,x1,x2…,xn]. Previously we defined the Jacobian (∇f). The Hessian is just (∇(∇f)) or ∇’’f which requires the differentiation of each function in the Jacobian to all variables, thus increasing the dimension.

To use the Hessian in optimization, it is really similar to Newton’s Method. In fact, it is analogous.

We can observe, where it’s not possible to divide by a matrix, we multiply by its inverse. There is a mathematical explanation for this using the quadratic term of a Taylor expansion, however, it is too lengthy to explain. Again using the Autograd library it is incredibly easy to calculate the Hessian.

5 Multivariable Constrained Optimization

Multivariable constrained optimization uses the same techniques as before but requires a different way of framing the problem. For constrained optimization we will use Lagrangian multipliers. The classic Lagrange equation requires solving for ∇f = λ∇g. However, computers have no way of symbolically solving this. Rather, we can rewrite the equation as ∇f−λ∇g=0 which is now an unconstrained optimization problem.

Just like the other optimization problems, we have a function that needs to be solved at zero ∇ L=0. Note the solving ∇ L= 0 is no different than solving for systems of nonlinear equations. Our final iterative equation will look similar.

The reason for the Hessian being involved again is due to minimizing L and solving for ∇L=0 being the same statement. Also, when using Lagrangian multipliers we have to introduce a new variable λ in the code, L(X) will take in X where X=[x0, x1, λ].

Let's say we have the objective function f(X) and the constraint g(X), in the code λ is l[3].

The correct minimum is -8, the argmin should be (sqrt(2),−1), and since we included the λ in our calculation we find the Lagrangian multiplier is −4.0.

6 Three Variable Multivariable Constrained Optimization

Problems in real life usually have more than two variables to be optimized and optimization hyperparameters need to be fine-tuned. As the complexity of optimization problems increases, other methods should be considered. For now we can use the models from the previous section and just increase the number of variables. Luckily, JAX will automatically adjust for this, we just need to adjust the L function in the code.

Let’s attempt to solve a problem with real-life applications found from Paul’s Online Notes; Find the dimensions of the box with the largest volume if the total surface area is 64cm². Our objective function is f(x) = x0*x1*x2 the constraint is g(x) = 2*x0*x1 + 2*x1*x2 + 2*x0*x2 − 64. First we have to define the functions, then the only thing that we have to change is the index of the list feeding into Lagrange.

This part of the code stays exactly the same except we add a learning rate of 0.1 to gain greater accuracy. We might also have to increase the total epochs.

The real answer is sqrt(32/3)³ ≈ 34.837187 the length of each side should be sqrt(32/3) ≈ 3.265985 the is almost calculation is perfect as the errors are negligible in real life. It’s important to note that without the learning rate, optimization is unlikely and accuracy was increased by doubling the number of epochs. Hopefully it is now obvious how more variables can be included in the optimization model.

7 Multivariable MultiConstrained Optimization

In the final part of this tutorial we will look at one of the most advanced types of optimization problems, multivariable multiconstrained optimization problems. Some of the problems, in the beginning, are admittedly better solved by hand. However as complexity increases, other numerical methods might be needed. Gradient Descent, no matter how many epochs and hyperparameters, can never 100% guarantee the best result but it always better than a random guess.

Let’s start by trying to maximize the object function f(x0, x1) with the constraints g(x0, x1) and h(x0, x1).

More problems like this can be found at Duke University. The general form of the Lagrangian function can be written such that the Jacobian of the objective minus each constraint function Jacobian multiplied by a respective lambda is equal to zero.

We see now that we have to define all three functions. Note that l[2] = λ1 and l[3] = λ2.

Once again these are the expected values.

Hopefully, by now you found a new interest in optimization problems and more importantly realized how the features JAX offers make solving such problems easier. Furthermore, machine learning libraries offer fast and reliable tools for problem-solving that can be used outside the machine learning domain.

Further Readings and Sources:

Linear Algebra and Learning from Data by Gilbert Strang

Link to Google CoLab:

Solving Optimization Problems with JAX CoLab

Link to LaTeX-PDF version:

Solving Optimization Problems with JAX PDF

--

--

Mazeyar Moeini
The Startup

Computer Science Student | Deep Learning and Mathematics Enthusiast | Learner