MuZero: Model-based RL (part2)

Yuki Minai
5 min readJun 21, 2024

--

This is a series of blog posts to learn Muzero, which is a popular model-based reinforcement learning algorithm.

In part1, we covered the overview of Muzero as well as Monte Carlo Tree Search (MCTS) to collect training samples through self-play. In this post, we will learn the deep learning models used in MuZero.

Review of three models

As we learned in part1, Muzero uses three deep learning models to learn the dynamics of the environment as well as the optimal policy. They are:

  • Representation model: s⁰ = h_θ(oₜ)
    - Input: raw state of the current root
    - Output: latent state of the current root
  • Dynamics model: rᵏ, sᵏ = g_θ(sᵏ⁻¹, aᵏ)
    - Input: latent state, action to take
    - Output: next latent state, expected immediate reward
  • Prediction model: pᵏ, vᵏ = f_θ(sᵏ)
    - Input: latent state
    - Output: policy at the input latent state, expected value at the input latent state

where t is the index for the past and current steps and k is the index for the future steps.

While the dynamics model and prediction model used in the original MuZero paper were trained on multiple output values, we will divide these outputs into a different model in this tutorial to stabilize the training process. More specifically, the code in this blog models each quantity using a separate network:

  • Representation model:
    - Input: raw state of the current root
    - Output: latent state of the current root
  • Dynamic model:
    - Input: latent state, action to take
    - Output: next latent state
  • Reward model:
    - Input: latent state, action to take
    - Output: expected immediate reward
  • Value model:
    - Input: latent state
    - Output: expected value at the input latent state
  • Policy model:
    - Input: latent state
    - Output: policy at the input latent state

The combination of the dynamics model and reward model behaves like the dynamics model of the original Muzero paper. The combination of the value model and policy model behaves like the prediction model of the original Muzero paper.

Muzero learns all of these models at the same time. Thus, the loss function is defined as the sum of three errors:

  • Policy loss: the error between the actions predicted by the policy pᵏₜ and by the search policy π_{t+k} obtained through MCTS
  • Value loss: the error between the predicted value vᵏₜ and the value target zₜ₊ₖ obtained through MCTS
  • Reward loss: the error between the predicted immediate reward rᵏₜ and the observed immediate reward uₜ₊ₖ

With the sum of three loss values, MuZero runs optimizer and gradient descent as we do for a typical deep learning model training.

Let’s review each model one by one.

As in part1, we assume using the CartPole-v0 environment in Gymnasium. The environment has two potential actions and each state is represented by a vector of four values (cart position, cart velocity, pole angle, and pole angular velocity).

Representation network

We first define a representation network. It receives a raw state of the current root node and returns its latent state. Thus, the input shape is the state shape. In the architecture used in the MuZero paper, the input will be transformed into the shape of the hidden neuron size. The outputs from hidden neurons are then transformed into the shape of embedding size to get the output latent state. The hidden neuron size and embedding size are the hyperparameters.

Dynamics network

The dynamic network has a similar architecture to the representation network. But one difference is the input size. The dynamic function receives the latent state and action to take as input. In this tutorial, we use one-hot encoding to represent the action to take. For example, when the cart moves left, the action will be represented as [1,0]. On the other hand, when the cart moves right, the action will be represented as [0,1]. We combine these two-dimensional vectors with the embedded latent state. Thus, the input has the shape of embedding size + action size.

The output is the next latent state reached by taking the input action at the input latent state.

Reward network

The reward network receives the latent state and action to take as the input and returns the predicted immediate reward as the output. In a Cartpole environment, a reward of +1 is granted to the agent at each step while the pole is kept upright. Thus, the predicted immediate reward (output) is a scalar.

Value network

The value network receives the latent state and returns the predicted expected value at the state. Instead of returning the output as a scalar value, MuZero uses an architecture to output multi-dimensional output and then applies an invertible transformation to get the predicted value (scalar). For more detail, please check “Appendix F Network architecture” of the MuZero paper and “Appendix A: Proposition A.2” of this paper.

Policy network

Lastly, the policy network receives the hidden state and returns the policy at the input state. This output value is not a probability. MuZero applies a softmax function to this output to get the probability of taking each action.

Initial inference

In part1, we skipped the detail of two functions, initial_inference and recurrent_inference functions, which were used to run Monte Carlo Tree Search (MCTS). Now, we are ready to cover them. We use initial_inference function to expand the current root node. What this function does is:

  • Use the representation network to get the latent representation of the current root note
  • Use the value network to get the expected value at the current latent state
  • Use the policy network to get the policy at the current latent state

In the below implementation, the InitialModel class integrates these three steps. Thus, in initial_inference function, we create the InitialModel object and use this to return the transformed scalar value, immediate reward (always set as 0 for the root state), policy before applying a sigmoid function, and latent representation of the root state.

Recurrent inference

Another function we used in MCTS is recurrent_inference function. This function is used to run the mental simulation during MCTS. What this function does is:

  • Use the dynamic network to get the next latent state when taking the input action at the input state
  • Use the reward network to get the immediate reward when taking the input action at the input state
  • Use the value network to get the expected value at the next latent state
  • Use the policy network to get the policy at the next latent state

In the below implementation, the RecurrentModel class integrates these four steps. Thus, in recurrent_inference function, we create the RecurrentModel object and use this to return the transformed scalar value, immediate reward, policy before applying a sigmoid function, and latent representation of the next state.

Summary

In this post, we reviewed the deep neural networks used in MuZero. These networks are trained using the data collected with MCTC, which is the process we learned in part1. In the next post, we combine part1 and part2, and add a few more elements to complete MuZero framework.

Codes

Reference

  • Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., … & Silver, D. (2020). Mastering atari, go, chess and shogi by planning with a learned model. Nature, 588(7839), 604–609.
  • muzero-pytorch by koulanurag (https://github.com/koulanurag/muzero-pytorch/tree/master)

--

--

Yuki Minai

Ph.D. student in Neural Computation and Machine Learning at Carnegie Mellon University, Personal webpage: http://yukiminai.com