JAX: Fast as PyTorch, Simple as NumPy

Hylke C. Donker
4 min readJul 5, 2023

--

JAX logo. Google LLC, Public domain, via Wikimedia Commons.

JAX is a new competitor of TensorFlow and PyTorch. JAX emphasises simplicity without sacrificing speed and scalability. Since JAX requires less boiler plate code, programs are shorter, closer to the math, and thus easier to understand.

TL;DR:

  • 🐍 Access NumPy functions using import jax.numpy and SciPy functions with import jax.scipy.
  • 🔥 Speed up with just-in-time compilation by decorating with @jax.jit.
  • ∇ Take derivatives using jax.grad.
  • ➡️ Vectorise with jax.vmap and parallelise across devices with jax.pmap.

This post is a shortened version of a talk I gave last spring at PyGrunn 11, one of the largest Python conferences in Europe.

Functional Programming

Fig. 1: A pure function looks like a mathematical function. Image by Author.

JAX follows a functional programming philosophy. This means that your functions must be self-contained, or pure: side effects are not allowed. Essentially, a pure function looks like a mathematical function (Fig. 1). Input comes in, something comes out, but there is no communication with the outside world.

❌ Example #1

The following snippet is an example that is not functionally pure.

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total

Notice the bias outside impure_example. During compilation (see below), bias may be cached and therefore no longer reflect changes to bias.

✅ Example #2

Here is an example that is pure.

def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation

Here, pure_example is self-contained: all parameters are passed as arguments.

🎰 Deterministic Samplers

Fig. 2: Random functions require a pseudo random number generator key. Image by Author.

In computers, true randomness does not exist. Instead, libraries such as NumPy and TensorFlow track a pseudo-random number state to generate “random” samples. A direct consequence of functional programming is that random functions work different. Since a global state is no longer allowed, you need to explicitly pass in a pseudo random number generator (PRNG) key every time you sample a random number (Fig. 2).

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

Moreover, you are responsible for advancing the “random state” for any subsequent calls.

key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

🔥 jit

You can speed up your code by just-in-time compiling your JAX instructions. For example, to compile your scaled exponential linear units (SELU) function, use the NumPy functions from jax.numpy and add the jax.jit decorator to the function as follows:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

Under the hood, JAX traces your instructions and converts it into jaxpr. This allows the accelerated linear algebra (XLA) compiler to make very efficient optimised code for your accelerator.

∇ grad

One of the most powerful features of JAX is that you can easily take gradients. With jax.grad, you define a new function that is the symbolic derivative.

from jax import grad

def f(x):
return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

As you can see in the example, you’re not limited to first order derivatives. You can take the n-th order derivative by simply chaining the grad function n times in sequence.

➡️ vmap and pmap

Matrix multiplications require serious mental gymnastics to get all the batch dimensions right. JAX’s vectorise-map function vmap alleviates this burden by vectorising your function. Basically, every code chunk that applies a function f element-wise is a candidate to be replaced by vmap. Let’s look at an example.

To compute the linear function:

def linear(x):
return weights @ x

across a batch of examples [x₁, x₂,..], we could naively (without vmap) implement it as follows:

def naively_batched_linear(X_batched):
return jnp.stack([linear(x) for x in X_batched])

Instead, by vectorising linearwith vmap we can compute the entire batch in one go:

def vmap_batched_linear(X_batched):
return vmap(linear)(X_batched)

Bonus: To distribute your workload across accelerators, you can play the same game: replace vmapwith pmap and your computation will scale out across multiple devices.

To learn more, I recommend Laurence Moroney’s introductory video. For further reading, take a look at the JAX docs.

--

--

Hylke C. Donker

Data scientist @ University Medical Centre Groningen (NL).