MuZero: Model-based RL (part1)

Yuki Minai
8 min readJun 21, 2024

--

In previous posts, I introduced various Reinforcement Learning (RL) methods such as Q-learning, Deep Q-learning, and Actor-Critic. These methods typically assume that the dynamics of the environment are known. In other words, an agent knows how the environment changes when taking a certain action, the rewards obtained in each state, and when the episode terminates.

These RL methods, which do not require learning the environment’s dynamics, are called model-free RL. They are model-free because the agent does not internally have or learn a model of the environment; the only goal is to learn the optimal policy through strategic interaction with the environment. This approach works when the environment’s dynamics are defined and known. Board games such as Chess are examples of this, where we know exactly how the board changes with each move and the conditions for winning or losing.

However, in real-life scenarios, we often do not know the environment’s dynamics and must learn them through interaction. For example, a walking robot must figure out how the state would change by moving parts of its body in certain ways. This kind of knowledge about the environment is called a “dynamics model” in RL. Methods that learn a world “dynamics model” through interaction with the environment are called model-based reinforcement learning.

In this series of posts, we will learn one of such methods called Muzero developed by Google Deep Mind in 2020. Before Muzero, model-based algorithms did not enjoy success over model-free methods because:

  • Dynamics models are difficult to learn.
  • Errors in the dynamics model propagate over steps. If a dynamics model makes a large prediction error, there will be a significant gap between the actual and predicted agent states, complicating policy learning.
  • Environments with high-dimensional input spaces (e.g., images in Atari games) are difficult to learn. Even small pixel changes can create different environment states, requiring the agent to learn numerous patterns.

MuZero is (as far as I know) the first algorithm to solve these problems. It learns the dynamics model through interaction and uses it to learn the optimal action to take (i.e., policy).

Here is our roadmap to learn MuZero:

Muzero overview

Before diving into the main topic of this post (self-play to collect training data), let’s briefly discuss the overview of MuZero to get the big picture.

As mentioned earlier, MuZero learns the dynamics model of the environment to learn the optimal policy. Specifically, MuZero utilizes three deep learning models:

  • Representation function
  • Dynamics function
  • Prediction function

We will soon talk about what each model does but I would like to note one important thing— a key difference between MuZero and traditional model-based methods is the use of latent states. MuZero uses latent states derived from actual states to efficiently learn the dynamics and optimal policy.

The representation function learns the mapping between the actual environment state and the latent state. The goal of this model is to extract a latent state that is helpful to accurately predict policies, values, and rewards. In other words, this function extracts key features of the environment relevant to policies, values, and rewards. Consider an Atari game environment where the input is a grid of pixels. In each frame, many pixel values change. However, changes in background pixels are less important than changes in pixels representing the main target or player. Traditional methods, which use the full input space to learn dynamics, must account for these irrelevant changes. In other words, the model has to learn the dynamics of the entire high-dimensional pixel images. With MuZero, the representation function extracts key features of the environment state, focusing on learning the dynamics related to them, making the learning process more efficient. The representation function receives the raw current state as input and returns the corresponding latent state.

The dynamics function learns how the environment changes using the latent state. It receives the current latent state and action, returning the predicted next latent state and the predicted immediate reward at the next latent state. Again, the key difference with MuZero is the use of latent state encoding useful features to predict policies, values, and rewards. The predicted next latent state from the dynamics function does not have any semantic meaning other than containing useful information for prediction.

The prediction function predicts the policy and value at a given latent state. It takes the latent state as input and outputs the predicted policy and value. After completing the model learning process, the predicted policy from this function is used when acting in the environment.

In summary, what these functions do is that:

1. The representation function extracts the latent state of the current state, including all key information about the environment to predict the value, reward, and policy.

2. Given the current latent state, the dynamics function is used to run a mental simulation to predict what will happen in the future if the agent takes a certain action from the current state.

3. The prediction function predicts the values and actions to take (policy) at each latent state while running the mental simulation with the dynamics function. The value and policy guide the mental simulation with the dynamics function.

Remember that when the environment dynamics are known, an agent can use it to run the mental simulation to see what will happen if it takes a certain action. These three models are used to run a mental simulation even when the environment dynamics are unknown. To fit these three models, we need to collect sample experiences. How should we do this? In this post, we focus on how to collect the training data to fit these three models.

Monte Carlo Tree Search to collect training samples

MuZero uses Monte Carlo Tree Search (MCTS) to collect training samples through self-play. A general MCTS consists of four steps:

  1. Selection: Traverse the current tree from the root node using the Upper Confidence Bound (UCB).
  2. Expansion: Add a child node to the leaf node which is optimally reached through the selection process.
  3. Simulation: Perform a random simulation from the expanded child node to a terminating state.
  4. Backpropagation: Update the value of each ancestor node using the expected value.

For a more detailed introduction to MCTS, check this page.

MuZero modifies MCTS to work with latent states. Another difference is that instead of running a random simulation from the leaf node, MuZero uses bootstrapping, which employs the expected value at a node as the value estimate, similar to TD-learning.

Let’s see the implementation of this process step by step.

Prepare gym environment

In this article, we utilize the CartPole-v0 environment in Gymnasium.

Learn more about the CartPole-v0 environment

The goal of this environment is to balance a pole by applying forces in the left and right directions on the cart. It has a discrete action space:
- 0: Push cart to the left
- 1: Push cart to the right

Upon taking an action, either left or right, an agent observes a 4-dimensional state consisting of:
- Cart Position
- Cart Velocity
- Pole Angle
- Pole Angular Velocity

A reward of +1 is granted to the agent at each step while the pole is kept upright. The maximum reward an agent can earn in a single episode is 200.

The episode ends under the following conditions:
- Termination: Pole Angle is greater than ±12°
- Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
- Truncation: Episode length exceeds 200 steps

In the code below, I provide an example of the agent randomly exploring this environment over 20 time steps.

Code overview of the training sample collection process

The below code shows the overview of the sample collection process by playing one game. MuZero iteratively runs the below code multiple times to collect experience over many games. In this post, we will learn 6 components of this code.

(1) Create a Game Object to store game play log

The first step is to create a Game class object to store logs of self-play. The data observed during self-play such as normalized visit count (i.e. visit probability) and expected value of the root node will be stored and used to train deep learning networks.

(2) Create a Node Object to store information about the root node

Next, we define a Node class to store information about each node during MCTS. This class object contains key attributes to represent each node, such as the total visit count, total value, and the hidden representation of the node state. Below is the implementation.

(3) Create a MinMaxStats Object to normalize values

We also create a MinMaxStats class object to normalize the observed value. This class stores the max and min values of the environment to transform the observed value into a range between 0 and 1.

(4) Expand the current root node

To start running MCTS, we first expand the root node. expand_root function runs initial_inference function, which uses representation and prediction functions to get the latent representation, predicted policy, and predicted value of the current root node. We will learn the detailed implementation of initial_inference function in the next post.

The obtained results are stored in Node object representing the current root node. The obtained policy is used as the probability to choose each child node. MuZero uses a Dirichlet random variable to add some randomness to the prior probability of choosing each child. This randomness helps to explore different children during MCTC.

(5) Run MCTS

run_mcts function is the main function to run MCTS. It mainly runs three steps:

  1. Starting from the root node, expand the node based on Upper Confidence Bound score until reaching a node, which has not expanded yet (called a leaf node)
  2. Expand the leaf node by using deep learning networks
  3. Perform backpropagation to update search statistics of each node upto the root node

(5)-1. Select child

Until reaching the leaf node, select_child function chooses a child with the maximum Upper Confidence Bounds (UCB) score to expand the current node. For more detail on the UCB score, please refer to Appendix B of the MuZero paper.

(5)-2. Expand node

After reaching the leaf node, expand_node function expands the leaf node. It uses recurrent_inference function, which leverages dynamics and prediction functions to obtain the predicted reward, policy, value, and next latent state. The obtained outputs are used to register the information about the leaf node and children nodes. We will learn the detailed implementation of recurrent_inference function in the next post.

(5)-3. Backpropagation

After expanding the leaf node, MuZero runs backpropagation. In backward order, it updates the node statistics (total visit count and total value) of each node upto the root node. The discount rate is used to compute the discounted value.

(6) Select and take an action from the root

After performing multiple self-play simulations from the root node with run_mcts function, MuZero chooses the actual action to take to move to the next state. It chooses the action based on the visit count of each child. During a network training phase, it chooses action stochastically. During a test phase, it chooses action deterministically based on argmax policy.

Summary

In this post, we learned how MuZero collects training samples with self-play using Monte Carlo Tree Search. In the next post, we will learn the details of deep learning models used in MuZero.

Codes

Reference

  • Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., … & Silver, D. (2020). Mastering atari, go, chess and shogi by planning with a learned model. Nature, 588(7839), 604–609.
  • muzero-pytorch by koulanurag (https://github.com/koulanurag/muzero-pytorch/tree/master)

--

--

Yuki Minai

Ph.D. student in Neural Computation and Machine Learning at Carnegie Mellon University, Personal webpage: http://yukiminai.com