Plugging Into JAX

Trying the new new machine learning framework

Nick Doiron
The Startup
5 min readOct 5, 2020

--

In November 2018, a new library appeared on Google’s GitHub: JAX. In 2020 there’s been a jump in popularity, and rumors that Googlers prefer it to TensorFlow v2. Functions written in JAX run on multiple-GPU or TPU systems without awkward helper libraries and without moving data out of device memory.

As a developer usually using AutoKeras and Transformers to solve neural networks problems, I’m cautiously optimistic about JAX. On its own it’s too low-level, closer to NumPy. If you try out one of the frameworks which adds neural network support, you find a growing community, but not as many examples or StackOverflow threads as PyTorch and TensorFlow.

For people who believe in a JAX future, this gap is an opportunity to build out the next big thing!
For this project, I tried solving a flower-classification challenge from Kaggle using four JAX frameworks: Flax/Linen, Haiku, Objax, and Elegy. These networks come from Google, DeepMind, and Poets-AI.

training loop of a neural network

Neural Net Notes

If you’re unfamiliar with a training loop, each step of training the network involves:

  • input training data
  • make a prediction with network
  • use loss to compare the prediction and reality (this function can simply measure difference, or weight losses depending on seriousness of mistake)
  • use optimizer to correct model (there are different strategies based on stage of training, amount of loss, etc.)

This process repeats for every batch and every epoch of training.

Flax (and Linen)

Flax is a high-performance neural network library for JAX that is designed for flexibility

I believe that Flax was the first JAX framework to become publicly available. After I started Tweeting about my project, a developer recommended switching to their beta API, called Linen. I took their Imagenet example and made a few minor changes to run it on CoLab with the Imagenette dataset. The Imagenet dataset is too big for CoLab and has many classes (1,000), so 10-class Imagenette and their spinoff projects (Imagewoof) are easier to play around with.

a template network in Flax/Linen

I can then make a few more changes to load the flowers challenge in the TFRecord format provided on Kaggle.
You might wonder, hey, if JAX is an alternative to TensorFlow, why am I still importing TensorFlow and loading data in TFRecord / tf.data.Dataset formats?

  • A lot of datasets / benchmarks are available in this semi-standard format (there are some annoying differences, but what can you do)
  • Datasets can be divided into batches and managed so we don’t try to load the full dataset into RAM at one time
  • Image datasets have tools for cropping, mirroring, coloring, masking, etc. to augment training data and make the model more flexible.

These tools exist in the PyTorch ecosystem, too, but JAX is closer to Google.

Haiku

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX’s pure function transformations.

Haiku’s ImageNet example shows how it includes ResNet101 as one of its standard building blocks. I successfully adapted their example for Imagenette and the flower classification dataset.
One particularly weird difference in Haiku (or at least the way this example was configured) is that I needed to set the number of training, test, and validation examples, instead of passing separate datasets or setting a % split.

The validation set needed to be a multiple of the evaluation batch size, and Imagenette v0.1.0 has only 500 test images, so I kept getting errors on the eval step, after the train step had used the same code successfully.

This brings up another issue about errors! The JAX/NVIDIA talk explains that JAX’s JIT compiler should accept good code and not silently fail or act strangely when you mess up. In my short experience with these frameworks, this meant lengthy error stacks saying ‘left hand side does not match right hand side’ or ‘pmap got arg of rank _’ , where a mature framework might have said, you put in the wrong number of classes or expected dimensions: [devices, batch_size, height, width, colors] to avoid confusion / dread.

Objax

Objax is designed by researchers for researchers with a focus on simplicity and understandability.

Objax is a new-ish framework and the first that I tried. I was able to read their Imagenet example and directly port it over to the flowers challenge without practicing on Imagenette.
There was a puzzling step where I needed to transpose images from CHW to HWC [height, width, color_channels] order, and StackOverflow answers had me involve TensorFlow. I’m not clear where I can bring standard TensorFlow, NumPy, or other operations into my code without losing the performance which JAX promises. For the time being, I’m happy the code runs.

Elegy

Elegy is a Neural Networks framework based on Jax inspired by Keras.

This is the newest of the pack, and a creation of poets-ai. This post initially skipped over Elegy, but the repo has since been revamped with new examples and ResNet modules. Their Imagenet example was easy to adapt to an Imagenette notebook. Here we can see:

  • how loss is measured (SparseCategoricalCrossentropy)
  • building an optimizer for the specific problem
  • a wrapper tfds2jax_generator to handle the awkwardness between TensorFlow datasets / batching and the JAX loops.
copied from the Elegy Imagenet example

For more libraries using JAX in neural networks, seq2seq, probabilistic programming, etc: https://news.ycombinator.com/item?id=22814870

Updates?

This article is from October 2020. For latest recommended libraries and tutorials please check https://github.com/n2cholas/awesome-jax or my page, github.com/mapmeld/use-this-now/blob/main/README.md#jax-tutorials

--

--