Solving Optimization Problems with JAX

Mazeyar Moeini
May 25 · 8 min read
Image for post
Image for post
Joseph-Louis Lagrange & Isaac Newton, JAX logo by Google

1 Introduction

What is JAX? As described by the main JAX webpage, JAX is Autograd and XLA, brought together for high-performance machine learning research. JAX essentially augments the numpy library to create a nouvelle library with Autograd, Vector Mapping (vmap), Just In Time compilation (JIT), all compiled with Accelerated Linear Algebra (XLA) with Tensor processing unit (TPU) support and much more. With all of these features, problems that depend on linear algebra and matrix methods can be solved more efficiently. The purpose of this article is to show that indeed, these features can be used to solve a range of simple to complex optimization problems with matrix methods and to provide an intuitive understanding of the mathematics and implementation behind the code.

2 Grad, Jacobians and Vmap

Grad is best used for taking the automatic derivative of a function. It creates a function that evaluates the gradient of a given function. If we called grad(grad(f)), this would be the second derivative.

3 Single Variable Optimization

3.1 Gradient Descent

Let’s imagine that we have the following optimization problem from UC Davis; A rectangular piece of paper is 12 inches in length and six inches wide. The lower right-hand corner is folded over so as to reach the leftmost edge of the paper, find the minimum length of the resulting crease where L is the length.

Image for post
Image for post
Image by University of California Davis
Image for post
Image for post
Image for post
Image for post

3.2 Newton’s Method

The same problem can be solved using Newton’s Method. Usually, Newton’s Method is used for solving a function that is equal to zero such as f(x)=x²−2=0, in the form of:

Image for post
Image for post
Image for post
Image for post

4 Multivariable Optimization

4.1 The Jacobian

In multivariable problems, we define functions such that f(X),

Image for post
Image for post
Image for post
Image for post
Image for post
Image for post

4.2 The Hessian

In multivariable problems we define functions such that f(X),X = [x0,x1,x2…,xn]. Previously we defined the Jacobian (∇f). The Hessian is just (∇(∇f)) or ∇’’f which requires the differentiation of each function in the Jacobian to all variables, thus increasing the dimension.

Image for post
Image for post
Image for post
Image for post

5 Multivariable Constrained Optimization

Multivariable constrained optimization uses the same techniques as before but requires a different way of framing the problem. For constrained optimization we will use Lagrangian multipliers. The classic Lagrange equation requires solving for ∇f = λ∇g. However, computers have no way of symbolically solving this. Rather, we can rewrite the equation as ∇f−λ∇g=0 which is now an unconstrained optimization problem.

Image for post
Image for post
Image for post
Image for post
Image for post
Image for post

6 Three Variable Multivariable Constrained Optimization

Problems in real life usually have more than two variables to be optimized and optimization hyperparameters need to be fine-tuned. As the complexity of optimization problems increases, other methods should be considered. For now we can use the models from the previous section and just increase the number of variables. Luckily, JAX will automatically adjust for this, we just need to adjust the L function in the code.

Image for post
Image for post

7 Multivariable MultiConstrained Optimization

In the final part of this tutorial we will look at one of the most advanced types of optimization problems, multivariable multiconstrained optimization problems. Some of the problems, in the beginning, are admittedly better solved by hand. However as complexity increases, other numerical methods might be needed. Gradient Descent, no matter how many epochs and hyperparameters, can never 100% guarantee the best result but it always better than a random guess.

Image for post
Image for post
Image for post
Image for post

The Startup

Medium's largest active publication, followed by +719K people. Follow to join our community.

Mazeyar Moeini

Written by

Computer Science Student | Deep Learning and Mathematics Enthusiast | Learner

The Startup

Medium's largest active publication, followed by +719K people. Follow to join our community.

Mazeyar Moeini

Written by

Computer Science Student | Deep Learning and Mathematics Enthusiast | Learner

The Startup

Medium's largest active publication, followed by +719K people. Follow to join our community.

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store