Writing an RL Environment in JAX
How to run CartPole at 1.25 Billion Step/Sec
JAX is a relatively new and exciting open-source machine learning framework. Here are some of the great features:
- Compiled using XLA so it can support CPUs, GPUs, and TPUs.
- With the jit function, it can just-in-time compile multiple operations and optimize the computation graph.
- Automatic vectorization via vmap.
- Automatic parallelization with pmap.
- Higher-order differentiation
- Enforces functional programming.
All these features are useful but the first 3 are especially so when writing RL environments. Currently, the most powerful accelerator available for free via Colab or Kaggle notebooks is a TPU, but these are often hard to use, and yet JAX makes it easy. Accelerating individual operations helps but any non-trivial RL environment will include many operations for each time-step. Optimizing the whole graph via the XLA compiler makes a big difference. And I don’t think I could point to anything that caused me more headaches in deep RL than manually managing the batch-dimension! It’s just so much easier to code for a single instance and let JAX take care of vectorization with vmap. For an example of how effective JAX can be at implementing an RL environment look no further Brax. This is a rigid body simulator written in JAX that is able to speed up the classic Ant or Humanoid environments by 100x to 1000x compared to the Mujoco equivalent.
To start with, let’s explore what a simple JAX environment look like. As an instructive example we can try to convert the classic Cartpole-v1 Environment from openai gym. The probably familiar api looks like this
env = gym.make("CartPole-v1")
obsv = env.reset()
obsv, reward, done, info = env.step(action)
After creating the environment, we reset it to get the initial state and then apply an action with the step function. The env.step
function is stateful because the output depends on the internal state of the environment. The env.reset
function is also stateful as it depends on the internal state of the random number generator (each time we reset we get a different state). But we already said JAX only allows functional code without any state or side-effects 🤔
Thankfully the excellent JAX documentation has a solution described in Stateful Computations in JAX. Apply this to a skeleton environment with a gym-like we can come to a sample implementation.
The environment is designed to automatically reset as this simplifies rollouts on batched environments and doesn’t detract from usability. This is equivalent to using the gym.vector wrappers.
Functions
__init__
is the constructor used used to set values on the environment instance likeself.random_limit
. This is ok because we never changed and used as as constants rather than state, so calls to the environment are still functional._get_obsv
is a private function that maps the state to to an observation since we might want to use different observations than the raw state. For example, we might want to rasterize an image of the environment and keep the compressed internal state hidden._reset
is a private function that actually generates a new random state. This is used by both the public reset function and the step function to conditionally reset the environment._maybe_reset
is used to conditionally reset the environment. It does this using the jax.lax.cond function which allows the branching to be compiled withjit
(more on that later).
jax.lax.cond(done, self._reset, lambda key: env_state, key)
is equivalent to
if done:
return self._reset(key)
else:
return env_state
The challenges in using python control flow in JAX are well documented but the tl;dr version is that JAX uses tracing and doesn’t know which path to take when python control flow is used. This implementation does seem more complicated but it’s worth it for performance.
After the key has been used jax.random.split functionally generates a new key which we return for subsequent random operations.
reset
is a public function that takes in the key and generates a new random state using the_reset
function. The state and key are combined into a tuple called env_state. The env_state is the full state of the environment. When I first started this project I thought the state was sufficient, but as I progressed I realized that the key itself is also part of the state since when the environment reaches a terminal state it resets and a new key is generated. This also means for a given initial env_state we produce the same rollout even if the environment resets multiple times. The env_state and the initial_observation determined using the_get_obsv
function are returned.step
is a public function that takes in the env_state and action. It applies a transformation to advance the internal state. Next, it calculates the reward and done values and uses the_maybe_reset
function to conditionally reset the env_state based on the done value. Finally, the updated env_state, obsv, reward, and done values are returned with obsv again being derived from the_get_obsv
function.
Usage
The usage can be seen in the Gist above below the Class definition. Let’s walk through it. First, we create a key using jax.random.PRNGKey function. The reason we need to do this explicitly is that for the random number generation process to be functional (in this case random.uniform
) its output must depend on the key, i.e. same key = same random output.
We then instantiate the environment which sets the constants.
The reset function is used with the previously created key to define the initial env_state (containing both the state and the updated key) and get the initial_obsv.
The step
function can now be called using the env_state and an arbitrarily chosen action (1) to advance the environment.
The main change we have made to gym API is that all the functions in the environment now require the state to be externally managed and passed in or returned to allow for the functional paradigm to be maintained. This may seem burdensome but it has the advantage of making a function’s behavior deterministic for its inputs.
JAX Cartpole
Now we’ve got a skeleton environment, but it’s not very interesting. Converting the Cartpole Environment to JAX using the skeleton environment as a template we can arrive at the following:
Let’s break down the changes.
- We swap python control flow for an equivalent set of operations when we can as python control flow only has limited support.
force = self.force_mag if action == 1 else -self.force_mag
---
force = self.force_mag * (2 * action - 1)
2. We remove side effects from the reset
andstep
functions in order for them to be strictly functional.
# Removed
assert self.action_space.contains(action), err_msg
log.warn
print statements are side-effects because they not returned as part of the function.
3. We replace incompatible math operations with jax.numpy equivalents.
costheta = math.cos(theta)
sintheta = math.sin(theta)
---
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
This seems to be required because JAX traces ShapedArray and the math functions apply the float
function which requires a concrete value.
4. Replace python control for calculating done with an equivalent set of operations
done = bool( x < -self.x_threshold
or x > self.x_threshold
or theta < -self.theta_threshold_radians
or theta > self.theta_threshold_radians)
---
done = ((x < -self.x_threshold)
| (x > self.x_threshold)
| (theta > self.theta_threshold_radians)
| (theta < -self.theta_threshold_radians))
5. Set reward always equal to 1 since the environment auto-resets.
6. jit
reset and step function using @partial(jit, static_argnums=(0,))
. We need to specify the static_argnums=(0,)
because the first argument self, is treated as as static (constant). The code is runnable without jit but it is extremely slow. Specifically the jax.lax.cond
is slow, so without jit python control flow is faster. This compiles the code to statically-typed expression language called jaxpr which is further compiled to an executation graph by XLA.
Using a make_jaxpr we can visualize what it the compiled operations look like. Below is a sample from the env.step
function. For context the constants come from the self argument.
y = mul p w
z = sub x y
ba = integer_pow[ y=2 ] p
bb = mul ba 0.10000000149011612
bc = div bb 1.100000023841858
bd = sub 1.3333333730697632 bc
be = mul bd 0.5
bf = div z be
bg = mul bf 0.05000000074505806
This implementation allows us to run the environment for a single environment in the same way as the skeleton environment.
Batched Environment
Throughput is limited with a single instance of the environment so we want to run many in parallel. One way to achieve that is by using multiprocessing and running each new instance in a new process, this tends to scale poorly though and it’s much more efficient if we can vectorize the implementation. Vectorizing code can be tricky because you need to consider the batch dimension in all operations you implement, thankfully JAX makes this much easier with vmap. Using this function we can automatically vectorize the reset
and step
functions. Now we can step through many instances of the environment without changing the code!
Here’s what this looks like. First we take the environment functions and wrap them with vmap
and jit
. The vmapped function benefits from being jitted because that allows xla to optimize the operation across the batch dimension.
vstep = jit(jax.vmap(env.step, in_axes=((0, 0), 0), out_axes=((0, 0), 0, 0, 0, 0), axis_name="batch_axis"))vreset = jit(jax.vmap(env.reset, out_axes=((0, 0), 0), axis_name="batch_axis"))
The in_axes and out_axes specify the batch dimension of the each of the jnp arrays in the arguments.
Now instead of passing a single key to reset
we pass as many keys as we want environments.
NUM_ENV = 10
seed = 0
key = jax.random.PRNGKey(seed)
keys = random.split(key, NUM_ENV)
env_state, obsv = vreset(keys)
We then get a vectorized env_state back from reset which we can use together with a vectorized action which similarly has vectorized return values.
action = random.randint(keys[0], (NUM_ENV,), 0, 2)
new_env_state, obsv, reward, done, info = vstep(env_state, action)
Batched Environment on Multiple Devices
Utilizing a TPU or multiple gpus fully means dealing with multiple devices. We need a way to vectorise across the device dimension as well. JAX provides this functionality through pmap. Similarly to vmap we can just wrap our jitted vmapped functions with pmap. We don’t need to apply jit again in this case because pmap automatically jits the function.
pvreset = jax.pmap(vreset, out_axes=((0, 0), 0), axis_name="device_axis")pvstep = jax.pmap(vstep, in_axes=((0, 0), 0), out_axes=((0, 0), 0, 0, 0, 0), axis_name="device_axis")
Something that might seem confusing here is the arguments to pmap look the same as vmap, which might make it look like we are using the same dimension as both the device & batch dimension. The way it works in this case is that the pmapped reset and step functions take in a jnp.array with dimensions[device_axis, batch_axis, …] the first dimension as the device axis which is removed by pmap, so vmap only sees the 2-dimensional jnp.array [batch_axis, …].
Compiled Rollout
Up to this point, we’ve been compiling the environment step and reset functions. This means we execute each step on the CPU or accelerator before returning control to the python kernel. There is a latency cost associated with dispatching execution. I noticed that this seemed to be very small for the CPU, sub-millisecond for the GPU but on the order of several milliseconds on the TPU. This forces very large batch sizes to get good throughput on the accelerators, especially the TPU.
To address this we can try to jit compile the entire rollout. That way the python kernel simply starts a rollout of the environment and gets data for the whole rollout back. This approach should be viable for RL algorithms that have a data-collection phase and a training phase e.g. Proximal Policy Optimization (PPO) or Q-Learning.
For this example we randomly select actions in-place of a model but there should be no issue in using a Flax based neural network in the same function.
See the example implementation below.
We define a rollout
function that resets the environment, steps through a fixed number of iterations, and returns the data. Note that rather than using a python for loop as is typical we use a JAX control flow equivalent jax.lax.fori. It is actually possible to use python control flow here but it takes a very long time to compile as it needs to trace through every step in the loop. The main downside I see in using JAX control flow is that it’s harder to read.
The jax.lax.fori is well documented, so that should help understand the loop implementation. Other things worth noting here are:
- We pre-allocate the obsv, reward, and done arrays and at each step functionally create a new array. This sounds inefficient but think of it as providing functional operations as what you want to happen and then letting the XLA compiler can implement that in a performant way.
- We can directly use pmap & vmap as decorators without arguments because by default they assume no static arguments and that the first dimension in the pmapped or vmapped dimension respectively which for the keys argument is correct.
Benchmarks
The chart below shows the number of steps per second achieved for a range of number of environments (NUM_ENV) starting at 1.
The baseline is Openai Gym Cartpole @ 94.2k steps/sec.
It’s worth highlighting here is that all benchmarks were running in Colab including the CPU benchmarks (so only one physical core).As mentioned earlier gym environments can be run in parallel using multiprocessing or similar but this only works if you have CPU cycles or cores to spare which is sadly not the case in Colab.
The GPU version was a k80 and the TPU version was V2.
The Colab notebook used to collect the benchmark result is available here.
Discussion
For the equivalent JAX version single environment on CPU, we see 95k steps/sec so basically equivalent to the baseline. On GPU and TPU it’s much slower. Increasing the batch size to 100 saturates the CPU performance @ 4.7M steps/sec. The GPU and TPU can reach much higher throughput (>400M steps/sec) at very large NUM_ENV (>1M).
The results for the compiled rollout are more interesting with the CPU rollout reaching 1M steps/sec. That’s 10x faster than either the Gym environment or the vectorized JAX environment when stepping manually. The GPU only has a “marginal” improvement of 2x at the same batch size. The TPU performance improves dramatically (400x) which is probably due to the dispatch latency now being a minor factor whereas it was the dominant factor when stepping the environment. This scales all the way to NUM_ENV=800k across the 8 devices before OOM occurs reaching an incredible 1.25G steps /sec.
Conclusion
We’ve seen how you might write an RL environment in JAX and how it can be both faster serially by compiling the rollout or made efficiently parallel using vmap and pmap. JAX has the potential to accelerate deep RL research by literally speeding up the environments to bring down the iteration time and making accelerators like TPUs easy to use.
Best of luck with your experiments!