Follow-ups: Cart-Pole Balancing with Q-Network

In my previous post on Cart-Pole Balancing with Q-Learning, I discretized the state-space into multiple buckets to construct a Q-table. Here, I replaced it with a neural network called Q-network.

The Q-network that I made has three fully-connected hidden layers (I didn’t draw the connections in the diagram below). The inputs to the network are the four state variables: position (x), velocity (x_dot), angle (theta), and angular velocity (theta_dot). The outputs are the Q-values of the two possible actions: move to the left or move to the right.

Once the Q-network is trained, it does exactly the same thing as the Q-table. It tells you the Q-value for each action given the state.

To train the Q-network, you need to train it with more accurate Q-values over time. At each times step, you compute a predicted Q-value for a given state-action like this:

Q'(state, action) = reward + DISCOUNT_FACTOR * max(Q(next_state))

This new, “more accurate” Q-value can then be fed back into the neural network and be used to update the weights and biases through back propagation.

One of the the advantages of using a Q-network is that you no longer need to discretize the state values to build a Q-table. This is useful when the dicretization method is not obvious and when you don’t really know the range of the values that you are getting. In addition, in cases when the state-action space is large (i.e. many features and many possible actions), the Q-table method becomes impractical.

You can find my full implementation and results here. Note that most of what I did was based off of the CNTK Reinforcement Learning Basics tutorial which was actually very well-written and easy-to-follow.

With a three-layer Q-network, it only took 493 episodes to solve the problem! For a simple problem like Cart-Pole, the Q-table method was definitely faster. However, the Q-network method has the potential to tackle much harder problems!

One interesting thing that I noticed was that even though the pole stayed relatively upright during the entire episode, the cart was inching towards the right. This means that the neural network was also exploiting the fact that our evaluation time window was short enough that the cart couldn’t drift too far horizontally! If we do want to keep it at the center indefinitely, we will have to penalize horizontal displacement as well.

Here is the code for this implementation: