Lessons from Implementing 12 Deep RL Algorithms in TF and PyTorch

Sven Mika
Distributed Computing with Ray
7 min readSep 23, 2020

In this blog post, I discuss my experience porting 12 different deep RL (reinforcement learning) algos for RLlib from TensorFlow to PyTorch, focusing on differences in performance and APIs relevant to a RL library.

Introduction

The battle between deep learning heavyweights TensorFlow and PyTorch is fully under way. Over the past few months, I have kept myself busy translating all of RLlib’s algorithms from their already existing TensorFlow implementations to respective PyTorch versions. In this blog post, I would like to share the joys and sufferings of doing so with the deep learning community.

And yes, this is all done now: RLlib has declared victory on reaching Torch-vs-TF parity as of release 1.0. Hence, if you still cannot decide which DL framework to use for your next reinforcement learning project after reading this article, don’t worry, you can always change horses later on.

Figure 1: As of Ray version 1.0, RLlib has reached full feature parity for TF and PyTorch. In fact, there are more PyTorch algorithms than TensorFlow due to community contributions.

How to switch between TF and Torch in RLlib

To allow users to easily switch between TF and Torch in RLlib, we added a new “framework” trainer config. For example, to switch to the PyTorch version of an algorithm specify {“framework”: “torch”}. Internally, this tells RLlib to try to use the torch version of a policy for an algorithm (check out the examples of PPOTFPolicy vs PPOTorchPolicy). You can also use this config to toggle between TF eager and graph-based TensorFlow:

Comparing Torch and TF Performance

One of the biggest questions we had with respect to porting algorithms to PyTorch was whether the performance would be competitive with the TensorFlow versions. With RLlib, we have a unique opportunity to do an apples-to-apples comparison between the frameworks, since most of the distributed logic would be shared between both versions of algorithms. Only the core policy of the algorithm would differ.

One of the first things I did after porting each algorithm over was benchmarking it against the original TensorFlow version. All RLlib algorithms have tuned examples, and we also keep an archive of historical performance. This is both to ensure learning, and to measure the wall time performance:

Figure 2: The Atari2600 SpaceInvaders (left) and MuJoCo HalfCheetah (right) environments, frequently used for reinforcement learning benchmark experiments. In our experiments, Atari games were used for discrete action-, MuJoCo for our continuous control algorithms.

Our initial learning results for TensorFlow and PyTorch were, surprisingly, quite similar right off the bat for many algorithms. For example:

Figure 3: Learning performance of SAC (torch 1.4 blue, tf 1.14 static-graph orange) over wall time (in seconds). We compare static-graph tensorflow 1.14 (orange) vs PyTorch 1.4 (blue). Seven seeds were used in a single worker setup (1 CPU; no GPU). The NN config was 2 x 256-node-dense layers for each: policy and Q-network, and we used a training batch size of 256.

We also compared just the raw system throughput (ignoring learning) across a few basic algorithms: PPO, IMPALA, and DQN, and environments: Pong-v0 and CartPole-v0. We ran the benchmarks on a single p2.16xl machine using the following grid sweep. We used 5 workers for each setting except for DQN which we gave only 1 worker:

rllib train — run=PPO — env=Pong-v0 — config=’{“num_gpus”: 1, “num_workers”: 5, “num_envs_per_worker”: 5, “framework”: {“grid_search”: [“tf”, “tfe”, “torch”]}}’

Figure 4: Throughput (steps/s) for each RLlib benchmark scenario. Note that the x-axis is log-scale. We found TF graph mode to be generally the fastest, with Torch close behind. TF eager with tracing off is the slowest (with tracing on, it is similar to TF graph mode, but is harder to debug).

Overall, as the two figures above show, RLlib/Torch provides similar learning performance, at a slight real-time penalty. However, it runs much faster than TF eager with tracing disabled (which provides equivalent functionality in terms of debugging custom models).

Where are my Gradients?

The vast majority of deep RL algorithms require calculating the gradients of a computation for optimizing a policy. Hence, the question of how to calculate gradients and what to do with them afterwards is crucial. This is handled quite differently for TensorFlow, TensorFlow Eager, and PyTorch:

PyTorch stores the gradients directly inside the actual `torch.tensor` variables in a separate property called .grad. The values in .grad are populated once the .backward() method is called on a loss tensor (the calculation of which depends on these variables). Note that in the beginning, .grad is None, but after calling optim.zero_grads, gradients are all set to 0.0.

In TF2.x (eager), gradients are stored in separate tensors, returned by a GradientTape object. An optimizer can then be used to update the variable (whose gradients have been calculated by the tape), whereby — other than in PyTorch — the variable is not initially associated with this optimizer.

As you can see, the methods of calculating gradients between the frameworks are quite different (and we didn’t even look at TF graph mode!), but fortunately similar enough that RLlib can bridge the remaining gaps to seamlessly support both PyTorch and TensorFlow algorithms through a similar functional policy API (blog post).

Of Models and Modules

RLlib allows users to plug in custom neural network models for the policy. Hence, an important question is how the framework allows us to assemble even the craziest models and — after we have made lots of mistakes doing so — — how it allows us to debug them. Both tf.keras’s “Models” and Torch’s “nn.Modules” APIs are mature, thought through, and easy to use. The main differences are that for torch Modules, we must know the input shape at construction time, whereas for keras Layers, you have a build() method, which allows you to deal with this later. This leads to slightly different looking code:

PyTorch

TensorFlow

Again though, the concepts are similar enough that RLlib can offer a unified custom model interface for both frameworks.

How to spend three days on changing one character?

A fun road-block occurred when translating the SAC algorithm from TF to Torch. The Torch version was not learning anything and I was trying to find out why. One of the algorithm’s loss terms involves a sampling step through a (“squashed”) normal distribution. Now, when you need to back-propagate through a sampling step, it gets tricky. Try the following in TensorFlow and Torch:

Tensorflow (top) will automatically apply the reparameterization trick to be able to differentiate through a sampling step, given this is mathematically possible (here, our loss is 2 times the sampled value, which results in a constant gradient of 2.0). Note that for some distributions (e.g. Categorical), reparameterization simply does not work. Torch (bottom) considers the result of sample() a constant (no grads). To fix this in your loss, use the rsample() method instead.

Not knowing about this explicitness for torch, the gradient of one of my tensors stayed None stealthily and it took several days to figure out this was the reason the agent was not learning.

TF vs Torch

So which framework feels more intuitive, more up to the task, and more developer friendly? I have to admit that the ease of use that PyTorch naturally offers is striking. Not having to worry any more about things like: “am I currently in eager mode or not?” or “why did this op end up in a different graph than the rest of my model?” is quite liberating and poses a clear win for PyTorch.

On the other hand, TensorFlow does seem to be the more mature framework. Operations like `tf.sequence_mask` or `tf.atanh` (needed in some RL loss functions), do not have counterparts in PyTorch, which makes it unnecessarily hard to translate code. However, as a library author, I found TensorFlow’s differences in handling of eager vs graph execution, especially across different versions, quite frustrating to deal with.

Conclusion

To summarize, RLlib 1.0 now allows users to seamlessly switch between TensorFlow and PyTorch for their reinforcement learning work. In this blog post we shared our experiences with the key differences from the perspective of a RL library.

If you’re interested in trying out RLlib or getting involved, you can check out our documentation, join the #rllib slack channel, or see how RLlib is being used in industry being joining us at Ray Summit.

--

--