I researched model-based learning this summer. Here are the results.
For the last few months I’ve been interning at the Machine Intelligence Research Institute, getting my hands dirty by replicating state of the art AI techniques.
I wanted to start with the dumbest possible problem that hadn’t yet been solved, which seemed to be the OpenAI gym environment CartPole (above). The goal is to balance a pole on a cart for 200 timesteps. At each timestep we can choose to push the cart left or right. In normal CartPole, you get four numbers.
- position of the cart
- velocity of the cart
- angle between cart and pole
- angular velocity of the pole
The game ends when the pole falls over too much, the cart moves too far, or 200 timesteps have elapsed. In the latter case the game is considered solved or beaten.
CartPole can be solved using standard model-free reinforcement learning techniques (described later) in about 30 seconds, but using those techniques on images (aka pixels) of CartPole (as shown above) instead of position, velocity, etc, is much harder. As far as I know no one has published an agent that can consistently beat CartPole learning from images.
When my fellow intern Andrew Schreiber and I tried to solve this from only the rendered images, we couldn’t get an agent to consistently beat the game using standard techniques like Deepmind used to beat Atari. One reason it’s harder: the images are 400x600, which is much larger than the Atari’s 192x160. It was an open question whether we could get away with simply downsampling or cropping without losing too much information, and whether other techniques would be necessary to beat this game.
Model-based reinforcement learning
Reinforcement learning agents receive observations of their environment and want to choose actions to maximize their rewards. Think of this as seeing the state of a video game and pressing the right controls at the right time to get as high a score as possible. In CartPole, the reinforcement learning problem is seeing a series of 400x600 images learning to press LEFT or RIGHT appropriately to keep the pole balanced as long as possible.
In reinforcement learning, there are two different styles: model-based and model-free learning. In model-based learning, agents have access to the output of a separately trained model of the world to aid them in achieving their goals. In the CartPole game, that might mean an estimate of what the position of the cart and angle of the pole will be if it pushes left vs right. Use of models is also sometimes called imagination, especially when it is fed back into itself to predict many steps into the future. In model-free learning, the agent only receives the raw environment, which in the version of CartPole I’m interested in, is just images.
In practice, most researchers working on reinforcement learning use model-free learning. The intuition for this is that if you don’t know the rewards, you don’t know what part of the environment is important. In CartPole, the environment is extremely simple. There’s only one object and, aside from the random starting angle, the agent has complete control over what happens. But consider a self-driving car. Its observations are from cameras and LIDAR and its rewards are not running into anything and staying in its lanes. The observations are extremely complicated. Perhaps a world model trained on these observations would spend a lot of representational power on how trees blow in the wind. This obviously won’t help it drive better, but models for model-based learning don’t have access to rewards, such as what “drive better” means. So for this reason most researchers use model-free reinforcement learning to train an agent to act based on both raw observations and rewards at the same time.
Model-based reinforcement learning particularly caught my interest because I thought it might speed up beating CartPole, and thus make the problem tractable to experiment on. Intuitively, if an agent has more concise, relevant information, it should learn faster. The idea (sketched above) is to first train a model whose inputs are the environment state, which is an image, and the agent’s action, pushing left or right, and output the predicted next state, which is another image. We collect a bunch of data (66,000 image/action pairs) from taking random actions and use that to train the model. Training is fast because we’re doing supervised learning rather than reinforcement learning. In reinforcement learning, the agent has to decide what to do at each timestep, then do it, then see what happened and update its model. In contrast, we can train the model in batch, with no additional access to the environment. This lets it train in minutes instead of hours.
Given a model, an agent should have some understanding of how its actions affect the environment since the model tells it, “If you move left, here’s what the next frame will look like. If you move right, it’ll look this way instead.” This should allow an agent to more easily figure out what to do to maximize its rewards, or in the CartPole game, to keep the pole balanced for as long as possible, eventually beating the game.
Building the model
We hypothesized that a model-based reinforcement learning agent could learn to solve CartPole faster than a similar model-free agent. In the rest of the post, I’ll show some fun pictures and the basic components of the research. I’m not covering the hundreds of experiments I had to run to get things working, since that’d take far too much space!
First, we collected 66,000 random frames. Here’s what they look like all stacked on top of each other after downsampling and converting to grayscale:
Looks like the data is diverse. The cart moves far to each side. Hopefully it’s enough to train an interesting model.
Since the actual environment can be uniquely represented by four numbers (position, velocity, angle, angular velocity), I thought an autoencoder-like architecture would be a good fit for the model. The cool thing about an autoencoder is that it’s a form of compression. The input is something high dimensional (images), and so is the output (an image). But in the middle, there’s a bottleneck layer that can be any size that works well for the application. The bottleneck layer size determines the intensity of the compression. I experimented with a couple of different sizes to see which produced the most accurate output, but in the end I settled on a 4 dimensional (aka 4 numbers) width so that I could compare the encoded representation from the autoencoder with the actual variables (position, velocity, etc), even though the model never had access to them during training. I hoped that it might figure it out on its own.
The first time I trained the model and hooked it up to a reinforcement learning agent, it completely failed to learn. To find out why, I did some visualizations to ensure it was learning what I wanted it to.
Below, each subplot is a scatter plot of one of the actual variables compared to one of the learned variables for 100 random timesteps in CartPole. The R value, which is close to 1 or -1 if the variables are correlated and 0 if not, is shown above every subplot. The top left subplot shows the position of the cart compared to the first encoded dimension, labeled column a, from the model. It’s a strong correlation, so it seems the model found a good way to represent position. The next row down compares the encoded representation to the cart’s velocity. Here, the second encoded dimension, labeled column b, has the highest R value, but it’s only -.47, which isn’t as strong. Looking at angle and angular velocity, it looks like the static variables (position and angle) have stronger R values than the dynamic ones (velocity and angular velocity). It makes sense that dynamic variables would be harder to learn because they can only be inferred from the difference between subsequent frames, whereas static ones can be inferred from a single image.
When I first produced this chart, all the R values were much lower, especially for the dynamic variables. Each time I modified the autoencoder architecture, I regenerated the chart and made sure the R values were increasing.
I also wanted to see that when the model gets a LEFT input, the cart indeed moves left. My first model showed the cart moving left regardless of whether I input RIGHT or LEFT, but after a few tries, I got imagination rollouts like this:
To generate these, I started with a random frame, then fed that into my autoencoder. That gave me 1 more frame, which I fed back into the same autoencoder. I repeated the process until I had enough. Interestingly on the top there’s a strange artifact where the pole disappears. I think this is because in the training data, the pole never falls below a certain angle. Since we keep feeding the same data back in, the most sensible thing the model can do is assume the pole must be far away. They do have the property we’re looking for, which is for the cart to move in opposite directions for each action, and for the pole to move opposite to the cart.
As with the scatter plots, I checked the imagination rollouts after each change I made to the model to ensure it still looked reasonable. The key fix was to change how the actions were fed into the autoencoder (see source code below for details). Seeing the above, I felt confident I could move on.
Beating the game
Next, I generated baselines to compare against. Originally, I was using OpenAI/baselines DQN implementation and made very slow progress since it only used 1 CPU core at a time. I eventually switched to their A2C, which was much faster since it could use all 8 of my CPU cores, did intelligent batching, and had a better learning rule. It solved classic CartPole (4 numbers) in less than 30 seconds (cyan line below). The environment is solved when the mean score for the last 100 episodes is 195 or more out of a max of 200.
Still, I couldn’t get it to solve from images at all, whether model-free or model-based. At some point, I realized I could experiment faster if I used a variation of classic CartPole where I removed velocity and angular velocity and trained only on position and angle, which I called “static CartPole”. This allowed me to run tens of experiments per day, which showed increasing the batch size helped a lot. The batch size is how many examples the agent gets to see each time it updates itself; a bigger batch size tends to lead to more stable learning. Eventually I got it solving the static environment in 5 minutes. From there, I decided to be patient and, using the same batch size and other configuration, let the agent train from images for much longer than I had before. After a few more experiments where I increased the model depth and capacity from the standard Atari convolutional neural net, a model-free agent could solve from pixels in a few hours (green line below).
From there, it was pretty simple to use a model-based agent to solve the game from images, too (blue). I used the same network architecture as solving from static (aka position and angle only), but feeding the network 8 numbers per timestep instead of 2: 4D model output from choosing to push left and 4D from right. It beat the game almost twice as fast as the model-free agent!
As far as I know this is also the first instance of anyone solving CartPole from pixels. I hope others find it useful! I’m sure with more computational power, a little hyperparameter optimization would make both model-based and model-free agents train much faster, but I only had my own machine.