Reinforcement Learning made easy
Reinforcement learning is one of the most exciting branches of AI right now. It has allowed us to make major progress in areas like autonomous vehicles, robotics and video games. Perhaps its most famous achievement has been beating the world-champion Go player, a feat that many considered impossible before. Today we are going to look into two of the most famous reinforcement learning algorithms, SARSA and Q-learning and how they can be applied to a simple grid world maze like problem.
Markov Decision Process
To explain the context behind these algorithms, we need to talk in terms of a Markov Decision Process.
In simple terms, MDPs are just a formal way of modeling the world with a defined set of rules. In this world we introduce an entity called an agent. The agent can move around and interact with the map in whatever way the transition function allows it to. The agent’s only goal is always to maximize its reward. Think of it like humans trying to maximize their happiness in a particular environment (although usually we are not very good at it!). This is the type of model that can make reinforcement learning thrive as we will see later on. We are able to simulate episodes as many times as needed, and our agent can learn how to maximize its reward through experience.
We now need to introduce the concept of Q-value which is at the core of how SARSA and Q-learning operate. If we want our agent to learn from experience, we need to somehow give it the ability to remember the reward it received on every journey. Not only that, but we would also like to know the value associated with doing a particular move at a particular place since we want to maximize our reward in the most efficient way possible. If a particular move leads to good results we would like its Q-value to represent this. In a way, the Q-value will be our memory of what occurs at each point of our grid. We can write it as Q(s,a) meaning that it is a function of the current state s and whatever action is chosen in that state.
Both SARSA and Q-learning can be summarized as an update rule to the Q-value of each state-action pair of the MDP. By generating episodes randomly we let our agent move and explore the world we created. Based on the reward recieved on each episode we are able to use these update rules to change the Q-values, based on what has occured so far. Doing this over and over again will reinforce the agent’s knowledge about the world and its ability to achieve maximum reward efficiently. The difference between SARSA and Q-learning lies only on how these values are updated. Before we explain this more specifically, let’s take a look at how our agent decides what to do during the episodes.
The agent’s action decision process is the well known ϵ-greedy algorithm. This algorithm simply states that whenever the agent has a decision to make, it should greedily choose its most profitable option, most of the time. Some of the time (with probability ϵ) it should randomly choose some other action. Surprising as it may be this basic idea gets quite good results in some very specific applications. It ensures exploration over all possibilities while still doing exploitation of current knowledge, two concepts that are paramount in RL. Naturally, wewant our agent to act ϵ-greedily on the Q-value of each state since we have already established that the Q-value represents the value of a particular move.
So how does the SARSA algorithm work? The acronym SARSA comes from “State–action–reward–state–action” which is a pretty good overview of how the algorithm operates. The basic premise is that given a certain state S and action A, we should update theQ value by incrementing it with the sum of the reward (R) obtained plus the Q-value of the next state-action pair (chosen ϵ-greedily!), minus the current Q-value. Phew, that doesn’t sound very intuitive, does it? You can understand it more easily by picturing as “every time I do this move from this state, I will update its value accounting for the reward I get and the potential reward of my next move”.
You can see how this is rule is intuitive in nature. If you decide to take up a job offer, you think about the salary you will be getting, but also your growth prospects within the company. If you are thinking about going to a casino, you might have a chance to win some money that night, but you might become addicted to gambling, and so your long-term reward does not seem so appealing. If you are playing chess, capturing one of your opponents’ peons is good for you, but it is not really worth sacrificing your queen, is it? The SARSA algorithm tries to capture this in the simplest way possible.
The variable α represent the learning rate of our agent which essentially dictates how much we change our Q-value every time that we update it. γ is the discount factor which let’s us control the influence of future rewards on the update.
The Q-learning algorithm is very similar to the previous algorithm with one slight modification. Remember how I mentioned that our update rule uses the Q-value of the next state action pair chosen ϵ-greedily? Well, it turns out that you can just use the maximum Q-value everytime instead! This is exactly what the update rule above shows.
Note that this does not mean that our agent will always act greedily during the episodes. The agent still has a small probability of doing a random action regardless of the fact that the update is always done using the highest Q-value. This is why Q-learning is an example of an off-policy learning algorithm, it updates the Q-values based on actions that weren't necessarily chosen by the agent.
Creating an agent
Let’s start with the fun part! For today we are gonna keep our agent very simple. It will only have a position, defined in terms of x and y coordinates, and its current reward value. It will have two methods, one that implements the ϵ-greedy algorithm described above, and another one that moves our agent in a particular direction, accounting for walls and out-of-bounds moves.
Create a world
Simple enough hopefully? Now let’s take a look at our Gridworld class. For simplicity we will assume that the grid will always be square. Tiles can be of several types:
- Walls: where our agent is not allowed to go to
- Snake pit: where the agent receives a penalty and dies. Entering here ends the episode.
- Treasure: where the agent receives a reward. Entering here ends the episode.
- Regular: tiles where our agent is allowed to go to. There is a small penalty on each to motivate the agent to go to the reward as fast as possible.
We will set the default values α=0.5 and γ=1.
We will also initialize 2 matrices in the constructor. The first one corresponds to our initial policy for each tile. This will simply be a dictionary that maps a direction with the probability of following it. We will initially set all probabilities to 0.25 (random) to ensure that there is no bias in the initial exploration of the world. The second matrix corresponds to the Q-values for each tile. Again, this will be a dictionary mapping each action (direction) to its Q-value for that tile (You can see Q(s,a) as Q(tile,direction)). This is the matrix that will be updated by both our algorithms.
- init_agent(): Initializes the agent and ensures that it does not begin the episode in an illegal tile like a wall or an out of bounds tile.
- check_reward(): Returns the reward associated with a certain position
The last function is our main focus so we will take the time to explain it thoroughly. First, we look at a pseudo code for SARSA (analogous for Q-learning).
- The first line corresponds to the initialization of the matrix Qmat
- Each iteration of this loop is one call to our generate_episode function.
- Our state s is kept by our agent, so we initialize it in the beginning of the episode.
- Next, our agent chooses an action using the ϵ-greedy algorithm described earlier.
- The next loop iterates over each step of the episode
- First, our agent moves in the selected direction and observes the reward of the next tile.
- Second, the agent again selects ϵ-greedily what move it will choose next.
- We now have everything that we need to update the Q-value! It is at this point that we decide wheter to use the SARSA or the Q-learning algorithm. The former will update the Q-value corresponding to the direction chosen by the agent. The latter always updates the direction of the maximum Q-value. Depending on the chosen algorithm, we apply the corresponding update rule.
- Finally, we update the direction for our agent to take next to be the same as the one we just selected (and which was used for the SARSA update).
That’s it! We have successfully built a working Gridworld and an agent to explore it.
Let’s take a look at the optimal policy that our algorithms arrive at and the corresponding Q-values. The red tile is the snake pit, the green one is the treasure and the black tiles (white in the heat map on the right) are the walls which our agent cannot go through.
We can see that the SARSA algorithm successfully converges to an optimal policy. This can be easily verified since our MDP is a simple gridworld and its fairly easy to imagine how an optimal policy looks like. It completely avoids the snake pit and always goes directly to the treasure, and around the walls. On the right figure we see the Q-values associated with the optimum policy at each tile. Interesting things to note are:
- The tiles right next to the treasure have Q-values equal to the reward of landing on the treasure tile. This is to be expected since as soon as our agent realizes that its goal is to get to the treasure, his Q-value will start convergening to r, the reward for finding the treasure. Also note that the episodes end when he does so, meaning that the Q(s’,a’) values make no difference for the Q-value of these tiles.
- The Q-values of each tile decrease linearly with the distance to the treasure. Each extra step decreases it by 1. This result is also expected since we set the default penalty for moving to be -1. meaning that tiles further away will necessarily have lower Q-values than tiles close to the treasure. This also allows interesting insight into our world since we know that tiles with equal Q-values are exactly at the same distance from the end-goal of our agent, meaning that he is indifferent of going one way or the other. For example, from the -5 tile (the furthest away from the treasure) the agent could go left or down and end up with the exact same reward, as long as it always follows the decreasing tile Q-value path.
The Q-value matrix of the converged Q-learning algorithm looks exactly the same as the SARSA. This is a rather nice result since we expect the Q-values of both converged algorithms to be the same. Even if there are multiple optimal policies, the optimal Q-values of each tile are unique.
Let’s look at a more complex example of the gridworld, introducing something like a maze into it.
We can see how effective our algorithm is in a this environment. It again manages to find an optimal policy for our agent. The locations on the top left corner of the map are the ones with the smaller q-values due to their distance to the treasure, and proximity to the snake pit. It is interesting to see that the agent figures out that, with a bigger map like this one, going into the snake pit in kamikaze style is actually preferable over making the journey all the way to the treasure. This is logical since the penalty for making such a big trip is larger than just ending the episode right there and jumping into the snake pit.
I made this post with the intention of demystifying some of the ‘magic’ behind reinforcement learning algorithms. It is amazing what a simple update rule can achieve when implemented the right way and I hope I was able to show exactly that. I did not go into any kind of mathematical depth because the point was to keep it simple enough so that people without AI/maths background could understand and potentially reproduce these algorithms on their own. Hopefully I didn't do too many technical mistakes but feel free to correct me! Thank you for reading my first post, if you are interested in learning more about AI (not just reinforcement learning!) please let me know. I’m by no means an expert but I do like to share the things that I am currently learning, and explaining things to others is the best way to learn yourself.
Link to github repo: https://github.com/filipkny/MediumRare