MuZero for dummies!

michelangelo
12 min readSep 6, 2022

--

We will see how to develop a simple but working implementation of MuZero, a revolutionary AI algorithm developed by DeepMind.

Photo by Alex Knight on Unsplash

MuZero: idea

We have already seen what the MCTS algorithm is and how we can empower it with the use of Neural Networks to learn to play (not only) videogames with the AlphaZero algorithm.

The key idea of AlphaZero is that, in order to avoid the MCTS algorithm to fall short when using it in games with very big state spaces, we can use neural networks to approximate the Value and the Policy of the tree nodes, so that the limits of the MCTS algorithm are overcomed.

MuZero builds on the same idea of AlphaZero, but goes a step further. In fact, in AlphaZero, we had complete access to game information: for every action we decided to take, we played such move in an exact copy of the actual game, and we could actually see how the environment state changed as a result of that action.
This is a luxury that we can’t always have: can we, for example, in real life, see what the results of taking an action really is? Well, no until we have taken it, but we have learned to predict what the results of taking most common actions are, and we play along the “game” of life like this.

MuZero builds on the idea that with the use of neural networks, we can actually learn the dynamics of the game itself, so that when we have to plan in order to look for a move, we can guess not only how good the move is, but also what will happen as the result of taking it. In the end, we do not need to play along with the game environment, because we will use neural networks to “simulate” it.

This is the key concept of MuZero, and there is a particular aspect, which I think is the most incredible one, but also the most difficult for my mind to understand and “accept”: in fact, the neural network that will approximate the dynamics of the environment, will NOT be trained to approximate the environment itself, that is, we are not trying to create an approximated version of the environment by training the network with information about it, no, we are trying to reach our objective (i.e. minimize the loss function, that we will see later) but giving no constraints about how the neural network that predicts the dynamics should do so.
The neural network is free to represent its internal state as best as it wants to in order to represent the dynamics of the game, but it has no “human” restriction on how to do so.
We do not know how internally the neural network represents the dynamics of the environment, and it is free to do so with something that is not correlated with the real environment at all.

This, I think, is the key powerful idea of MuZero, as it gives the neural network the possibility to really flourish by not giving it any “constraint” and letting it “imagine” the dynamics of the environment as it wants, and, in general, I think this is a fundamental idea in order to build algorithms that aims towards build AGIs.

Thank you DeepMind for opening my mind to such kind of ideas! As a computer scientist, in the beginning, I was very reluctant to understand and accept it, but now I start seeing things differently!

MuZero: implementation

So we are ready to tackle a simple but working implementation of MuZero. In this case, I have decided to implement the pseudo-code that the authors at DeepMind have kindly shared along with the academic paper of the algorithm itself.

The extra effort for understanding it is really worth as the code is very well structured into classes, the flow is clear, easy to read and implement. Also, it is from the authors of the algorithm itself, so who better than them can provide a guide for the actual implementation?

As usual, I will try to simplify and explain the basic ideas and concepts about the algorithm, and I will not focus on the details of the (many) optimizations that we will encounter along the way: each one of such optimizations and concepts would require an article (and more) on its own!

In any case, I assume familiarity with both the MCTS and AlphaZero algorithm. If you need a refresher, you can read the previous articles.

Helpers

Let’s start with defining helper classes that we will use along the way. There are a lot of them, but they are quite simple to understand and will simplify the actual algorithm code later on.

MinMaxStats

The MinMaxStats object contains minimum and maximum values encountered along MCTS search. We will use it to bound the value of the node between 0 and 1, so that the Value of the node is in the same range of the Policy of the node during the search.

Config

The MuZeroConfig object stores the configuration of hyperparameters used by the algorithm for search and training. Some of such parameters are common for all the games, some are very specific, and each game has its own specific configuration. The meaning of each hyperparameter will be explained along the way when we will use them in the actual algorithm.

CartPole config

As usual, our environment of choice is CartPole-v1. It is simple enough to be solved by such simple implementation, but complex enough to check if learning is actually happening.

I have found that this implementation of MuZero is quite sensible about chosen hyperparameters, but I have managed to find values working for the CartPole environment. Also, after the MCTS will have completed its search, we can add an extra exploration probability to the choice of the action by using the visit_softmax_temperature function.

Action and Player

The Action and Player classes do not have anything special in this implementation, they are just wrapper objects.

Node

The Node class represents the usual MCTS node. Nothing that we have not already seen, but notice the to_play attribute that, in two player games, would change every time we change level of the MCTS tree, to take into account that the value of the node for one player, is the opposite for the other one.

ActionHistory

The ActionHistory is an helper class that stores the history of actions we have taken along a MCTS search.

Environment

The Environment is a wrapper for the gym environment, CartPole-v1 in our case, and defines how to actually act in the environment just by using the step method of the openAI gym interface.

Game

The Game represents a single episode of interaction with the environment. In particular, the store_search_statistics will collect data about visit counts and node values during the MCTS search, and the and the make_target allows to build the target values for the neural networks to aim for:

  • for the target Value we will use a technique called bootstrapping so that the target value of a given node is the discounted root value of the MCTS tree N steps into the future (where N is given in our config by the td_steps hyperparameter), plus the (discounted) sum of all rewards until then (notice that in AlphaZero we did not have the rewards taken into account, but now we have). Every time we use an approximated value to estimate another value, like in this case, we say that we are bootstrapping.
  • for the Policy of the node (i.e. the probability of exploring each action) we use the number of visits, so our target is the MCTS exploring policy.

Also, and this is very important, as we are looking N steps into the future for doing our estimations, it may happen that these states falls after the end of the actual game. It is important that in these cases our target has null or zero target values, such that these states are learned by the neural network and treated as “absorbing states”. In fact, for the neural network, there is no such a concept of “end game”, and it will continue to imagine how its dynamics will expand forever! By training for absorbing states, the value of such nodes will be lower than other states, and this will be enough to avoid them and choose other, better actions, in the actual game.

Replay Buffer

Nothing major to say about the ReplayBuffer, only that for this simple implementation we do not need any kind of priority mechanism (such as PER, Prioritized Experience Replay) for choosing episodes and positions: completely random choices are enough.

Network

The Network class represents the brain of MuZero, and is worth understanding the underlying ideas correctly. We have different neural networks that approximates different things:

  • The Representation network allows us to go from an actual observation of the environment, to an internal state of the neural network. We will call this state an HiddenState. Note that the output size of the HiddenState is not the same of the actual observation size of the environment, because, as we said, they are completely unrelated!
  • The Value network approximates the actual value of an HiddenState (how good it is) as a single output number.
  • The Policy network approximates the MCTS exploring policy in the form of probabilities for each action.
  • The Reward network approximates the reward of an HiddenState and Action, it will tell us how good is to take this single action in that state.

The MuZero algorithms builds on the combined usage of such networks:

  • initial_inference is used for example at the start of a new game and during the learning process and takes as input the actual environment observation, transforming it into an HiddenState using the Representation network and then using such output HiddenState to query the other networks for gathering the predicted Value and Policy (Reward is assumed 0 as there is no previous action taken).
  • recurrent_inference is the function that is used to transition from an HiddenState to the next one, which is what we will do to simulate the game in the “mind” of MuZero: given the current HiddenState and an Action, we use the Dynamics network to get the next HiddenState, and then we use the other networks to approximate Value, Policy and Reward.

Using initial_inference at the root of the MCTS search and then recurrent_inference on the intermediate nodes we have a powerful search mechanism to collect statistics about MCTS values, all in MuZero brain. Wow!

As usual, the better the networks approximates the real Values, Policy and Rewards, the more “expert” our plays will be. We will see later on how to train the networks to correctly approximate such quantities.

At the end, the NetworkOuput is just a structure containing the Value, Reward, Policy and HiddenState predicted by the networks. Just a notice about this specific implementation where I have used the softmax activation function for the last layer of the Policy network, hence its outputs are not really raw logits as the pseudocode suggests, but already actual probabilities.

SharedStorage

The last helper class that we need before the actual algorithm implementation is the SharedStorage.

In theory, this is a shared space where to store Network objects after each training step, so that jobs (i.e. parallel threads) that finish collecting data about a game can pick the last up-to-date version, train it, and store a new version for other jobs to use it.

In practice, in our implementation, we have a single Network object and there are no threads, games are played one at time and so is network training.

Self-Play and Training

We are now ready to tackle the main MuZero algorithm. It is basically divided into two parts: the first part plays a game and stores gathered statistics putting them into the shared ReplayBuffer, the second part is about training the neural network sampling the data from the buffer. While these steps can be executed in parallel and by multiple threads, in this simple implementation everything happens sequentially.

Part 1: Self-Play

So the steps are quite simple: grab the most up-to-date network (remember, we only have one in this implementation), play the game, store the statistics.

And we already know how we are going to play games, right? We use the MCTS algorithm to select an action at each step. Notice how we create the root node of the MCTS calling the initial_inference function: we pass to it the initial observation of the environment, and the Representation network will take care of transforming it into an internal HiddenState. The first observation of the environment is the only thing that we will know about the real game!

Not much to say about the core MCTS algorithm, we have dedicated an entire article to it, and this is just a different implementation. Just notice the use of the recurrent_inference function. We use the Dynamics network, to go from one state to another of the imaginary game into the “mind” of MuZero.

Action selection is as usual proportional to MCTS visit counts, with a twist: we use a temperature value to encourage exploration, and we reduce such value based on the number of episodes we have trained on, so that in the beginning we favor exploration, and in the end, when we are confident about the moves, we favor exploitation.

The remaining MCTS accessory functions are almost self-explanatory:

  • select_child: chooses the child that maximizes the ucb value.
  • ucb_score: calculates the ucb value for a node. It is slightly different from the one used in the AlphaZero implementation, but the concept is the same: one term comes from the Value of the node (scaled to the range 0 to 1) and the other term is given by the MCTS policy. Objective of the function is to balance exploration and exploitation.
  • backpropagate: at the end of the search it propagates up to the root the gathered information. Notice how at each level of the tree the Value of the node depends on the current player, so as to consider also 2 player games.
  • expand_node: expands the current node creating children with the Policy values that comes from the network prediction. As in this implementation we are already using the softmax activation function for the network, we can use such probabilities directly, so I have commented out the logits approach.
  • add_exploration_noise: last, but not least, when we create the root node for the MCTS we add to it Dirichlet noise again to favor exploration.

And this concludes the first part of the MuZero algorithm.

Part 2: Training

The second part of the algorithm is all about how to train the various neural networks. In particular, we train all of them trying to minimize the following:

Formula for calculating MuZero loss as found in the original paper.

That is, we are trying to minimize the differences between:

  • The predicted Policy with respect to the MCTS policy.
  • The predicted Values with respect to the actual values (i.e. the ones we calculated using the bootstrapping technique).
  • The predicted Rewards with respect to the observed ones.
  • Finally, a regularization term is added in order to penalize big network weights.

All parameters of the model are trained jointly minimizing such loss.

Main

There is not much to add to the MuZero algorithm, we have only to create some code to actually use it:

We use the usual plot functions in order to track progress. This time, we will also track time, as it will take much more time to train the networks compared to previous MCTS and AlphaZero implementations. In fact, if for AlphaZero you could go grab a cup of coffee, get back and enjoy the result, for the training of MuZero I would suggest leave it overnight and get a good night of sleep, in order to wake up with the following:

Reward over 225 episodes.
Loss over 225 episodes.

And that’s it, we just coded a simple but working implementation of MuZero!

You can write the code on your own (suggested) or find the repository on GitHub.

Besides the implementation, what I think its important to understand are the basic concepts and ideas that we discussed about along the way: they for sure gave me a different way about how to see things and about how to code!

--

--

michelangelo

Computer Scientist at the core | AI passionate | Deep Reinforcement Learning enthusiast