Feature Toggle in JAX

Maneesh Sutar
Thoughtworks: e4r™ Tech Blogs
8 min readMay 8, 2024

In the previous article, we talked about JAX and its use case in a CFD simulation. In this article, we will deep dive into the just-in-time compilation in JAX, and how we can use it as a feature toggle.

All code examples present in this article are also present in this github repo

Just-in-time (JIT) compilation in JAX

How does JIT compilation work?

Python is an interpreted language, thus the program is executed line-by-line. By default, JAX too performs operations sequentially, one at a time. JIT compilation techniques perform compilation of the program during runtime. The program is continuously analysed to find parts of the program that can be further optimised. JIT compilers will only optimise the code when the performance gain by compiled code is above the overhead of JIT compilation.

JAX library provides jax.jit function, which wraps another python function which can be JIT compiled or jitted at runtime. This gives programmer an ability to decide which functions are suitable for JIT compilation.

In following code snippet, we define a simple python function “AXPY” which is a standard linear algebraic function used for benchmarking.

import jax
from jax import numpy as jnp

def AXPY(a: jnp.float32,
x: jax.Array,
y: jax.Array) -> jax.Array:
ans = a * x + y
return ans

To create jitted version of AXPY, we can call:

AXPYjitted = jax.jit(AXPY)

This creates a python variable AXPYjitted which is an instance of type jaxlib.xla_extension.PjitFunction .

At runtime, when the AXPYjitted function is called for the first time, JAX performs a trace operation. The tracer analyses the behaviour of the function for the given input arguments. Suppose APXYjitted is called using following parameters:

a = 0.5
x = jnp.full((5,5),2.0, dtype=jnp.float32)
y = jnp.full((5,5),3.0, dtype=jnp.float32)

AXPYjitted(a,x,y)

JAX replaces the input arguments with basic tracer objects, having same shape and dtype as the arguments, but which are value agnostic. Calling print() on all the objects inside AXPYjitted function gives following output:

a =  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
x = Traced<ShapedArray(float32[5,5])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[5,5])>with<DynamicJaxprTrace(level=1/0)>
ans = Traced<ShapedArray(float32[5,5])>with<DynamicJaxprTrace(level=1/0)>

JAX’s JIT compiler analyses the operations performed on the tracer objects, and outputs an compiled XLA. Once compilation is done, the program executes the jitted function with the actual values of the input arguments and produces the output. Since tracer objects are value agnostic, next time when the jitted function is called with matching inputs, JAX uses the pre-compiled XLA.

Jaxpr: a static intermediate representation

The sequence of operation which are compiled down to XLA are encoded in JAX representation or jaxpr. Its an intermediate representation, as same jaxpr can be executed across different across different devices (CPU, GPU or TPU). A programmer can have a look at the jaxpr of their program, to identify any hotspots and re-write the program as needed.

To print jaxpr of the previously defined AXPY function, we need to call jax.make_jaxpr as shown in the example below:

a = 0.5
x = jnp.full((5,5),2.0)
y = jnp.full((5,5),3.0)

print(jax.make_jaxpr(AXPY)(a,x,y))

Output:

{ lambda ; a:f32[] b:f32[5,5] c:f32[5,5]. let
d:f32[5,5] = mul a b
e:f32[5,5] = add d c
in (e,) }

There are a few things we need to understand in the resultant jaxpr:

  1. Even though python is a loosely typed language, the jaxpr contains statically defined dtype and shape of each input argument to the function AXPY. JAX infers this data from the arguments passed to the function AXPY at runtime i.e. in our example the variables a, x and y.
  2. The “mul” and “add” are jaxpr representations of the addition and multiplication operations. When executing the code on a device (GPU/TPU/CPU), JAX will take care of calling the equivalent instruction for the device.

JAX follows a functional programming paradigm. It assumes that all the functions are pure functions i.e. they don’t produce any side- effects. The output of a function must be dependent on the input of the function. During JIT compilation, all the side-effects of a function are ignored, and thus they are absent from the jaxpr.

As an example, let’s modify the AXPY functions to see how it affects the jaxpr:

b = 10.0 # global variable

def AXPY2(a: jnp.float32,
x: jax.Array,
y: jax.Array) -> jax.Array:
# the main AXPY function
ans = a * x + y

# priting: side-effect
print("Performing AXPY")

# accessing global variable
ans = ans + b
return ansp

The jaxpr of the AXPY2 function looks like this:

{ lambda ; a:f32[] b:f32[5,5] c:f32[5,5]. let
d:f32[5,5] = mul a b
e:f32[5,5] = add d c
f:f32[5,5] = add e 10.0
in (f,) }

After compilation, JAX gets rid of the print statement as it is one of the side-effects. Also, it treats the global variable as static and replaces it with its value.

JIT is not a solution for every problem

JIT compilation is an optimisation which can make the JAX code run faster on a device. But JIT compilation does have its own implications:

  1. JIT compilation adds a small overhead the first time a function is called. In a typical simulation which runs over multiple iterations, this overhead is nullified.
  2. If a JIT compiled function is called with new parameters with different input data types / data shape, then JAX needs to recompile the function.
  3. As discussed earlier, JIT can not be applied on the functions which may need to produce some side-effects, e.g. generating log files, printing error to stdout etc.

Taking advantage of the JIT for feature toggle

But first, feature toggle?

Let’s create a simple pipeline function which takes two JAX arrays, performs some steps and outputs a JAX array. The output of the pipeline is the summation of the outputs of individual steps. The following code snippet describes the steps in the pipeline and the pipeline function.

# The number after each step decides the order of the steps
def step1(a , b):
return a + b

def step2(a , b):
return a * b

def step3(a , b):
return a - b

@jit
def pipeline(a , b):
step1_out = step1(a, b)
step2_out = step2(a, b)
step3_out = step3(a, b)
output = step1_out + step2_out + step3_out
return output

Pipeline function is decorated with @jit annotation. This is another way of declaring a function to be jit compiled.

As per a new requirement, we have been told that step1 must be disabled, and only step2 and step3 should be part of the pipeline. We can change the code of the original pipeline to fit the new requirements:

@jit
def pipeline_new(a , b):
# step1 is not part of this pipeline
step2_out = step2(a, b)
step3_out = step3(a, b)
output = step2_out + step3_out
return output

But changing the code every time the requirement changes is not an efficient solution. In such cases, we should introduce Feature toggles in the code. Feature Toggles is a technique which allows a programmer to change system behaviour without any change in the code. This can be achieved by using an external configuration file. At runtime, the program reads the file and decides which code to execute based on some existing conditional statements.

Setting up the pipeline with feature toggle

As we know that JAX, like python, executes the code line by line. Thus adding more conditional statements for handling feature toggles will only degrade the performance. Since JAX JIT is performed at runtime, it would have information regarding which features are to be turned on or off. Using the information, when JAX compiles the code to XLA, we can modify the compilation process, such that only the required features are part of the compiled XLA.

Continuing with the pipeline example, we can setup a config file, which contains information about which features should be enabled at runtime. The configuration file can be modified as per our new requirement mentioned above.

; features.ini
[Features]
enable_step1 = false
enable_step2 = true
enable_step3 = true

The data from configuration file can be loaded into a dataclass object. Following snippet shows a “Features” data class, with 3 attributes corresponding to 3 parameters in the config file.

from chex import dataclass

@dataclass
class Features():
enable_step1: bool
enable_step2: bool
enable_step3: bool

def __hash__(self):
return hash((self.enable_step1, self.enable_step2, self.enable_step3))

def __eq__(self, other):
assertions = []
for key in self:
assertions.append(self[key] == other[key])
return all(assertions)

The following snippet shows the program for the pipeline with feature toggle. This pipeline function takes the instance of class Features as an argument.

from functools import partial

@partial(jit, static_argnums=2)
def pipeline_with_toggle(a , b, features: Features):

step1_out = 0
if features.enable_step1:
step1_out = step1(a, b)

step2_out = 0
if features.enable_step2:
step2_out = step2(a, b)

step3_out = 0
if features.enable_step3:
step3_out = step3(a, b)

output = step1_out + step2_out + step3_out

return output

In order to pass arguments to @jit decorator, we need to use @partial decorator provided by python’s functools

We have already seen that during the trace, each argument is replaced with a tracer object with same shape and dtype. It is not possible to perform any conditional statements on the tracer objects. In above scenario, the JAX will throw error at if conditions, as it tries to convert a tracer objects into a boolean. To avoid this, we have to specify to the JIT compiler to treat features as a static argument.

Static arguments are not replaced with tracer objects. The program’s behavior depends on their actual value.

The Figure 1 shows different execution flows of the pipeline based on the values of attributes of Feature class. During compilation, the tracer will follow only one of the execution flows, based on the value of the static variable features ,and will generate compiled XLA for the particular flow.

Figure 1. Execution flows of the example pipeline

Watching it in action

Let’s run the pipeline_with_toggle function with suitable parameters. We will print jaxpr and compare it with the jaxpr of the pipeline_new function.

from jax import random

rng1, rng2 = jax.random.split(random.PRNGKey(14), 2)
a = jax.random.normal(rng1, (1000, 1000))
b = jax.random.normal(rng2, (1000, 1000))
features = load_config("features.ini") # enable_step1 is False

print(jax.make_jaxpr(pipeline_with_toggle, static_argnums=2)(a, b, features))

Following is the jaxpr output of the pipeline_with_toggle function for given arguments. Since the features.enable_step1 is False, that conditional statement was not traced during compilation, thus corresponding add operation is not present in the jaxpr.

{ lambda ; a:f32[1000,1000] b:f32[1000,1000]. let
c:f32[1000,1000] = pjit[
name=pipeline_with_toggle
jaxpr={ lambda ; d:f32[1000,1000] e:f32[1000,1000]. let
f:f32[1000,1000] = mul d e
g:f32[1000,1000] = sub d e
h:f32[1000,1000] = add 0.0 f # NOTE: this operation is missing in jaxpr of "pipeline" function
i:f32[1000,1000] = add h g
in (i,) }
] a b
in (c,) }

Above jaxpr is exactly same as the jaxpr of pipeline_new function, except for one extra add statement used to sum all the outputs. After running timeit(see Figure 2) on both pipeline functions, we see that one extra operation does not affect the overall performance of the pipeline. (we can safely ignore the differences in microseconds)

Figure 2. Comparison of time between the two pipeline functions

Conclusion

JAX’s JIT compilation uses the information available at runtime to generate a low-level compiled XLA. Today we looked at how we can take advantage of the compilation to mimic the Feature Toggles at runtime.

References

  1. JAX tracing
  2. Feature Toggle

--

--

Maneesh Sutar
Thoughtworks: e4r™ Tech Blogs

Consultant Developer at Thoughtworks India. Working in Engineering for Research (e4r).