# 1 Introduction

What is JAX? As described by the main JAX webpage, JAX is Autograd and XLA, brought together for high-performance machine learning research. JAX essentially augments the numpy library to create a nouvelle library with Autograd, Vector Mapping (vmap), Just In Time compilation (JIT), all compiled with Accelerated Linear Algebra (XLA) with Tensor processing unit (TPU) support and much more. With all of these features, problems that depend on linear algebra and matrix methods can be solved more efficiently. The purpose of this article is to show that indeed, these features can be used to solve a range of simple to complex optimization problems with matrix methods and to provide an intuitive understanding of the mathematics and implementation behind the code.

# 2 Grad, Jacobians and Vmap

Grad is best used for taking the automatic derivative of a function. It creates a function that evaluates the gradient of a given function. If we called grad(grad(f)), this would be the second derivative.

# 3 Single Variable Optimization

Let’s imagine that we have the following optimization problem from UC Davis; A rectangular piece of paper is 12 inches in length and six inches wide. The lower right-hand corner is folded over so as to reach the leftmost edge of the paper, find the minimum length of the resulting crease where L is the length.

## 3.2 Newton’s Method

The same problem can be solved using Newton’s Method. Usually, Newton’s Method is used for solving a function that is equal to zero such as f(x)=x²−2=0, in the form of:

# 4 Multivariable Optimization

## 4.1 The Jacobian

In multivariable problems, we define functions such that f(X),

## 4.2 The Hessian

In multivariable problems we define functions such that f(X),X = [x0,x1,x2…,xn]. Previously we defined the Jacobian (∇f). The Hessian is just (∇(∇f)) or ∇’’f which requires the differentiation of each function in the Jacobian to all variables, thus increasing the dimension.

# 5 Multivariable Constrained Optimization

Multivariable constrained optimization uses the same techniques as before but requires a different way of framing the problem. For constrained optimization we will use Lagrangian multipliers. The classic Lagrange equation requires solving for ∇f = λ∇g. However, computers have no way of symbolically solving this. Rather, we can rewrite the equation as ∇f−λ∇g=0 which is now an unconstrained optimization problem.

# 6 Three Variable Multivariable Constrained Optimization

Problems in real life usually have more than two variables to be optimized and optimization hyperparameters need to be fine-tuned. As the complexity of optimization problems increases, other methods should be considered. For now we can use the models from the previous section and just increase the number of variables. Luckily, JAX will automatically adjust for this, we just need to adjust the L function in the code.

# 7 Multivariable MultiConstrained Optimization

In the final part of this tutorial we will look at one of the most advanced types of optimization problems, multivariable multiconstrained optimization problems. Some of the problems, in the beginning, are admittedly better solved by hand. However as complexity increases, other numerical methods might be needed. Gradient Descent, no matter how many epochs and hyperparameters, can never 100% guarantee the best result but it always better than a random guess.

Written by

Written by

## Mazeyar Moeini

#### Computer Science Student | Deep Learning and Mathematics Enthusiast | Learner 