Building a hardware accelerated simulation using JAX

Nimalan
Thoughtworks: e4r™ Tech Blogs
5 min readDec 15, 2023

In this article we will looking at how to build a hardware accelerated simulation with JAX.

You can find the accompanying code in Github.

JAX

JAX is a high performance numerical computation library with composable transforms. JAX is built on top of the XLA linear algebra compiler, which is capable of running code on CPUs, GPUs and TPUs. XLA is able to run on multiple platforms as it leverages MLIR.

Although most common use cases of JAX is in machine learning, JAX can model Linear Algebra computation closer to the domain. This allows for building high performant applications, that can easily target hardware accelerators.

The Problem Statement

We will be using the code of a finite volume simulation from this article as reference and we will create a hardware accelerated version of it.

Modelling in JAX

The jax.numpy module in JAX is a drop in replacement for numpy and has most of the numpy API. The key thing to remember is that JAX requires you to use pure functions: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.

# Numpy
f_dx = ( np.roll(f,R,axis=0) - np.roll(f,L,axis=0) ) / (2*dx)
f_dy = ( np.roll(f,R,axis=1) - np.roll(f,L,axis=1) ) / (2*dx)

# JAX
f_dx = ( jnp.roll(f,R,axis=0) - jnp.roll(f,L,axis=0) ) / (2*dx)
f_dy = ( jnp.roll(f,R,axis=1) - jnp.roll(f,L,axis=1) ) / (2*dx)

The code from the tutorial already used pure functions, so the only change needed to port from numpy to JAX was to change numpy function calls from np to jnp. This demonstrates how easy it is to migrate existing python code into JAX.

Keep in mind that numpy is executed on the host machine and JAX functions are compiled and executed on the device (GPU/TPU). We need to be aware about when we transfer data from the host (CPU) to device and vice-versa. Data transfer from device to host is expensive and will down down the execution.

Simulations in JAX

JAX operates on pure functions however a simulation is an iterative stateful process.

The State Monad in functional programming can be used to solve this problem. Each iteration of the simulation can be seen represented as a function that takes in a state, performs computation and return a new state.

iteration: s -> (s, a)

We create the initial state of the system, pass it to the iteration function a number of times with lax.fori_loop and we get the final state of the system. An advantage that we get from this pattern is that we can save the state of simulation at a particular iteration and replay from that point.

Dataclasses

For a physics simulation such as this, there are a lot of variables in the state. To encapsulate them we create immutable dataclasses. Typically for working with classes in JAX we need to define how the class has to be flattened and unflattened. chex is a library in the JAX ecosystem with a convenient utilities, we can use it’s dataclass decorator to simplify this for us.

@chex.dataclass
class State:
Mass: jax.Array
Momx: jax.Array
Momy: jax.Array
Energy: jax.Array

@chex.dataclass
class Environment:
courant_fac: jnp.float32
dx: jnp.float32
gamma: jnp.float32
vol: jnp.float32

In this case the variables that change with each iteration are part of the State dataclass and the constant variables that describe the environment are part of the Environment dataclass.

Simulation Step

A simulation iteration takes the Environment, State and returns the next state. We add the jit annotation here to tell JAX to perform Just-in-Time compilation of the function and all the other functions called withing the iteration function. The iteration function is a good candidate for JIT as it does not change. Performing JIT on this function will speed up the execution significantly.

@jit
def iteration(environment: Environment, state: State) -> State:

# Logic of the simulation

# Create a new state for the next iteration
return state.replace(
Mass = Mass,
Momx = Momx,
Momy = Momy,
Energy = Energy
)

@jit
def multi_step(environment: Environment, state: State, count: int):
return lax.fori_loop(0, count, lambda i, s: iteration(environment, s), state)

Capturing results

We would want to capture the results of the simulation periodically, however the simulation happens on an accelerator where data transfer to host is expensive. So we will have to minimise the number of times we perform data transfer.

To do this we periodically store the results in a separate variable on the device, and when the execution is done we send it back to the host. For simulations with large number of snapshots we will have to periodically send the data back to the device.

from jax import tree_util
import numpy as np

def run_simulation(
environment: Environment,
state: State,
count: int,
nr_snapshots: int = 10
):
nr_iter_per_snapshot = count // nr_snapshots

# For large simulations following steps will have to be repeated periodically

results = [state]
for i in range(nr_snapshots):
state = multi_step(environment, state, nr_iter_per_snapshot)
results.append(state)

# Get data from device
results = jax.device_get(results)

return tree_util.tree_map(lambda *xs: np.stack([np.array(x) for x in xs]), *results)

Running the Simulation

To run the simulation we need to create a environment and state instance and call the run_simulation method we created.

environment = Environment(
courant_fac=courant_fac,
dx=dx,
gamma=gamma,
vol=vol
)

initial_state = State(
Mass=Mass,
Momx=Momx,
Momy=Momy,
Energy=Energy
)

nr_iterations = 1000
nr_snapshots = 100
results = run_simulation(environment, initial_state, nr_iterations, nr_snapshots)

The execution time is dependent on the number of snapshots, frequent snapshots will slow the execution time and also increase the memory consumption. Running the simulation on a Google Collab T4 GPU runtime produced this results.

Results of running the simulation

Summary

We saw how JAX can be used to model a finite volume fluid simulation. This pattern can be used to create simulations in other domains as well. Porting existing python code to JAX is easy while also being high performant. By using JAX we also were able to target GPUs.

The code can be found in Github, you can also try this simulation out in Google Collab

References

  1. https://github.com/pmocz/finitevolume-python/blob/master/finitevolume.py
  2. https://github.com/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb

Disclaimer: The statements and opinions expressed in this blog are those of the author(s) and do not necessarily reflect the positions of Thoughtworks.

--

--

Nimalan
Thoughtworks: e4r™ Tech Blogs

Research Engineer. Working on High Performance Computing and Accelerated Computing.