A bird’s-eye view of Google JAX

Abhijit Gupta
Geek Culture
Published in
4 min readJun 13, 2021


If I were to summarize what Google Jax is, I would say it’s a heterogeneous mixture of Functional programming (FP) style and differentiable NumPy operations, running on accelerators.

The familiarity of NumPy along with FP makes it special. Its side-effect-free way of doing things makes it safe, so to speak. You are not allowed to do mutations, i.e., in place modification. Some might think that this might hinder its performance but that is usually not the case; the compiler takes care of it. It offers an asynchronous dispatch mechanism, where one does not need to wait for computation to get completed before the control is transferred back to the user. Essentially, we get a future object that is decoupled from it was computed(promise). This paradigm facilitates flexibility and distributed computing. Another highlight is the just in time compilation(jit) that allows us to compile multiple operations together using XLA (an optimized linear algebra compiler). Besides that, we get a vectorized map available via vmap API.

Let’s talk about the grad function that jax offers. From FP viewpoint,

grad :: Differentiable f => f -> f’

That is, for a differentiable function f, we get its gradient. The grad(f) is the function that computes gradient, and grad(f) (x) is the gradient of f computed at x. To illustrate how one can use grad and vmap together, here’s a simple function —

plot_func_and_deriv(lambda x: x**3)

jax.grad(f) (x) is the gradient of f evaluated at x

The ubiquitous map is defined as:

map :: (a -> b) -> [a] -> [b]

The augmented vmap takes it a step further and we get the benefits of auto-vectorization. By default, the zero array axis is used to map over for all arguments. Some use cases for vmap are:

mat = random.normal(key, (150,100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.vdot(mat, v)
vmap(lambda mat,v: jnp.dot(mat, v), (None,0) ) (mat, batched_x)
vmap(lambda v: jnp.dot(mat, v), 0) (batched_x)
(vmap(lambda v: jnp.dot(mat, v), 1, 0) (random.normal(key, (100, 10))))
(vmap(lambda v: jnp.dot(mat, v), 1, 1) (random.normal(key, (100, 10))))
vv = lambda v1, v2: jnp.vdot(v1,v2)
mv = vmap(vv, (0,None), 0) #([b,a], [a]) -> [b]
mm = vmap(mv, (None, 1), 0) # Note: (None, 0), normally. Here, we have unusual (10, None) shape
mm(mat, batched_x.T)
vmap(mv, (None, 0), 0) (mat, batched_x)

A feature that is worth mentioning is that we can register our own custom data types by implementing the Pytree interface. Pytree is a tree-like structure built out of container-like Python objects. By doing so, JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.

from jax.tree_util import register_pytree_node@register_pytree_node_class
class Point:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def __repr__(self):
return f"Point({self.x}, {self.y}, {self.z})"
def tree_flatten(self):
return ((self.x, self.y, self.z), None)
def tree_unflatten(cls, aux_data, children):
return cls(*children)

Now, we can define arbitrary functions that operate on our data type and make them differentiable:

def dist_orig(pt: Point):
return jnp.sqrt(pt.x**2 + pt.y**2 + pt.z**2)
grad(dist_orig)(Point(1., 2., 3.))

Let’s talk about two fundamental operations: Jacobian vector product and Vector-Jacobain product.

Jacobian-vector product

JVP is the projection of a given vector onto the Jacobian matrix of an operator. It captures the crucial information on the local geometry of the deep neural network(DN) input-output mapping, which is one of the main reason behind its popularity. Unfortunately, JVPs are computationally expensive for real world DN architectures.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
# Linear logistic model
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b) # inputs is data matrix
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.39],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])# loss is a scalardef loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds)*(1 - targets)
return -jnp.sum(jnp.log(label_probs))
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
from jax import jvp
# Isolate the function from the weight matrix to the predictions
def f(W): return predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(lambda W: predict(W, b, inputs), primals=(W,), tangents=(v,))

Vector Jacobian products form the backbone of reverse mode auto-differentiation.

vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals.

from jax import vjpy, vjp_fun = vjp(lambda W: predict(W, b, inputs), W)key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
identity = jnp.eye(*y.shape, dtype=jnp.float32)
# Pull back the covector `u` along `f` evaluated at `W`
print("Recovering Jacobian elements row-wise!")
vjp_fun(identity[3]), sep="\n")

From the examples above, I tried to illustrate the key higlights of JAX including its FP inspired approach, which encourages composability and results in clean code. Check out the official documentation if the article has stirred your interest.