Reinforcement Learning — Dots and Lines — Snake — 3/3
In this blog we will learn about training a snake RL agent and inference it.
Below code testenv_1.py where we would be just checking base environment can be used to check if your previous code written properly.
from stable_baselines3.common.env_checker import check_env
# SnakeEnv is a example
from snakeenv import SnakeEnv
env = SnakeEnv()
check_env(env)
Below code testenv_2.py where we would can check each episodes input and output correctness.
# Example SnakeEnv
from snakeenv import SnakeEnv
env = SnakeEnv()
episodes = 50
for episode in range(episodes):
done = False
obs = env.reset()
while not done:
random_action = env.action_space.sample()
print("action",random_action)
obs, reward, done, info = env.step(random_action)
print('reward',reward)
We are declaring places where model and logs are to be stores and then we are calling the SnakeEnv() class and resetting the environment. With stable_baselines3 library we are using PPO policy and we are training it. Below is the code for train.py
from stable_baselines3 import PPO
import os
from snakeenv import SnakeEnv
import time
models_dir = f"models/{int(time.time())}/"
logdir = f"logs/{int(time.time())}/"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(logdir):
os.makedirs(logdir)
env = SnakeEnv()
env.reset()
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=logdir)
TIMESTEPS = 10000
iters = 0
for i in range(1, 10000000):
iters += 1
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO")
model.save(f"{models_dir}/{TIMESTEPS*iters}")
env.close()
After training you can find the trained model in a zip file under models folder and it can be used to inference further steps. Below is the code for inference.py
import gym
from stable_baselines3 import PPO
from snakeenv import SnakeEnv
models_dir = "models/1667882491"
env = SnakeEnv()
env.reset()
model_path = f"{models_dir}/50000.zip"
model = PPO.load(model_path, env=env)
episodes = 500
for ep in range(episodes):
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
print(rewards)
The entire code for this blog series can be found here.