Model-based Reinforcement Learning with Ray RLlib
Author: Michael Luo, Sven Mika
TLDR; So far, RLlib has supported model-free reinforcement learning-, evolutionary-, and planning algorithms. In this blog post, we describe the successful expansion of RLlib to a new class of algorithms: “model-based RL (MBRL)”. Reading this blogpost you will learn how MBRL works, how to run MBRL experiments against various environments and how we utilized RLlib’s new distributed-execution API to develop these new algos.
Specifically, we will focus on the following new MBRL algorithms in RLlib:
- MBMPO (Model-based Meta Policy Optimization) — A Dyna style MBRL algorithm that learns from fake data generated from an ensemble of dynamics models.
- Dreamer — A Gradient-based MBRL algorithm that learns by imagining trajectories into the future.
In addition, for each of these algorithms, we present results and simple data flows from the new distributed execution API to represent complicated model-based algorithms.
Why Model-based Reinforcement Learning?
Normally, training RL agents requires millions, if not billions, of samples from an environment to reach human level performance. For example, OpenAI Five simulated 900 years of in-game experience using the PPO algorithm before the agent was able to beat the world’s best human players. Doing this in a real-life setting (w/o a simulator) is unrealistic, as e.g. a robot may suffer from wear and tear and cause accidents before learning proper behaviors. Hence, RL agents should be more sample efficient. This is where model-based RL comes in: Instead of training an agent’s policy network using actual environment samples, we only use these samples to train a separate model — in supervised fashion — that can predict the environment’s behavior, and then use this “transition dynamics model” to generate (fake) samples for learning the agent’s policy.
We’ll demonstrate how we developed (using RLlib’s new distributed execution API) and then benchmarked two new model-based agents on RLlib: One Dyna-style algorithm, “Model-based Meta Policy Optimization” (MB-MPO) and one gradient-based algorithm: “Dreamer”. Our results are comparable to the ones from the original papers and can be found in our RL Experiments repo. We performed experiments on Mujoco and the Deepmind Control Suite, both of which are continuous control environments.
Installing Ray/RLlib
RLlib is part of Ray, a python library for distributed computing. It is available as a PyPI package and can be installed like this:
pip install ray[rllib]
Also, you must pip-install either TensorFlow or PyTorch along with the above. Other ways to install Ray are listed here.
MB-MPO: A Dyna-Style MBRL Algorithm
Figure 1: The general schema for Dyna-based MB algorithms. There are three components: 1) the agent , which acts in the real environment and trains from model-generated data, 2) the environment, which sends real samples to the model, and 3) the model, which learns from real environment data and generates fake data for the agent.
Dyna-style algorithms, such as MB-MPO, alternate between sampling from the real environment and sampling from a proxy environment, which is usually represented as a function approximator for the real environment, as illustrated in Figure 1.
An example of a Dyna-style algorithm is Model-Based Meta Policy Optimization (MB-MPO).
Figure 2: MB-MPO Architecture
Conceptual Overview
MB-MPO uses an ensemble of transition dynamics (TD) models to learn from actual environment samples. In essence, MB-MPO is a meta-learning algorithm that treats each TD-model (and its emulated environment) as a different task. The goal of MB-MPO is to meta-learn a policy that can perform and adapt well across the ensemble.
In principle, MB-MPO is stacked on top of MAML, a gradient-based meta-learning algorithm, and uses the same loss function. Ray workers are subsequently viewed as single tasks, and choose one (always the same) dynamics model from the ensemble to generate fake data.
As such, the master worker (driver):
- Collects samples from the real environment, appending samples to a replay buffer.
- Trains the TD ensemble jointly (using above replay buffer).
- Aggregates fake data from workers to perform the meta-update step in the MAML computation graph, training the meta-policy network.
While the workers:
- Collect pre-adaptation fake samples from their dynamics model.
- Perform Inner Adaptation
- Collect post-adaptation fake samples from their dynamics model.
Algorithm: Distributed Execution Plan
The recent introduction of RLlib’s new distributed execution API allows for diverse generalizations to various RL algorithms, including MB-MPO. This is important since meta-learning and model-based algorithms exhibit more complex data flows than a typical model-free RL algorithm. The API is flexible enough to allow us to write any algorithm at a high level and still have it executed in a distributed fashion with Ray.
Figure 3: MBPO Algorithm
Each iteration in MBPO’s algorithm represents the interval between training the dynamics ensemble. For each iteration, the agents steps through several MAML iterations. To further add complexity to the data flow, each MAML iteration samples data for pre-adaptation and post-adaptation.
Despite the complexity of MBMPO’s algorithm above, the execution plan can be concisely expressed in four lines of code.
To model MB-MPO’s algorithm, which encompasses several iterations of MAML, we partition the algorithm into two generator classes/methods.
The first generator method inner_adaptation_steps
collects fake samples from workers and performs inner adaptation on the workers. All samples collected thus far are stored in a buffer, which is released when post-adaptation samples are collected. This method is implicitly the data collection step for MAML.
The transform
method acts as a queue for data to be aggregated and returns the accumulated pre- and post- adaptation data upon completion.
Where the magic occurs is the second generator class MetaUpdate
, which takes post- and pre- adaptation samples from inner_adaptation_steps
and applies the meta-update to the policy’s computation graph. MetaUpdate
is the meta-update for MAML and keeps track of the number of iterations over MAML.
The combine
method waits for MetaUpdate
to return (after N MAML iterations) and returns logging statistics.
Results
We evaluate RLlib’s MB-MPO versus the original paper’s implementation on MuJoCo environments Halfcheetah and Hopper using an episode horizon of 200 timesteps and running for 100k timesteps. All the experiments below were run on a single Titan X GPU machine with 32 (CPU) cores.
Table 1: MuJoCo Benchmarks for MB-MPO
Tensorboard logs are shown. Additional statistics are logged: MAMLIter$i$_DynaTrajInner$j$
corresponds to the agent’s performance across all dynamics models during the MAML iteration i
and inner adaptation j
.
Figure 4: MB-MPO Performance for MuJoCo
Running MB-MPO with RLlib
MB-MPO currently supports most MuJoCo environments. We provide a sample command for the reader to try out:
rllib train -f tuned_examples/mbmpo/halfcheetah-mbmpo.yaml
To run MB-MPO with your own environment, we provide two examples here and here.
An important limitation of the algo is that an actual reward (python) function, def reward(self, obs, action, next_obs)
, must be specified in the environment class.
Our next algorithm, Dreamer, does not have this limitation on the environment. Let’s take a look.
Dreamer: Gradient-Based MBRL Learning in Latent Spaces
Figure 5: Gradient-based MBRL algorithms learn by rolling out trajectories via the dynamics model and policy. This allows for direction differentiation of the RL objective (maximize rewards).
Similar to Dyna-style algorithms, gradient-based MBRL algorithms collect data from the real environment to train a transition dynamics (TD) model. The key difference lies in that Dyna algos view each observation as independent inputs whereas gradient-based methods compute (fake) trajectories from each observation, resulting in the direct differentiation of the RL objective, obfuscating the need to approximate the RL objective with policy gradients.
Figure 6: Dreamer (fake) trajectories visualized; The top row corresponds to simulations in real environments (left: Walker, right: HalfCheetah), the middle row corresponds to TD-model predictions, and the bottom row shows the difference between the two. Above images are logged automatically in Tensorboard (as gifs).
Conceptual Overview
Dreamer is an example of a gradient-based RL algorithm. It can solve long-horizon tasks from image (or other) observations via training the actor and critic in a latent space, whose features are learnt automatically by an encoder network.
Figure x: Learning Dynamics (Left), Learning Actor and Critic (Center), Interacting with Environment (Right)
Dynamics Model: Dreamer uses PlaNET, a partially-stochastic recurrent dynamics model. Like any dynamics model, PlaNET takes in image o_t
, previous action a_{t-1}
, and outputs latent state s_t
, reward r_t
.
PlaNET consists of four components:
- Encoder — Converts in image
o_t
into latent states_t
, representsp(s_t|o_t)
- Decoder — Converts latent state
s_t
back into imageo_t
, representsp(o_t|s_t)
- Transition (Prior) — Take in previous state
s_{t-1}
and actiona_{t-1}
and predicts the next states_t
, representsp(s_t|s_{t-1}, a_{t-1})
- Representation (Posterior) — Takes in previous state
s_{t-1}
, actiona_{t-1}
, and also the current imageo_t
to predict the next states_t
, representsp(s_t|s_{t-1}, a_{t-1}, o_t)
During evaluation, the dynamics model plays the role of encoding the image from the environment to send to the actor to compute actions. During training, PlaNET optimizes the following variational lower bound:
Actor: The actor takes in latent states generated by the PlaNET dynamics model. For each latent state, Dreamer learns by imagining trajectories of horizon h via rolling out the dynamics model and actor. Because the trajectory is rolled out via a differentiable neural network, the loss function for the Actor is obtained via rolling out the dynamics model and actor. Because the trajectory is rolled out via a differentiable neural network, the loss function for the Actor is done via the direction differentiation of the value target!
We omit the derivation of the value target, as it is a variant of Generalized Advantage Estimation.
Critic: The critic takes the imagined states from the actor and computes a simple L2 loss:
Algorithm: Distributed Execution Plan
Figure 7: Dreamer Algorithm
At a first glance, Dreamer’s algorithm is complicated. However, the dynamics learning and behaviour learning in Figure xx can be moved inside the Tensorflow or Torch policy and loss graph. Due to this, the execution pseudocode can be effectively reduced to:
In short, the agent loops through rolling out an episode and training the agent for n iterations. Translating this algorithm to the distributed execution plan is simple. Dreamer’s execution plan automatically handles episode collection in ParallelRollouts
, an iterator that yields the recently collected episode. Handling of the replay buffer and training of the agent is done in DreamerIteration
shown below.
Results
Dreamer is evaluated against the original paper’s implementation, using the Deepmind Control Suite, a single seed run, and measuring after both 100k and 1M timesteps. The experiments below were run on a single Titan X GPU machine with 32 (CPU) cores.
Table 2: Deepmind Control Suite Benchmark on Walker-Walk and Cheetah-Run
Tensorboard logs are shown below.
Figure 8: Dreamer Performance for Deepmind Control Suite
Running Dreamer with RLlib
Dreamer currently supports all Deepmind Control Suite environments.
rllib train -f tuned_examples/dreamer/dreamer-deepmind-control.yaml
Conclusion
That’s it! In short, we introduced model-based RL algorithms to RLlib. We proved that RLlib’s framework can generalize across different types of MBRL algorithms, including Dyna-style (MB-MPO) and gradient-based algorithms (Dreamer).
Of course, this is just a small subset of features that RLlib offers, which range from multiagent training, intrinsic curiosity motivation, to functional RL! If you would like to see more features, please have a look at our full list of examples.
If you’ve been successful in using RLlib, or if you need help with understanding its API, please reach out through Discourse or Slack— we would love to hear from you. If you would like to know about how RLlib is being used in industry, please consider attending Ray Summit.