Reinforcement Learning — 2

Nikhil Badveli
4 min readMay 6, 2022

--

This is the second part in the RL series. In the previous part, we have learnt about the basics of RL and implemented a simple Q-learning solution for the cartpole balancing problem.

Now, in this post we will learn about solving the same problem using a bit more advanced approach called Deep Q-learning. This technique was originally published by researchers from DeepMind¹, a subsidiary of Google that’s behind the famous AlphaGo project. The authors achieved state-of-the-art results in six of the Atari-2600 arcade games and has surpassed humans in three of them.

A documentary on AlphaGo. Image taken from DataDrivenInvestor.

(I’m going to assume that you are already familiar with neural networks and how awesome they are. If not, I would highly suggest to at least learn the basics.)

An attempt at introducing Deep Q-learning

The entire premise of this approach is to utilize the power of deep neural networks for an RL problem. Specifically, there are two common ways it is done. One is Q-network and the other is called Policy network. The idea is to use the Universal function approximation property of neural networks to learn the Q-function for the former and the policy function for the latter.

For the cartpole balancing problem we’re going to build a Policy network using keras library in python. Essentially, the inputs for this network are state values (4 for this problem) and the outputs are action values (2 for this problem). And there will be a couple of hidden layers in between. All layers are fully-connected. The below code builds the network.

Now, one of the key components of this technique is the replay buffer. It is just a memory unit that stores the history of the past states, the action taken, the reward obtained and the next state for each of those past states. You can think of this like a list that holds the tuple (state, action, reward, next_state, done) for each of its elements. This is important because the network can only learn from past experience and so we’re collecting it before we use it to train.

The next step is the training part of the network. A couple of questions that might arise here are, “How frequently should we train the network and for how many epochs?”, “How do we transform the values in the replay buffer as inputs and targets to the network?”. Read on to find out :)

Since the replay buffer is of fixed size and as we cannot train the network on the entire experience accumulated, we will randomly sample the buffer for a fixed mini_batch_size (say 32) and then use these to train the network in every time step. Each time we will train only for one epoch. (probably not ideal)

In the above code, there’s a for loop that shows how to convert the replay buffer tuple into actual inputs and targets to be passed to the model.fit() function. Here the target values are between 0 and 1.

We can put together everything until now in a python class. You can find this in this jupyter notebook along with the rest of the code.

Training the network

Now, using the above created class we can instantiate a new Deep Q-Network object and then it can be used by the agent to predict what action to take in any given state. I’m assuming you already know about episodes and time steps :)

One thing to note is that we are using an epsilon greedy strategy to balance the exploration and exploitation of the state space. And the epsilon factor is used in a decaying fashion to decrease the exploration in the later stage.

Testing the network

While testing the network, we will always exploit i.e., use our knowledge of the state space so far and make the network predictions which are used to take an action.

The problem is considered to be solved if the average score per episode is more than 195. This is to be sure that the network solved the problem by learning and not through sheer luck.

Plotting the results

Finally, a good old plot to understand how the network evolved as it gained more experience through playing in a number of episodes. Here’s the code for it.

And here’s the actual plot generated after training the policy network created above.

Avg. score vs No. of episodes

Sorry about having no legend, no title and axis labels :(

References

  1. Playing Atari with Deep Reinforcement Learning https://arxiv.org/pdf/1312.5602.pdf

Also, read my previous post on Reinforcement Learning.

--

--