MuZero: Model-based RL (part3)

Yuki Minai
4 min readJun 21, 2024

--

In part 1, we learned Monte Carlo Tree Search to collect training data. In part 2, we covered the deep learning models used in MuZero. In this post, we will integrate these two main components and other components to complete the entire Muzero algorithm.

Environment

Firstly, we define an environment which MuZero interacts with. As in part1 and part2, we use the CartPole environment.

Node class

Next, we define three classes to store key information in MuZero (Node, Game, and ReplayBuffer classes). The first class we define is a Node class to store the information during MCTS self-play. This is the same code as the one we learned in part1.

Game class

Next, we define a game class that stores a single episode of interaction with the environment. Using store_search_statistics function, MuZero stores the experience in this Game class object. MuZero stores this class’s information in a replay buffer.

For the model training, make_target function is used to create the target data including target value, reward, and policy at each state.

Replay Buffer class

Next, we define a replay buffer class. MuZero stores game episodes in the replay buffer. When training a model, MuZero samples the stored episodes from this buffer using sample_batch function.

Networks class

Networks class is used to instantiate five deep learning models we covered in part2. We will use the same model architecture as part 2. Networks class includes various helper functions such as initial_inference, recurrent_inference, and _value_transform to run MCTS and train networks.

While _value_transform function transforms a multi-dimensional output from the value network into a scalar predicted value, _scalar_to_support function performs the inverse transformation — it transforms a scalar target value into a multi-dimensional value to train the value network.

Train network

The below code defines several functions to train the deep learning models.

train_network function trains the networks by using the batch data sampled from the replay buffer. It calls update_weights function, which performs all steps of network training. To maintain a roughly similar magnitude of the gradient across a different number of unroll steps during MCTS, MuZero scales the gradient with scale_gradient function. For more details on gradient scaling, please refer to Appendix G of the MuZero paper.

MCTS class

To run MCTS, we summarize functions we covered in part1 such as run_mcts into MCTS class. This class contains various helpers to run self-play.

Helper classes to store and plot train/test performance

We define two optional classes to store training loss and test reward at each epoch for plotting.

Main function

Finally, let’s define the main function of MuZero (self-play function) by integrating everything we defined so far.

Muzero config setting

Here, we define MuZero’s hyperparameters. Parameter values are not optimized for this environment so it would be possible to obtain a better performance by adjusting these values.

Run Muzero

Now, we are ready to run MuZero! Below code train MuZero with Cartpole environment. We can see that each loss gets smaller and the reward gets larger and reaches the max reward (200) as the training progresses. This suggests that MuZero learns the environment dynamics as well as the optimal policy to behave well in the environment!

By the way, the policy loss is relatively larger than other losses because there could be multiple optimal actions with the Cartpole environment at each state, and thus it is difficult to determine an optimal action to minimize the loss. Another thing to note is that the total reward earned during training tends to be lower than the total reward during test because action selection during training involves random noise to encourage exploration while no random noise is added in action selection during test phase.

Summary

In this series of blog posts, we learned the MuZero which is a model-based reinforcement learning algorithm. With the Cartpole gym environment, we confirmed that MuZero can learn good policy by learning the environment dynamics from scratch.

While the above implementation works okay, the MuZero paper suggested a few more tricks to improve the performance such as Reanalyze. If you are interested in those techniques, I would encourage you to check their paper. For example, Reanalyze is in Appendix H.

Codes

Reference

  • Schrittwieser, J., Antonoglou, I., Hubert, T., Simonyan, K., Sifre, L., Schmitt, S., … & Silver, D. (2020). Mastering atari, go, chess and shogi by planning with a learned model. Nature, 588(7839), 604–609.
  • muzero-pytorch by koulanurag (https://github.com/koulanurag/muzero-pytorch/tree/master)

--

--

Yuki Minai

Ph.D. student in Neural Computation and Machine Learning at Carnegie Mellon University, Personal webpage: http://yukiminai.com