Reinforcement Learning — Dots and Lines — Snake — 3/3

Ajith Kumar V
2 min readDec 24, 2022

--

In this blog we will learn about training a snake RL agent and inference it.

Below code testenv_1.py where we would be just checking base environment can be used to check if your previous code written properly.

from stable_baselines3.common.env_checker import check_env
# SnakeEnv is a example
from snakeenv import SnakeEnv

env = SnakeEnv()
check_env(env)

Below code testenv_2.py where we would can check each episodes input and output correctness.

# Example SnakeEnv
from snakeenv import SnakeEnv

env = SnakeEnv()
episodes = 50

for episode in range(episodes):
done = False
obs = env.reset()
while not done:
random_action = env.action_space.sample()
print("action",random_action)
obs, reward, done, info = env.step(random_action)
print('reward',reward)

We are declaring places where model and logs are to be stores and then we are calling the SnakeEnv() class and resetting the environment. With stable_baselines3 library we are using PPO policy and we are training it. Below is the code for train.py

from stable_baselines3 import PPO
import os
from snakeenv import SnakeEnv
import time

models_dir = f"models/{int(time.time())}/"
logdir = f"logs/{int(time.time())}/"

if not os.path.exists(models_dir):
os.makedirs(models_dir)

if not os.path.exists(logdir):
os.makedirs(logdir)

env = SnakeEnv()
env.reset()

model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=logdir)

TIMESTEPS = 10000
iters = 0
for i in range(1, 10000000):
iters += 1
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO")
model.save(f"{models_dir}/{TIMESTEPS*iters}")

env.close()

After training you can find the trained model in a zip file under models folder and it can be used to inference further steps. Below is the code for inference.py


import gym
from stable_baselines3 import PPO
from snakeenv import SnakeEnv

models_dir = "models/1667882491"

env = SnakeEnv()
env.reset()

model_path = f"{models_dir}/50000.zip"
model = PPO.load(model_path, env=env)

episodes = 500

for ep in range(episodes):
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
print(rewards)
Inference of model

The entire code for this blog series can be found here.

--

--