Curious Agents IV: BYOL-Hindsight

Dries Smit
InstaDeep
Published in
10 min readJul 11, 2023

--

Welcome back! In the previous three posts (here, here and here) we looked at methods that can allow agents to learn in environments without requiring any external rewards. Our agents learn by simply being curious. In the last post, we used DeepMind’s BYOL-Explore algorithm to learn to solve Jumanji’s Maze environment with a solve rate of over 80%. And this was achieved without the agent knowing the goal of the environment.

Why did our agent learn to move to the Maze’s end goal in the first place? The policy’s objective, in the BYOL-Explore setup, is to maximise the world model loss over observation transitions. Therefore it attempts to find transitions in an episode that is hard to predict. This seems like a logical thing to do, as we want our policy to visit new areas that the world model has not yet experienced. This setup works well in the Maze environment as the observation transitions that have the most entropy are when the environment resets. When the environment is reset, a new random map is generated that cannot be predicted from the previous observation and action. Therefore the world model will almost always have a high loss value there.

This helps our agent solve the Maze environment, but will not necessarily work in other environments. As an example, let us look at an open-world environment such as Minecraft. If we let BYOL-Explore run in this environment it might just decide to kill itself, as fast as possible, or find highly random observation transitions in the environment to focus on. We certainly do not want that. Therefore we need to refine the objective that we want our agent to achieve. We want our agent to seek out novelty in its environment, but we don’t want it to get distracted by noise. This seems like a tricky thing to encode into a loss function. How can we get an agent to distinguish between novel and noisy transitions?

This is what researchers at DeepMind attempted to solve in their new paper titled, “Curiosity in Hindsight”. In their abstract, they state:

Therefore it is important to distinguish between aspects of world dynamics that are inherently predictable (for which errors reflect epistemic uncertainty) and aspects that are inherently unpredictable (for which errors reflect aleatoric uncertainty): The former should constitute a source of intrinsic reward, whereas the latter should not.

The authors provided the following illustration to understand the problem at hand.

Structural Causal Model. Source.

Here solid squares denote deterministic nodes, the top circles denote observable stochastic nodes, and the bottom circles denote unobservable
stochastic nodes (here they use W to capture any randomness in the agent’s policy). A and X represents the action and a latent representation of the policy’s internal state, respectively. The variable Z is used to capture any stochastic noise in the environment.

We want our agent to not get stuck on stochastic noise from W and Z and rather focus on dynamics that it can model. Luckily the agent already observes W and therefore does not have to infer it. On the other hand, Z is encoded in the environment transition and is not directly provided to the agent. Therefore if one tries to directly predict X_t+1 from A_t and X_t, as seen below, the prediction error can only ever go as low as the entropy of Z_t+1. This is not ideal and might allow the agent to get stuck.

World model prediction. Source.

The main claim that the authors make is that it should be possible to infer Z_t+1 in hindsight.

Source.

They propose training a separate generator network to predict Z_t+1 directly from X_t, A_t and X_t+1. We can then feed that latent value for Z_t+1 directly into the world model which in turn predicts X_t+1. If the generator can, in hindsight, predict Z_t+1 accurately the world model can predict the next latent representation of X accurately. Therefore the loss function of the world model can tend to zero.

However, as you probably noticed, there is an obvious problem with this setup. How do we train the generator to generate true values for Z_t+1? The main problem we started with was that we do not know what Z_t+1 is. We can try to learn Z_t+1 indirectly by just allowing the world model to update the generator’s parameters as well.

Reconstruction loss. Source.

In the BYOL-Hindsight algorithm, this is called the reconstruction loss. There is now an secondary problem with this approach. If the world model loss can be used to adapt the generator’s parameters it can effectively tune the generator to just leak information about X_t+1 and avoid actually modelling the transition function of the environment. This will defeat the whole purpose of the generator. We want the generator to capture any new stochastic information (Z_t+1) that is not inferable from X_t and A_t. To do this we effectively want to enforce that Z_t+1 is independent from both X_t and A_t. If this is true then the world model must still use X_t and A_t (along with Z_t+1) to accurately predict the next state X_t+1. The authors propose introducing another network called the discriminator. The discriminator is tasked with outputting a high value for tuples (X_t, A_t, Z_t+1). Here Z_t+1 is the generator’s output for that inputs (X_t, A_t) and X_t+1. The loss function below takes the generator’s output for that current input and compares it against other Z outputs for other timesteps and episodes.

Invariance loss. Source.

If Z_t+1 is truly independent from X_t and A_t them the discriminator should not be able to produce high values for the generator’s output in that specific transition when compared to other transitions using the same X and A.

The entire system can be described as follows. The discriminator attempts to maximise the above invariance loss. In turn, the generator attempts to minimise this invariance loss (keeping Z_t+1 independent from X_t and A_t) and also minimise the reconstruction loss (maximising the stochastic information captured in Z_t+1). The world model attempts to also minimise the reconstruction loss.

And that is it! There are many more proofs and technicalities in the paper, which I encourage you to read, but the basic idea is captured above. I think it is quite ingenious. An interesting fact that the authors proof is that both the reconstruction and invariance loss is driven to zero given infinite experience. Therefore the agent’s reward will tend to zero over observation transitions. This result has an important implication for our curiosity algorithm. Our agent will eventually get bored of any observation transition in an environment. Therefore it must continue exploring the environment to continue to receive large non-zero rewards. Our agent will then need to explore more of the environment to try and find new and interesting dynamics.

Let’s test this out in practice. This series was partly inspired by the open environment Minecraft. Minecraft is quite computationally expensive to run. So let’s create a drastically simplified JAX based Minecraft environment for us to play around with. We want to specifically test our agent’s ability to not get stuck on stochastic transitions. Therefore we don’t want to exploit the trick of placing the most stochastic transitions at the point of optimal environment returns.

OpenAI blogpost. Source.

In the above diagram a simple progression chart is shown. An Minecraft environment typical has resources such as wooden logs (from chopping down trees), cobblestone, iron ore and diamond ore. We now adapt the Maze environment in Jumanji. In this Minecraft environment we randomly place Steve (the avatar that the agent controls) and 4 resource blocks on a 5 by 5 map as shown below.

Our custom 2D Minecraft environment.

For each of 11 possible levels we reward the agent with a value of 1. The agent levels are specified as follows:

WOODEN_LOG_LEVEL = 1 
WOODEN_PLANK_LEVEL = 2
WOODEN_STICK_LEVEL = 3
WOODEN_PICKAXE_LEVEL = 4
COBBLESTONE_LEVEL = 5
STONE_PICKAXE_LEVEL = 6
IRON_ORE_LEVEL = 7
IRON_IGNOT_LEVEL = 8
IRON_PICKAXE_LEVEL = 9
DIAMOND_ORE_LEVEL = 10
DIAMOND_PICKAXE_LEVEL = 11

The agent starts by moving to the wooden log that is located somewhere on the map. It will then upgrade to level 1 and receive 1 reward. It then needs to output the wooden plank action followed by the stick and then the wooden pickaxe action to progress further. There are 4 move actions and 7 build actions. The agent gets a certain time limit per level and the environment resets if the agent exceeds that time limit. In the above image you can sometime see the agent reaching the wooden log and moving up the levels by randomly pressing the correct build actions.

PPO can easily solve this problem when presented with the external rewards. But can our curiosity based agent reach level 11 without external rewards? A randomly exploring agent never seems to reach level 11 as it is highly unlikely that random actions will get you there. Furthermore, this environment is set up in such a way that it is easy to reset the environment. Those reset observation transitions are highly stochastic. Therefore BYOL-Explore would fail to ever reach level 11 as it would get stuck in just maximising for environment resets. However, BYOL-Hindsight is theoretically set up to get bored of these random transitions. Let’s test this out in practice. We start by adapting our world model to take in an additional variable, z.

class WorldModel(nn.Module):
action_dim: Sequence[int]
activation: str = "relu"

@nn.compact
def __call__(self, x_tm1, a_tm1, z_t):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh

# One-hot encode the action
one_hot_action = jax.nn.one_hot(a_tm1, self.action_dim)

inp = jnp.concatenate([z_t, x_tm1, one_hot_action], axis=-1)

layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(inp)
layer_out = activation(layer_out)
layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(layer_out)
layer_out = activation(layer_out)
layer_out = nn.Dense(x_tm1.shape[-1], kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
layer_out
)
return layer_out

Furthermore we create a generator that can generate this latent variable for us.

class Generator(nn.Module):
z_dim: Sequence[int]
action_dim: Sequence[int]
activation: str = "relu"

@nn.compact
def __call__(self, x_tm1, a_tm1, x_t):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh

# One-hot encode the action
one_hot_action = jax.nn.one_hot(a_tm1, self.action_dim)

inp = jnp.concatenate([x_tm1, x_t, one_hot_action], axis=-1)

layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(inp)
layer_out = activation(layer_out)
layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(layer_out)
layer_out = activation(layer_out)
layer_out = nn.Dense(self.z_dim, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
layer_out
)

# tanh activation the output for stability
layer_out = nn.tanh(layer_out)

return layer_out

Lastly we add a discriminator model as follows.

class Discriminator(nn.Module):
action_dim: Sequence[int]
activation: str = "relu"

@nn.compact
def __call__(self, x_tm1, a_tm1, z_t):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh

# One-hot encode the action
one_hot_action = jax.nn.one_hot(a_tm1, self.action_dim)

inp = jnp.concatenate([z_t, x_tm1, one_hot_action], axis=-1)

layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(inp)
layer_out = activation(layer_out)
layer_out = nn.Dense(
64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)
)(layer_out)
layer_out = activation(layer_out)
layer_out = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
layer_out
)

layer_out = jnp.tanh(layer_out)*5

# Exponentiate the output to get a probability
layer_out = jnp.exp(layer_out)

return layer_out

To calculate our policy’s reward, we now also use the generated to generate our latent variable z_t.

# Calcuate the distance between the predicted and the actual observation
x_t = self._target_encoder.apply(train_states.target, obs)
z_t = self._generator.apply(train_states.generator.params, x_tm1, a_t, x_t)
pred_x_t = self._world_model.apply(train_states.world_model.params, x_tm1, a_t, z_t)

# Set the reward to be the distance between the predicted and the actual observation
reward = byol_loss(pred_x_t, x_t)

Similarly, our world model loss function also takes as input the generator’s output. The generator’s loss is defined below.

def _generator_loss_fn(generator_params, traj_batch):
a_tm1 = traj_batch.action
o_tm1 = traj_batch.obs
o_t = traj_batch.next_obs
x_tm1 = jax.lax.stop_gradient(self._online_encoder.apply(train_states.online.params, o_tm1))
x_t = jax.lax.stop_gradient(self._target_encoder.apply(train_states.target, o_t))
z_t = self._generator.apply(generator_params, x_tm1, a_tm1, x_t)

# CALCULATE THE WORLD MODEL LOSS
pred_x_t = self._world_model.apply(train_states.world_model.params, x_tm1, a_tm1, z_t)
wm_loss = byol_loss(pred_x_t, x_t).mean()

disc_loss = calc_disc_loss(train_states.discriminator.params, x_tm1, a_tm1, z_t)

gen_loss = wm_loss - self._config["DISC_IMP_COEF"]* disc_loss
return gen_loss

As can be seen the generator attempts to minimise the world model loss and maximise the discriminator’s loss function. The discriminator’s loss function is defined as:

def calc_disc_loss(params, x_tm1, a_tm1, z_t):
# Entry fuction
entry_fn = jax.vmap(self._discriminator.apply, in_axes=(None, None, None, 0))

# Batch function
batch_fn = jax.vmap(entry_fn, in_axes=(None, 0, 0, None))

scores = batch_fn(params, x_tm1, a_tm1, z_t)

# Get the diagonal of the matrix
sqeezed_scores = jnp.squeeze(scores)
diag_scores = jnp.diag(sqeezed_scores)

ratios = diag_scores / (jnp.sum(sqeezed_scores, axis=-1) / len(scores[0]))

log_ratios = jnp.log(ratios)

return -jnp.mean(log_ratios)

The discriminator attempts to maximise its output score for Z values that followed from X and A over other Z values observed in other timesteps and episodes. To sample these other Z values we train on a batch of experience. We calculate output scores for each Z value in the batch with relation to every other (X, A) pair. Then we can calculate the loss function by dividing the score for related pairs over unrelated pairs.

When looking at the results below can see that BYOL-Hindsight takes about 6 million steps to achieve a score of around 5.5. This is half way toward the maximum reward possible. However if we look at the maximum returns over a batch of experience we can seen that BYOL-Hindsight successfully reaches the maximum rewards. The world model therefore continues to incentives rewards for higher levels up until where it reaches the maximum level. Once the agent reaches the maximum level a certain number of times the world model becomes good enough at predicting its dynamics. Therefore the agent is less incentives to only focus on those higher levels.

BYOL-Hindsight results.

We can now also visualise the trained BYOL-Hindsight agent. As the agent progresses in an environment the latest tool in it’s inventory can be seen on the button left corner of it’s avatar.

Trained agent on our Minecraft 2D environment

And that is it! We successfully reached the maximum level of our simple 2D Minecraft. Feel free to try the code out for yourself! Happy coding!

--

--