Continuous-Action Reinforcement Learning with Inverted Pendulum Using PyTorch — Part 3
Introduction
CARL is a branch of machine learning that teaches agents to make decisions in environments with continuous action spaces. This is important because many real-world problems, such as autonomous driving, robotics, and financial trading, involve continuous actions.
Example
Imagine an agent that is learning to control a robotic arm. The agent has two continuous actions: the angle of the arm and the force applied to the arm. The agent’s goal is to move an object to a specific location.
The agent starts by randomly exploring the action space. It tries different angles and forces, and observes the results. The agent receives a reward for moving the object closer to the goal location, and a penalty for moving the object further away.
Over time, the agent learns to select actions that lead to higher rewards. It learns that certain angles and forces are more effective than others. Eventually, the agent is able to move the object to the goal location consistently.
CARL Algorithms
There are a number of different CARL algorithms that have been developed. Some popular algorithms include:
- Actor-critic methods: Actor-critic methods are a widely used approach to CARL. They consist of two components: an actor and a critic. The actor selects actions, and the critic evaluates the actions selected by the actor. The actor then updates its policy based on the feedback from the critic.
- Policy gradient methods: Policy gradient methods directly optimize the policy function. They use the gradient of the expected reward with respect to the policy function to update the policy.
- Q-learning: Q-learning is a value-based RL algorithm that can be extended to continuous action spaces using function approximation techniques.
Implementing Inverted Pendulum Using PyTorch
!apt-get install -y xvfb
!pip install -q gym==0.23.1 \
pytorch-lightning==1.6 \
pyvirtualdisplay
!pip install -Uq brax==0.0.12 jax==0.3.14 jaxlib==0.3.14+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
import warnings
warnings.filterwarnings('ignore')
Setup virtual display
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()
Import the necessary code libraries
import copy
import torch
import random
import gym
import matplotlib
import functools
import math
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW
from torch.distributions import Normal
from pytorch_lightning import LightningModule, Trainer
import brax
from brax import envs
from brax.envs import to_torch
from brax.io import html
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()
v = torch.ones(1, device='cuda')
@torch.no_grad()
def plot_policy(policy):
pos = np.linspace(-4.8, 4.8, 100)
vel = np.random.random(size=(10000, 1)) * 0.1
ang = np.linspace(-0.418, 0.418, 100)
ang_vel = np.random.random(size=(10000, 1)) * 0.1
g1, g2 = np.meshgrid(pos, ang)
grid = np.stack((g1,g2), axis=-1)
grid = grid.reshape(-1, 2)
grid = np.hstack((grid, vel, ang_vel))
grid = torch.from_numpy(grid).float()
loc, _ = policy(grid)
plot_vals = loc.numpy()
plot_vals = plot_vals.reshape(100, 100)[::-1]
plt.figure(figsize=(8, 8))
plt.imshow(plot_vals, cmap='coolwarm')
plt.colorbar()
plt.clim(-1, 1)
plt.title("P(left | s)", size=20)
plt.xlabel("Cart Position", size=14)
plt.ylabel("Pole angle", size=14)
plt.xticks(ticks=[0, 50, 100], labels=['-4.8', '0', '4.8'])
plt.yticks(ticks=[100, 50, 0], labels=['-0.418', '0', '0.418'])
@torch.no_grad()
def create_video(env, episode_length, policy=None):
qp_array = []
state = env.reset()
for i in range(episode_length):
if policy:
loc, scale = policy(state)
sample = torch.normal(loc, scale)
action = torch.tanh(sample)
else:
action = env.action_space.sample()
state, _, _, _ = env.step(action)
qp_array.append(env.unwrapped._state.qp)
return HTML(html.render(env.unwrapped._env.sys, qp_array))
@torch.no_grad()
def test_agent(env, episode_length, policy, episodes=10):
ep_returns = []
for ep in range(episodes):
state = env.reset()
done = False
ep_ret = 0.0
while not done:
loc, scale = policy(state)
sample = torch.normal(loc, scale)
action = torch.tanh(sample)
state, reward, done, info = env.step(action)
ep_ret += reward.item()
ep_returns.append(ep_ret)
return sum(ep_returns) / episodes
Create the policy
class GradientPolicy(nn.Module):
def __init__(self, in_features, out_dims, hidden_size=128):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc_mu = nn.Linear(hidden_size, out_dims)
self.fc_std = nn.Linear(hidden_size, out_dims)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
loc = self.fc_mu(x)
loc = torch.tanh(loc)
scale = self.fc_std(x)
scale = F.softplus(scale) + 0.001
return loc, scale
Create the environment
class RunningMeanStd:
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
self.mean = torch.zeros(shape, dtype=torch.float32).to(device)
self.var = torch.ones(shape, dtype=torch.float32).to(device)
self.count = epsilon
def update(self, x):
batch_mean = torch.mean(x, dim=0)
batch_var = torch.var(x, dim=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
class NormalizeObservation(gym.core.Wrapper):
def __init__(self, env, epsilon=1e-8):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[-1])
self.epsilon = epsilon
def step(self, action):
obs, rews, dones, infos = self.env.step(action)
obs = self.normalize(obs)
return obs, rews, dones, infos
def reset(self, **kwargs):
return_info = kwargs.get("return_info", False)
if return_info:
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
obs = self.normalize(obs)
if not return_info:
return obs
else:
return obs, info
def normalize(self, obs):
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / torch.sqrt(self.obs_rms.var + self.epsilon)
class NormalizeReward(gym.core.Wrapper):
def __init__(self, env, gamma=0.99, epsilon=1e-8):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.return_rms = RunningMeanStd(shape=())
self.returns = torch.zeros(self.num_envs).to(device)
self.gamma = gamma
self.epsilon = epsilon
def step(self, action):
obs, rews, dones, infos = self.env.step(action)
dones = dones.bool()
self.returns = self.returns * self.gamma + rews
rews = self.normalize(rews)
self.returns[dones] = 0.0
return obs, rews, dones, infos
def normalize(self, rews):
self.return_rms.update(self.returns)
return rews / torch.sqrt(self.return_rms.var + self.epsilon)
entry_point = functools.partial(envs.create_gym_env, env_name='inverted_pendulum')
gym.register('brax-inverted_pendulum-v0', entry_point=entry_point)
def create_env(env_name, num_envs=256, episode_length=1000):
env = gym.make(env_name, batch_size=num_envs, episode_length=episode_length)
env = to_torch.JaxToTorchWrapper(env, device=device)
env = NormalizeObservation(env)
env = NormalizeReward(env)
return env
env = gym.make('brax-inverted_pendulum-v0', episode_length=1000)
env = to_torch.JaxToTorchWrapper(env, device=device)
create_video(env, 1000)
env = create_env('brax-inverted_pendulum-v0', num_envs=1)
obs = env.reset()
print("Num envs: ", obs.shape[0], "Obs dimentions: ", obs.shape[1])
env.observation_space
obs, reward, done, info = env.step(env.action_space.sample())
Plot the untrained policy
policy = GradientPolicy(4, 1)
grid = plot_policy(policy)
Create the dataset
class RLDataset(IterableDataset):
def __init__(self, env, policy, episode_length, gamma):
self.env = env
self.policy = policy
self.episode_length = episode_length
self.gamma = gamma
self.obs = self.env.reset()
@torch.no_grad()
def __iter__(self):
transitions = []
for step in range(self.episode_length):
loc, scale = self.policy(self.obs)
action = torch.normal(loc, scale)
next_obs, reward, done, info = self.env.step(action)
transitions.append((self.obs, action, reward, done))
self.obs = next_obs
obs_b, action_b, reward_b, done_b = map(torch.stack, zip(*transitions))
running_return = torch.zeros(self.env.num_envs, dtype=torch.float32, device=device)
return_b = torch.zeros_like(reward_b)
for row in range(self.episode_length - 1, -1, -1):
running_return = reward_b[row] + ~done_b[row] * self.gamma * running_return
return_b[row] = running_return
num_samples = self.env.num_envs * self.episode_length
obs_b = obs_b.view(num_samples, -1)
action_b = action_b.view(num_samples, -1)
return_b = return_b.view(num_samples, -1)
idx = list(range(num_samples))
random.shuffle(idx)
for i in idx:
yield obs_b[i], action_b[i], return_b[i]
Create the Proximal Policy Optimization algorithm
class reinforce(LightningModule):
def __init__(self, env_name, num_envs=256, episode_length=1_000, batch_size=1024,
hidden_size=64, policy_lr=1e-4, gamma=0.999, entropy_coef=0.0001, optim=AdamW):
super().__init__()
self.env = create_env(env_name, num_envs=num_envs, episode_length=episode_length)
test_env = gym.make(env_name, episode_length=episode_length)
test_env = to_torch.JaxToTorchWrapper(test_env, device=device)
self.test_env = NormalizeObservation(test_env)
self.test_env.obs_rms = self.env.obs_rms
obs_size = self.env.observation_space.shape[1]
action_dims = self.env.action_space.shape[1]
self.policy = GradientPolicy(obs_size, action_dims, hidden_size)
self.dataset = RLDataset(self.env, self.policy, episode_length, gamma)
self.save_hyperparameters()
self.videos = []
def configure_optimizers(self):
return self.hparams.optim(self.policy.parameters(), lr=self.hparams.policy_lr)
def train_dataloader(self):
return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)
# Training step.
def training_step(self, batch, batch_idx):
obs, action, returns = batch
loc, scale = self.policy(obs)
dist = Normal(loc, scale)
log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True)
policy_loss = - log_prob * returns
entropy = dist.entropy().sum(dim=-1, keepdim=True)
self.log("episode/Policy Loss", policy_loss.mean())
self.log("episode/Entropy", entropy.mean())
return torch.mean(policy_loss - self.hparams.entropy_coef * entropy)
def training_epoch_end(self, training_step_outputs):
if self.current_epoch % 10 == 0:
average_return = test_agent(self.test_env, self.hparams.episode_length, self.policy, episodes=1)
self.log("episode/Average Return", average_return)
if self.current_epoch % 50 == 0:
video = create_video(self.test_env, self.hparams.episode_length, policy=self.policy)
self.videos.append(video)
Purge logs and run the visualization tool (Tensorboard)
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%reload_ext tensorboard
%tensorboard --logdir /content/lightning_logs/
“Stay connected and support my work through various platforms:
- GitHub: For all my open-source projects and Notebooks, you can visit my GitHub profile at https://github.com/andysingal. If you find my content valuable, don’t hesitate to leave a star.
- Patreon: If you’d like to provide additional support, you can consider becoming a patron on my Patreon page at https://www.patreon.com/AndyShanu.
- Medium: You can read my latest articles and insights on Medium at https://medium.com/@andysingal.
- The Kaggle: Check out my Kaggle profile for data science and machine learning projects at https://www.kaggle.com/alphasingal.
- Hugging Face: For natural language processing and AI-related projects, you can explore my Huggingface profile at https://huggingface.co/Andyrasika.
- YouTube: To watch my video content, visit my YouTube channel at https://www.youtube.com/@andy111007.
- LinkedIn: To stay updated on my latest projects and posts, you can follow me on LinkedIn. Here is the link to my profile: https://www.linkedin.com/in/ankushsingal/."Requests and questions: If you have a project in mind that you’d like me to work on or if you have any questions about the concepts I’ve explained, don’t hesitate to let me know. I’m always looking for new ideas for future Notebooks and I love helping to resolve any doubts you might have.
Remember, each “Like”, “Share”, and “Star” greatly contributes to my work and motivates me to continue producing more quality content. Thank you for your support!
Resources: