Curious Agents III: BYOL-Explore

Dries Smit
InstaDeep
Published in
8 min readJul 1, 2023

Welcome back to our series on self-supervised learning and curiosity for pre-train agents in open-world environments. This post builds on Part 1 and Part 2, where we implement an agent that could solve a MountainCar environment without ever receiving external rewards from the environment. In this post, we will be looking at some more recent developments in this space. Specifically, we will be looking at DeepMind’s work on BYOL-Explore which addresses an issue that our current algorithm has with large noisy observation spaces. All code used in the post is again made freely available here.

Our MountainCar environment, as used in the previous post, had only two values in its observation vector. Both values were perfectly predictable from the previous observation step. However, this might not always be the case. In more complex real-world environments, the observation space might be very large. Furthermore, there might be some aspects of the observation space that are extremely hard to predict/or is not necessary to predict. As Yann LeCun mentioned in his paper (A Path Towards Autonomous Machine Intelligence), it might not be beneficial to try and predict tree leaves moving in the wind. The agent in the previous post would get stuck trying to predict this noise instead of moving on to other more “interesting” parts of the observation space.

One idea that many prominent researchers suggested in the past is to make future predictions directly in the latent space, instead of in the observation space. However, this is easier said than done. If we allow an observation encoder to be tunable it might be possible for the network to learn to output constant vectors for perfect future predictions. This is not helpful as we still want to learn useful latent variables. To alleviate this OpenAI proposed using Random Distillation Networks (RDN), where they fix one encoder’s parameters and allow another tunable encoder to predict the latent state of that fixed encoder. They state, “Namely we predict the output of a fixed randomly initialized neural network on the current observation”. This allows for prediction in the latent space without having to worry about the latent values collapsing to a constant vector. Below the loss function for the RDN network is depicted. Here sg represents a stop gradient. The predictor network tries to predict the outputs of the fixed network and would typically have the largest loss on new unexplored observational data.

RND loss function. Source.

However, this method still has its drawbacks. This method does not allow updating what is being encoded. Therefore it might encode features that are not relevant for exploration. The authors noted that in Montezuma’s Revenge, the intrinsic reward was not a strong enough signal to solve the first level of the environment.

BYOL-Explore is one approach that tries to improve upon the previous work. They investigate whether they could design a method that allows for the latent representation to be updated without leading it to collapse to a constant vector output. Their basic idea is represented in the following diagram.

Source.

This diagram looks quite complicated but is actually quite simple to intuitively understand. The first thing they do is encode the observation (o_t) using an encoder f_θ. This generates a latent space representation of the observation. Then they have a closed loop and an open loop unroll function. The encoded observation is fed alongside a previous action to the close-loop RNN cell, which computes b_t. A second open-loop RNN is used to simulate future history representations while observing future actions. This simulated future is encoded in the embeddings (b_{t,k}) and is fed to a final predictor g_θ. A target network then also encodes the observations. This target network uses an exponential moving average to slowly update towards the f_θ.

This architecture, therefore, encodes the (potentially) high-dimensional observation into a lower-dimensional latent space. Thereafter future latent spaces are predicted using an open-loop RNN. As can be seen above, the goal of the predictor (g_θ) is to predict the latent space representations of future observations. In the BYOL-explore loss function, sg represents a gradient stop. Therefore only the predictor g_θ gets updated in this process.

Why does this process not collapse to outputting constant vectors? The authors state the following:

The intuition behind BYOL-Explore is similar in spirit to the one behind BYOL. In early training, the target network is initialized randomly, and so BYOL-Explore’s online network and the closed-loop RNN are trained to predict random features of the future. This encourages the online observation representation to
capture information that is useful to predict the future. This information is then distilled into the target observation encoder network through the EMA slow copy mechanism. In turn, these features become targets for the online network and predicting them can further improve the quality of the online representation.

We now implement our own version of BYOL-Explore and test it on a simple 2D environment available inside the Jumanji library. This library contains many JAX environments. We focus on the Maze environment. In this environment, the agent receives a top-down 3 dimensional observation of the environment. The agent’s location is represented by the green block in the below image. It can move left, right, up, down or do nothing and its goal is to reach the red target location. This environment is interesting to us because it has a large observation size. Therefore making predictions directly in the observation space becomes more challenging. We thus look to BYOL-Explore as a potential solution.

A random policy on Jumanji’s Maze environment.

To start we create an observational encoder network as follows:

class ObservationEncoder(nn.Module):
latent_size: Sequence[int]
activation: str = "relu"

@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh

# Convolutional layers

layer_out = x

for _ in range(3):
layer_out = nn.Conv(
features=32, # increased the number of features
kernel_size=(3, 3),
strides=(2, 2),
padding="SAME", # added padding
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(layer_out)
layer_out = activation(layer_out)

layer_out = layer_out.reshape((layer_out.shape[0], -1))

for _ in range(2):
layer_out = nn.Dense(
128, # increased the number of features
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(layer_out)
layer_out = activation(layer_out)

layer_out = nn.Dense(
self.latent_size,
kernel_init=orthogonal(1.0),
bias_init=constant(0.0),
)(layer_out)

# tanh activation the output for stability
layer_out = nn.tanh(layer_out)

return layer_out

This network takes in a 3D image representation and outputs latent features. We keep the ActorCritic network the same as in the previous post. We now update the WorldModel to take in this latent representation, instead of the raw observations as follows:

class WorldModel(nn.Module):
action_dim: Sequence[int]
activation: str = "relu"

@nn.compact
def __call__(self, latent_in, action):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh

# One-hot encode the action
one_hot_action = jax.nn.one_hot(action, self.action_dim)

inp = jnp.concatenate([latent_in, one_hot_action], axis=-1)

layer_out = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(inp)
layer_out = activation(layer_out)
layer_out = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(layer_out)
layer_out = activation(layer_out)
layer_out = nn.Dense(latent_in.shape[-1], kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
layer_out
)
return layer_out

Notice that we don’t use a recurrent network. This is a simplification that we can make because the observation represents all information that is present in the environment’s state. Therefore the closed-loop RNN is not needed. We also only do one-step predictions into the future and therefore do not need the one-loop RNN. We now create a BYOL loss function that looks as follows:

def l2_norm_squared(arr, axis=-1):
return jnp.sum(jnp.square(arr), axis=axis)

def byol_loss(pred_l_t, l_t):
# Calculate the world loss
norm_pred_l_t = pred_l_t # normalise(pred_l_t)
norm_l_t = l_t # normalise(l_t)

# Cap the world model loss
return l2_norm_squared(norm_pred_l_t-norm_l_t)

We take in a predicted latent vector and a ground truth latent observation vector (with a stop gradient already applied to it) and calculate the L2 squared norm of those two vectors. The only difference between this implementation and the equation above is the vector normalisation. We could not get normalisation to work properly. Feel free to try the code out for yourself and see if you can get it to work. I would much appreciate some feedback on this.

The world model loss can now be calculated using:

def _wm_loss_fn(online_params, world_model_params, traj_batch):
# RERUN NETWORKS
l_tm1 = self._online_encoder.apply(online_params, traj_batch.obs)
pred_l_t = self._world_model.apply(world_model_params, l_tm1, traj_batch.action)
l_t = jax.lax.stop_gradient(self._target_encoder.apply(train_states.target, traj_batch.next_obs))
return byol_loss(pred_l_t, l_t).mean()

Furthermore, the reward we provide to our agent is equal to this loss value. When we now train our agent of the Maze environment we again see that it learns to almost solve the environment. In the Maze environment, a score of 1.0 means that the agent always reaches the destination before the maximum step limit. As can be seen in the figure below the agent achieves a score of around 0.8.

Environment episode returns over training steps.

We can also visualise the agent’s performance after training. It is nowhere near perfect, but this result is pretty interesting given that we never provided the agent any external reward!

A trained policy on Jumanji’s Maze environment.

Why does our agent converge to trying to solve the environment at all? This can be explained by again looking at what the world model is trying to predict. It is predicting future latent representations from past representations and the action that was taken. It is quite easy to predict future representations within an episode as our environment is deterministic. The only thing that the world model cannot accurately predict is the transitions between episodes, as each episode’s setup is randomly generated. Therefore our policy is incentives to end the episodes as fast as possible to receive the high rewards associated with the end-of-episode transitions. Feel free to try the code out yourself.

We have now mostly solved the Maze environment. It is time to turn our attention to some more challenging environments. In more complicated environments the within-episode transition dynamics are not always deterministic. We don’t want our policy to get stuck on sequences where the world model cannot possibly model those random dynamics. Therefore we will next investigate methods where our agent can learn to ignore pure randomness and instead focus on areas where it can improve. In the next blogpost we will be looking at BYOL-Hindsight, a more advanced curiosity-based algorithm and attempt to solve a simple 2D Minecraft environment.

--

--