Coding PPO from Scratch with PyTorch (Part 3/4)

Eric Yang Yu
Analytics Vidhya
Published in
11 min readSep 17, 2020
A roadmap of my 4-part series.

Welcome to Part 3 of our series, where we will finish coding Proximal Policy Optimization (PPO) from scratch with PyTorch. If you haven’t read Part 1 and Part 2, please do so first.

Here is the code: https://github.com/ericyangyu/PPO-for-Beginners

We will be picking off from where we left off: Step 5. Here’s an overview of the pseudocode again, which can be found here as well:

Pseudocode of PPO on OpenAI’s Spinning Up doc.

Before we continue, note that we will be doing multiple epochs per iteration for Steps 6 and 7. We will add the number of epochs, n, as a hyperparameter in _init_hyperparameters later. We do this because if you notice, there is a k subscript on parameters Θ and Φ in steps 5–7. This suggests that there is a distinction between the parameters at the k-th iteration vs. the current moment in training, implying that each iteration also has its own set of epochs to run. We’ll see this again once we get to those steps, and write code for them.

Let’s look at Step 5 now.

Step 5

We will use the advantage function defined here. TL;DR here’s the equation, modified with Vᵩₖ:

Advantage function.

where Q^π is the Q-value of state action pair (s, a), and Vᵩₖ is the value of some observation s determined by our critic network following parameters Φ on the k-th iteration.

A modification we made from the formula here was to specify that the value predicted is following parameters Φ on the k-th iteration, which is important because later in Step 7 we’ll need to recalculate V(s) following parameters Φ on the i-th epoch. However, as the Q-values are determined after each rollout, and Vᵩₖ(s) must be determined before we perform multiple updates to our networks (otherwise, Vᵩₖ(s) will change as we update our critic network, which proves to be unstable and inconsistent with the pseudocode), Vᵩₖ(s) and advantage must be calculated before our epoch loop.

The way we’ll approach this is by individually calculating our Q-values and predicted values Vᵩₖ(s) with subroutines. We already have our Q-values calculated with compute_rtgs, so all we need to worry about is Vᵩₖ(s).

Let’s create a function evaluate to calculate Vᵩₖ(s).

def evaluate(self, batch_obs):
# Query critic network for a value V for each obs in batch_obs.
V = self.critic(batch_obs).squeeze()
return V

Note that we perform a squeeze operation on our returned tensor from a forward pass on our critic network. If you don’t know what it does, it basically changes the dimensionality of a tensor. For example, calling squeeze on [[1], [2], [3]] will return [1, 2, 3]. Since batch_obs retains the shape (timesteps per batch, dimension of observation), the tensor returned from passing batch_obs into our critic network is (timesteps per batch, 1), whereas the shape we want is just (timesteps per batch). squeeze will do the trick. Here’s some documentation on squeeze in case you want to dig further.

Next, we can simply calculate advantages:

# Calculate V_{phi, k}
V = self.evaluate(batch_obs)
# ALG STEP 5
# Calculate advantage
A_k = batch_rtgs - V.detach()

Note that we do V.detach() since V is a tensor with gradient required. However, the advantage will need to be reused each epoch loop, and the computation graph associated with advantage at the k-th iteration will not be useful in multiple epochs of stochastic gradient ascent.

Now, one of the only tricks we use in this code: advantage normalization. Through trial and error, I found that using the raw advantage makes PPO training highly unstable (yes, even more unstable and higher variance than our graphs in Part 1 depict). Though normalizing advantage isn’t in the pseudocode, in practice it’s extremely important as numerical algorithms behave poorly when different dimensions are also different in scale. I was going to stick advantage normalization into Part 4, since it’s technically an optimization, but I found it almost necessary to have in the code in order to maintain a reasonable level of performance with PPO. So, here it is:

# Normalize advantages
A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10)

Note that we add 1e-10 to the standard deviation of the advantages, just to avoid the possibility of dividing by 0.

Here’s the code so far:

Let’s look at Step 6 now.

Step 6.

Ah yes, the elephant in the room. The life and blood of PPO. This formula tells us how we’ll be updating our parameters Θ of our actor network. Luckily for us, most of what we need is either already calculated or can be calculated with existing subroutines.

First, let’s address the ratio in the left surrogate function.

Ratio of action probabilities with parameters Θ over action probabilities with parameters Θₖ.

Ok, so we’ll need to calculate log probabilities (probs) of the actions we took during our most recent rollout. Again, for why we’re finding log probs instead of raw action probs, here is a resource that explains why and here is another.

The bottom set of log probs will be with respect to parameters Θ at the k-th iteration (which we already have with batch_log_probs), while the top is just at the current epoch (the original pseudocode also assumes multiple epochs). Instead of defining a whole new subroutine to calculate log probs, let’s do it in evaluate.

First, let’s fix evaluate to return log probs of actions as well.

def evaluate(self, batch_obs, batch_acts):
...
# Calculate the log probabilities of batch actions using most
# recent actor network.
# This segment of code is similar to that in get_action()
mean = self.actor(batch_obs)
dist = MultivariateNormal(mean, self.cov_mat)
log_probs = dist.log_prob(batch_acts)
# Return predicted values V and log probs log_probs
return V, log_probs

Next, let’s fix the evaluate call we made earlier to unpack an extra return value for calculating advantage.

V, _ = self.evaluate(batch_obs, batch_acts)A_k = batch_rtgs - V.detach()...

Now, let’s start up our epoch loop to perform multiple updates on our actor and critic networks. The number of epochs is a hyperparameter, so we can add that to _init_hyperparameters as well.

for _ in range(self.n_updates_per_iteration):
# epoch code
def _init_hyperparameters(self):
...
self.n_updates_per_iteration = 5

Note that I chose 5 arbitrarily. Now notice again in the ratio formula:

Ratio of action probabilities with parameters Θ over action probabilities with parameters Θₖ.

We already have the bottom log probs (batch_log_probs). We just need to find π_Θ (aₜ | sₜ), which again we can use evaluate for. Note that this is the second time we call evaluate, and this time it will be in the epoch loop rather than before. The evaluate we called earlier to only extract V for calculating advantage, is right before this epoch loop.

for _ in range(self.n_updates_per_iteration):
# Calculate pi_theta(a_t | s_t)
_, curr_log_probs = self.evaluate(batch_obs, batch_acts)

Now since both batch_log_probs and curr_log_probs are log probs, we can just subtract them and exponentiate the log out with e. Cool little precalculus trick. Note that curr_log_probs is not detached, meaning it has a computation graph associated with it, which we’ll want to include as part of our back propagation when calculating gradients later. This is the start of our computation graph.

# Calculate ratios
ratios = torch.exp(curr_log_probs - batch_log_probs)

Now, let’s calculate surrogate losses. Surrogate losses are just the two losses that we will be taking the minimum of in Step 6. The first surrogate loss uses raw ratios to calculate ratios * advantages, whereas second surrogate loss clips the ratios to make sure we are not stepping too far in any direction during gradient ascent. This should be very easy now that we’ve found all the small parts to our formula.

# Calculate surrogate losses
surr1 = ratios * A_k
surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * A_k
...def _init_hyperparameters(self):
...
self.clip = 0.2 # As recommended by the paper

I use torch.clamp, which will bind argument 1, or ratios, between argument 2 and argument 3 as respective lower and upper bounds. For some documentation on torch.clamp, here you go.

Finally, we calculate our entire actor loss.

actor_loss = (-torch.min(surr1, surr2)).mean()

Let’s take a step back. We take the minimum of the two surrogate losses, as per the pseudocode. We have the negative because we’re trying to maximize this loss, or performance/objective function, through stochastic gradient ascent, but the optimizer we’ll be using is Adam, which minimizes the loss. So, minimizing the negative loss maximizes the performance function. We then take the mean to generate a single loss as a float.

Let’s do back propagation on our actor network. First, we’ll need to define Adam optimizer for our actor parameters. Let’s do that in __init__.

from torch.optim import Adamclass PPO:
def __init__(self, env):
...
self.actor_optim = Adam(self.actor.parameters(), lr=self.lr)
def _init_hyperparameters(self):
...
self.lr = 0.005

Again, lr, or learning rate, is arbitrarily defined. Let’s now do our back propagations and perform one epoch on our actor network.

# Calculate gradients and perform backward propagation for actor 
# network
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()

And that’s it for the dreaded Step 6! Here’s the code so far:

__init__
learn
evaluate, _init_hyperparameters

Let’s go on to the penultimate step, Step 7.

Step 7.

This might look really scary at first, but we’re really just updating critic parameters with mean squared error of the predicted values at the current epoch, Vᵩ(sₜ), with rewards-to-go. We’ll use a given class from torch.nn that will calculate MSE for us, torch.nn.MSELoss. Here’s some documentation. If you want, you can also write your own MSE function instead of using torch’s, which shouldn’t be too hard.

First, let’s define another Adam optimizer for our critic network.

self.critic_optim = Adam(self.critic.parameters(), lr=self.lr)

Then, we need to calculate Vᵩ(sₜ) and rewards-to-go. Luckily for us, rewards-to-go are already calculated with batch_rtgs. Even luckier for us, we can find Vᵩ(sₜ) with a single change to our existing code: where we call evaluate in the epoch loop, just retain the V returned instead of ignoring it with _.

# Calculate V_phi and pi_theta(a_t | s_t)    
V, curr_log_probs = self.evaluate(batch_obs, batch_acts)

Last but not least, calculate the MSE of predicted values and rewards-to-go, and back propagate on the critic network.

critic_loss = nn.MSELoss()(V, batch_rtgs)# Calculate gradients and perform backward propagation for critic network    
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()

Note that since we’re doing a second backward propagation on our computation graphs, and both the actor and critic loss computation graphs converge a bit up the graph, we’ll need to add a retain_graph=True to backward for either actor or critic (depends on which we back propagate on first). Otherwise, we’ll get an error saying that the buffers have already been freed when trying to run backward through the graph a second time.

And that’s it for Step 7! Here’s the code so far:

__init__
learn
Rest of the code, unchanged since Step 6.

Now, for the finale… Step 8.

Step 8.

Yes, quite possibly the hardest step up to this point. Just wait until you see Step 9 though, it’s quite the elusive one.

On a more serious note, we can’t forget to increment the t_so_far that we set up at the beginning of Part 2, in order to keep track of how many iterations to run. Let’s do that by taking advantage of batch_lens returned from rollout.

import numpy as np...# ALG STEP 2
while t_so_far < total_timesteps:
...
batch_obs, batch_acts, batch_log_probs, batch_rtgs, batch_lens = self.rollout()
# Calculate how many timesteps we collected this batch
t_so_far += np.sum(batch_lens)

And now you have a fully functional bare-bones PPO! To test if everything is ok up to this point, you can run this snippet of code at the bottom of your ppo.py:

import gym
env = gym.make('Pendulum-v0')
model = PPO(env)
model.learn(10000)

If the program runs without any errors (should take about 10 seconds), you’re golden. If there’s any issues or mistakes up to this point, please let me know. You can also cross reference what you have right now with the main PPO for Beginners code.

Now you might be wondering: Eric, you only wrote ppo.py and network.py, but in Part 1 you also had main.py, arguments.py, and eval_policy.py. You also had so many more features like logging, saving the actor and critic networks, parsing command line arguments, custom hyperparameters, and more. Also, the PPO for Beginners code looks a bit different from the screenshots above.

You are absolutely right. However, to keep things simple, I will not discuss how I coded those parts in this series as it is irrelevant to learning how to write a bare-bones PPO implementation with PyTorch. I promise that all the extra stuff are sprinkled on top of the bare-bones PPO, which acts as a foundation for the rest of the code.

Instead, I encourage you to explore hands-on how the extra features work in sync with the bare-bones PPO implementation. I highly recommend using pdb, or python debugger, to step through my code starting from if __name__ == '__main__': in main.py. If you don’t know how to use pdb, get started quickly here. If you’re an expert with pdb already, here’s the documentation.

The repository README.md contains instructions on how to run the code under “Usage”. All my files and code are styled in great detail and well documented, so you can explore them to your heart’s content. I structured them in the hopes of making things as simple and modular as possible.

This concludes Part 3 of our series; at this point, you should have completely implemented PPO from the pseudocode and should be able to achieve performance from graphs seen in Part 1. I know this was a lot of material to digest, so if you have any questions, don’t hesitate to contact me at eyyu@ucsd.edu or just comment below.

In Part 4, we will explore some of the optimizations we can perform on the bare-bones PPO to increase performance and reduce variance. Hope this series has been helpful to you so far!

--

--