Object Orient Programming in JAX with Haiku

Hylke C. Donker
2 min readJul 9, 2023

--

JAX helps you write short, simple, and blazing fast computation. Unlike TensorFlow and PyTorch, JAX adopts a functional programming style. That is, functions must be “pure”: side effects are not allowed. In practice, we still want to re-use components, such as neural network layers. DeepMind’s Haiku marries the functional programming world of JAX with object oriented programming (OOP). In short, Haiku allows you to mix-and-match modules — OOP-style — and then convert it to a functional program.

On top of that, Haiku ships with a large library of neural network components: from simple convolutions and attention modules to completely trained ResNet models. And finally, my personal favourite: Haiku comes with a pseudo random number generator (PRNG) key iterator (haiku.PRNGSequence) that alleviates the key splitting burden.

Workflow

✨✨Three simple steps to enlightenment✨✨

transform: First, write a function containing Haiku modules. You then purify your function by decorating it with haiku.transform as follows:

import haiku as hk

@hk.transform
def forward(x):
neural_net = hk.nets.MLP([300, 100, 10])
return neural_net(x)

Here, we created a vanilla neural network (multi-layer perceptron, or MLP) and used x to make a prediction. Technically, forward is now a Transformed instance holding two pure functions: Transformed.init and Transformed.apply.

🔢init: Your neural network in forward has lots of parameters. Haiku needs to make one pass through your function to track and initialise the parameters. You do that by calling the init method of your transformed function with a pseudo random number generator key.

import jax

key_seq = hk.PRNGSequence(42)
params = forward.init(next(key_seq), x)

This will return a pytree params with the initial weights and biases of your purified function. One set for each layer. Notice that we used next(key_seq) to generate a new key from Haiku’s pseudo random number generator key sequence.

🔨apply: This method is the pure counterpart of your original forward function. Now that you have the accompanying params, you can call the function on a batch x as follows:

logits = forward.apply(params, next(key_seq), x)

That’s it, you’re done!

For further reading, I recommend the Haiku basics docs for a brief intro to Haiku.

--

--

Hylke C. Donker

Data scientist @ University Medical Centre Groningen (NL).