Continuous-Action Reinforcement Learning with Inverted Pendulum Using PyTorch — Part 3

Ankush k Singal
AI Artistry
Published in
7 min readSep 20, 2023

--

Ankush k Singal

Source: Continuos Action Reinforcement Learning

Introduction

CARL is a branch of machine learning that teaches agents to make decisions in environments with continuous action spaces. This is important because many real-world problems, such as autonomous driving, robotics, and financial trading, involve continuous actions.

Example

Imagine an agent that is learning to control a robotic arm. The agent has two continuous actions: the angle of the arm and the force applied to the arm. The agent’s goal is to move an object to a specific location.

The agent starts by randomly exploring the action space. It tries different angles and forces, and observes the results. The agent receives a reward for moving the object closer to the goal location, and a penalty for moving the object further away.

Over time, the agent learns to select actions that lead to higher rewards. It learns that certain angles and forces are more effective than others. Eventually, the agent is able to move the object to the goal location consistently.

CARL Algorithms

There are a number of different CARL algorithms that have been developed. Some popular algorithms include:

  • Actor-critic methods: Actor-critic methods are a widely used approach to CARL. They consist of two components: an actor and a critic. The actor selects actions, and the critic evaluates the actions selected by the actor. The actor then updates its policy based on the feedback from the critic.
  • Policy gradient methods: Policy gradient methods directly optimize the policy function. They use the gradient of the expected reward with respect to the policy function to update the policy.
  • Q-learning: Q-learning is a value-based RL algorithm that can be extended to continuous action spaces using function approximation techniques.

Implementing Inverted Pendulum Using PyTorch

!apt-get install -y xvfb

!pip install -q gym==0.23.1 \
pytorch-lightning==1.6 \
pyvirtualdisplay

!pip install -Uq brax==0.0.12 jax==0.3.14 jaxlib==0.3.14+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


import warnings
warnings.filterwarnings('ignore')

Setup virtual display

from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

Import the necessary code libraries

import copy
import torch
import random
import gym
import matplotlib
import functools
import math

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from torch.distributions import Normal

from pytorch_lightning import LightningModule, Trainer

import brax
from brax import envs
from brax.envs import to_torch
from brax.io import html

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

v = torch.ones(1, device='cuda')
@torch.no_grad()
def plot_policy(policy):
pos = np.linspace(-4.8, 4.8, 100)
vel = np.random.random(size=(10000, 1)) * 0.1
ang = np.linspace(-0.418, 0.418, 100)
ang_vel = np.random.random(size=(10000, 1)) * 0.1

g1, g2 = np.meshgrid(pos, ang)
grid = np.stack((g1,g2), axis=-1)
grid = grid.reshape(-1, 2)
grid = np.hstack((grid, vel, ang_vel))

grid = torch.from_numpy(grid).float()
loc, _ = policy(grid)

plot_vals = loc.numpy()
plot_vals = plot_vals.reshape(100, 100)[::-1]

plt.figure(figsize=(8, 8))
plt.imshow(plot_vals, cmap='coolwarm')
plt.colorbar()
plt.clim(-1, 1)
plt.title("P(left | s)", size=20)
plt.xlabel("Cart Position", size=14)
plt.ylabel("Pole angle", size=14)
plt.xticks(ticks=[0, 50, 100], labels=['-4.8', '0', '4.8'])
plt.yticks(ticks=[100, 50, 0], labels=['-0.418', '0', '0.418'])


@torch.no_grad()
def create_video(env, episode_length, policy=None):
qp_array = []
state = env.reset()
for i in range(episode_length):
if policy:
loc, scale = policy(state)
sample = torch.normal(loc, scale)
action = torch.tanh(sample)
else:
action = env.action_space.sample()
state, _, _, _ = env.step(action)
qp_array.append(env.unwrapped._state.qp)
return HTML(html.render(env.unwrapped._env.sys, qp_array))


@torch.no_grad()
def test_agent(env, episode_length, policy, episodes=10):
ep_returns = []
for ep in range(episodes):
state = env.reset()
done = False
ep_ret = 0.0

while not done:
loc, scale = policy(state)
sample = torch.normal(loc, scale)
action = torch.tanh(sample)
state, reward, done, info = env.step(action)
ep_ret += reward.item()

ep_returns.append(ep_ret)

return sum(ep_returns) / episodes

Create the policy

class GradientPolicy(nn.Module):

def __init__(self, in_features, out_dims, hidden_size=128):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc_mu = nn.Linear(hidden_size, out_dims)
self.fc_std = nn.Linear(hidden_size, out_dims)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
loc = self.fc_mu(x)
loc = torch.tanh(loc)
scale = self.fc_std(x)
scale = F.softplus(scale) + 0.001
return loc, scale

Create the environment

class RunningMeanStd:
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
self.mean = torch.zeros(shape, dtype=torch.float32).to(device)
self.var = torch.ones(shape, dtype=torch.float32).to(device)
self.count = epsilon

def update(self, x):
batch_mean = torch.mean(x, dim=0)
batch_var = torch.var(x, dim=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)

def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)


def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
delta = batch_mean - mean
tot_count = count + batch_count

new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count

return new_mean, new_var, new_count


class NormalizeObservation(gym.core.Wrapper):

def __init__(self, env, epsilon=1e-8):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[-1])
self.epsilon = epsilon

def step(self, action):
obs, rews, dones, infos = self.env.step(action)
obs = self.normalize(obs)
return obs, rews, dones, infos

def reset(self, **kwargs):
return_info = kwargs.get("return_info", False)
if return_info:
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
obs = self.normalize(obs)
if not return_info:
return obs
else:
return obs, info

def normalize(self, obs):
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / torch.sqrt(self.obs_rms.var + self.epsilon)


class NormalizeReward(gym.core.Wrapper):

def __init__(self, env, gamma=0.99, epsilon=1e-8):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.return_rms = RunningMeanStd(shape=())
self.returns = torch.zeros(self.num_envs).to(device)
self.gamma = gamma
self.epsilon = epsilon

def step(self, action):
obs, rews, dones, infos = self.env.step(action)
dones = dones.bool()
self.returns = self.returns * self.gamma + rews
rews = self.normalize(rews)
self.returns[dones] = 0.0
return obs, rews, dones, infos

def normalize(self, rews):
self.return_rms.update(self.returns)
return rews / torch.sqrt(self.return_rms.var + self.epsilon)
entry_point = functools.partial(envs.create_gym_env, env_name='inverted_pendulum')
gym.register('brax-inverted_pendulum-v0', entry_point=entry_point)
def create_env(env_name, num_envs=256, episode_length=1000):
env = gym.make(env_name, batch_size=num_envs, episode_length=episode_length)
env = to_torch.JaxToTorchWrapper(env, device=device)
env = NormalizeObservation(env)
env = NormalizeReward(env)
return env
env = gym.make('brax-inverted_pendulum-v0', episode_length=1000)
env = to_torch.JaxToTorchWrapper(env, device=device)
create_video(env, 1000)
env = create_env('brax-inverted_pendulum-v0', num_envs=1)
obs = env.reset()
print("Num envs: ", obs.shape[0], "Obs dimentions: ", obs.shape[1])
env.observation_space
obs, reward, done, info = env.step(env.action_space.sample())

Plot the untrained policy

policy = GradientPolicy(4, 1)
grid = plot_policy(policy)

Create the dataset

class RLDataset(IterableDataset):

def __init__(self, env, policy, episode_length, gamma):
self.env = env
self.policy = policy
self.episode_length = episode_length
self.gamma = gamma
self.obs = self.env.reset()

@torch.no_grad()
def __iter__(self):
transitions = []

for step in range(self.episode_length):
loc, scale = self.policy(self.obs)
action = torch.normal(loc, scale)
next_obs, reward, done, info = self.env.step(action)
transitions.append((self.obs, action, reward, done))
self.obs = next_obs

obs_b, action_b, reward_b, done_b = map(torch.stack, zip(*transitions))

running_return = torch.zeros(self.env.num_envs, dtype=torch.float32, device=device)
return_b = torch.zeros_like(reward_b)

for row in range(self.episode_length - 1, -1, -1):
running_return = reward_b[row] + ~done_b[row] * self.gamma * running_return
return_b[row] = running_return

num_samples = self.env.num_envs * self.episode_length
obs_b = obs_b.view(num_samples, -1)
action_b = action_b.view(num_samples, -1)
return_b = return_b.view(num_samples, -1)

idx = list(range(num_samples))
random.shuffle(idx)

for i in idx:
yield obs_b[i], action_b[i], return_b[i]

Create the Proximal Policy Optimization algorithm

class reinforce(LightningModule):

def __init__(self, env_name, num_envs=256, episode_length=1_000, batch_size=1024,
hidden_size=64, policy_lr=1e-4, gamma=0.999, entropy_coef=0.0001, optim=AdamW):

super().__init__()

self.env = create_env(env_name, num_envs=num_envs, episode_length=episode_length)
test_env = gym.make(env_name, episode_length=episode_length)
test_env = to_torch.JaxToTorchWrapper(test_env, device=device)
self.test_env = NormalizeObservation(test_env)
self.test_env.obs_rms = self.env.obs_rms

obs_size = self.env.observation_space.shape[1]
action_dims = self.env.action_space.shape[1]

self.policy = GradientPolicy(obs_size, action_dims, hidden_size)
self.dataset = RLDataset(self.env, self.policy, episode_length, gamma)

self.save_hyperparameters()
self.videos = []

def configure_optimizers(self):
return self.hparams.optim(self.policy.parameters(), lr=self.hparams.policy_lr)

def train_dataloader(self):
return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)

# Training step.
def training_step(self, batch, batch_idx):
obs, action, returns = batch

loc, scale = self.policy(obs)
dist = Normal(loc, scale)
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)

policy_loss = - log_prob * returns

entropy = dist.entropy().sum(dim=-1, keepdim=True)

self.log("episode/Policy Loss", policy_loss.mean())
self.log("episode/Entropy", entropy.mean())

return torch.mean(policy_loss - self.hparams.entropy_coef * entropy)

def training_epoch_end(self, training_step_outputs):
if self.current_epoch % 10 == 0:
average_return = test_agent(self.test_env, self.hparams.episode_length, self.policy, episodes=1)
self.log("episode/Average Return", average_return)

if self.current_epoch % 50 == 0:
video = create_video(self.test_env, self.hparams.episode_length, policy=self.policy)
self.videos.append(video)

Purge logs and run the visualization tool (Tensorboard)

# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%reload_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

“Stay connected and support my work through various platforms:

Remember, each “Like”, “Share”, and “Star” greatly contributes to my work and motivates me to continue producing more quality content. Thank you for your support!

Resources:

--

--

Ankush k Singal
AI Artistry

My name is Ankush Singal and I am a traveller, photographer and Data Science enthusiast .