JAX: Fast as PyTorch, Simple as NumPy
JAX is a new competitor of TensorFlow and PyTorch. JAX emphasises simplicity without sacrificing speed and scalability. Since JAX requires less boiler plate code, programs are shorter, closer to the math, and thus easier to understand.
TL;DR:
- 🐍 Access NumPy functions using
import jax.numpy
and SciPy functions withimport jax.scipy
. - 🔥 Speed up with just-in-time compilation by decorating with
@jax.jit
. - ∇ Take derivatives using
jax.grad
. - ➡️ Vectorise with
jax.vmap
and parallelise across devices withjax.pmap
.
This post is a shortened version of a talk I gave last spring at PyGrunn 11, one of the largest Python conferences in Europe.
Functional Programming
JAX follows a functional programming philosophy. This means that your functions must be self-contained, or pure: side effects are not allowed. Essentially, a pure function looks like a mathematical function (Fig. 1). Input comes in, something comes out, but there is no communication with the outside world.
❌ Example #1
The following snippet is an example that is not functionally pure.
import jax.numpy as jnp
bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
Notice the bias
outside impure_example
. During compilation (see below), bias
may be cached and therefore no longer reflect changes to bias
.
✅ Example #2
Here is an example that is pure.
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation
Here, pure_example
is self-contained: all parameters are passed as arguments.
🎰 Deterministic Samplers
In computers, true randomness does not exist. Instead, libraries such as NumPy and TensorFlow track a pseudo-random number state to generate “random” samples. A direct consequence of functional programming is that random functions work different. Since a global state is no longer allowed, you need to explicitly pass in a pseudo random number generator (PRNG) key every time you sample a random number (Fig. 2).
import jax
key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
Moreover, you are responsible for advancing the “random state” for any subsequent calls.
key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
..
🔥 jit
You can speed up your code by just-in-time compiling your JAX instructions. For example, to compile your scaled exponential linear units (SELU) function, use the NumPy functions from jax.numpy
and add the jax.jit
decorator to the function as follows:
from jax import jit
@jit
def selu(x, α=1.67, λ=1.05):
return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
Under the hood, JAX traces your instructions and converts it into jaxpr. This allows the accelerated linear algebra (XLA) compiler to make very efficient optimised code for your accelerator.
∇ grad
One of the most powerful features of JAX is that you can easily take gradients. With jax.grad
, you define a new function that is the symbolic derivative.
from jax import grad
def f(x):
return x + 0.5 * x**2
df_dx = grad(f)
d2f_dx2 = grad(grad(f))
As you can see in the example, you’re not limited to first order derivatives. You can take the n-th order derivative by simply chaining the grad
function n times in sequence.
➡️ vmap and pmap
Matrix multiplications require serious mental gymnastics to get all the batch dimensions right. JAX’s vectorise-map function vmap
alleviates this burden by vectorising your function. Basically, every code chunk that applies a function f
element-wise is a candidate to be replaced by vmap
. Let’s look at an example.
To compute the linear function:
def linear(x):
return weights @ x
across a batch of examples [x₁, x₂,..], we could naively (without vmap
) implement it as follows:
def naively_batched_linear(X_batched):
return jnp.stack([linear(x) for x in X_batched])
Instead, by vectorising linear
with vmap
we can compute the entire batch in one go:
def vmap_batched_linear(X_batched):
return vmap(linear)(X_batched)
Bonus: To distribute your workload across accelerators, you can play the same game: replace vmap
with pmap
and your computation will scale out across multiple devices.
To learn more, I recommend Laurence Moroney’s introductory video. For further reading, take a look at the JAX docs.