Differentiable Programming and Neural ODEs for Accelerating Model Based Reinforcement Learning and Optimal Control
We will explain the theory in detail first. Feel free to jump to the code section.
Abstract
We simplify and accelerate training in model based reinforcement learning problems by using end-to-end differentiable programming in Julia. We compute policy gradients by differentiating through a continuous time neural ODE consisting of the environment and neural network agent, though the technique applies to discrete time also. We train in seconds on a single instance agents for the harder swing up variant of the cartpole problem under different objectives. For comparison, Microsoft Bonsai in its demo tackles the initially upright cartpole balancing problem using conventional reinforcement learning that does not first fit a differentiable environmental model.
Motivation: (much) faster reinforcement learning
Reinforcement learning (RL) is still a baby in the machine learning family. While computer vision, natural language processing, and recommendation systems touch our lives everyday, reinforcement learning is just starting to make an impact. Sure, there are impressive demos by Google’s Alpha Go, OpenAI 5, and Alpha Star. However they required large engineering teams and tons of compute. Even simple games can take dreadful amounts of tuning and millions of training epochs. Part of the problem is we’re treating the environment too much as a black box. State of the art policy gradient algorithms (A2C including PPO) essentially sample this black box and iteratively design consistent policy and value estimators. In some cases the environment is indeed a black box, e.g. reacting to a novel environment for robotics or playing against an external adversary in games.
However, in many scenarios we already own the environmental simulation! Examples: video games, self driving simulators, control systems, constrained robotics, industrial automation, and process engineering. Why sample around the system dynamics when we know it already?
For model based RL, we should combine the model and neural network agent into a single system loop that as a whole can be differentiated to yield policy gradients for the agent. (Our policy is deterministic wrt state but a stochastic formulation is possible.) In the past the derivative program was made by hand, e.g. optimal control in aeronautics. However, recent advances in differentiable programming and neural ODEs can automate the process! We demonstrate an orders of magnitudes improvement in learning speed on the “Hello World” of reinforcement learning: the cartpole problem.
Toy Problem
In the cartpole problem, the usual goal is to balance an upright pole by moving the base cart as available in the OpenAI Gym. This is actually too easy because the system is fairly linear at small angles :) Instead, we’ll start the pole hanging down and then compute the base movements to bring it up and balanced, aka cartpole swing up. This sweeps through the system’s nonlinearities. We wish to minimize the angle, angular velocity, and cart velocity in the end while staying within a time limit.
The System (Environment) is the cartpole, while the Controller (Agent) dictates how to move the cart. The system state space u consists of the cart position, cart velocity, pole angle, and pole angular velocity. The time derivative f describes its time evolution. It’s a function of its current state u and the control action which is the force applied to the cart. We make the controller a neural network g(u, p) with the system state u as input and parameters p as the weights. We specify the system initial condition and let it run, generating a trajectory in state space. We construct a loss functional l acting on the trajectory, penalizing deviations from desired behavior. Goal is to minimize l wrt p, ie seeking the optimal weights for the neural controller.
Magic ingredient: differentiable programming
They key difference from traditional RL is bundling the agent and environment as one differentiable system. For us, the neural network agent’s action affects the system’s time evolution, or the ODE’s time derivative. This “neural ODE” setup were first popularized in the 2018 paper “Neural Ordinary Differential Equations,” winning the best paper award at the prestigious NIPS conference. The original paper used it to approximate residual connections in discrete neural networks to improve parameter efficiency. However, it has since been used for time series and dynamical systems modeling, going back into the field it came from. We train the “neural” part of the neural ODE to be our agent or controller while the system evolves as a function of its state variables and the control signal.
You might wonder how on earth we can differentiate this neural network embedded as part of an ODE that is integrated in time to yield a trajectory along which the state is sampled at multiple time points to compute the loss!? Actually, you’ve already answered the question :)
The loss depends on the neural network weights through a chain of functions. The chain rule as applied to the NN is called backpropagation. Now we just need to do the same for the ODE. The PDE and inverse problems literature provide the solution as the “adjoint method,” which yields a “dual” adjoint ODE which is integrated backwards in time. Luckily, the language Julia has a powerful automatic differentiation package called Zygote which can do the dirty work ;) We simply code up the simulation and can compute the gradient with just 1 line of code! In Python, Tensorflow and Pytorch perform AD on graphs of predefined neural network building blocks in their libraries, whereas Julia’s Flux (and Zygote underneath) does this for (almost) arbitrary functions written in Julia! Zygote digs into the compiler and automagically applies the chain rule on intermediate instructions. This results in performant statically compiled gradient code.
Julia is a modern general purpose Pythonic language but with easier and more performant syntax for scientific computing and differentiable programming. Think of it as a happy marriage between Python, R, Matlab, and C++.
Code
Complete code in Julia is at Github.
We first construct time derivative of the system. A derivation using Lagrangian mechanics is at https://metr4202.uqcloud.net/tpl/t8-Week13-pendulum.pdf .
# physical params
m = 1 # pole mass kg
M = 2 # cart mass kg
L = 1 # pole length m
g = 9.8 # acceleration constant m/s^2# map angle to [-pi, pi)
modpi(theta) = mod2pi(theta + pi) - pi#=
system dynamics derivativedu: du/dt, state vector derivative updated inplace
u: state vector (x, dx, theta, dtheta)
p: parameter function, here lateral force exerted by cart as a fn of time
t: time
=#
function cartpole(du, u, p, t)
# position (cart), velocity, pole angle, angular velocity
x, dx, theta, dtheta = u
force = p(t) du[1] = dx
du[2] =
(force + m * sin(theta) * (L * dtheta^2 - g * cos(theta))) /
(M + m * sin(theta)^2)
du[3] = dtheta
du[4] =
(-force * cos(theta) - m * L * dtheta^2 * sin(theta) * cos(theta) + (M + m) * g * sin(theta)) / (L * (M + m * sin(theta)^2))
end
Next we define our controller neural network as a MLP with 1 hidden layer. If we require a more nonlinear agent we can obviously use a deeper network.
# neural network controller, here a simple MLP
# inputs: cos(theta), sin(theta), theta_dot
# output: cart force
controller = FastChain((x, p) -> x, FastDense(3, 8, tanh), FastDense(8, 1))# initial neural network weights
pinit = initial_params(controller)
We now set up the whole neural ODE and define the ODE solver that integrates it forward in time.
#=
system dynamics derivative with the controller included
=#
function cartpole_controlled(du, u, p, t)
# controller force response
force = controller([cos(u[3]), sin(u[3]), u[4]], p)[1]
du[5] = force# plug force into system dynamics
cartpole(du, u[1:4], t -> force, t)
end# initial condition
u0 = [0; 0; pi; 0; 0]
tspan = (0.0, 1.)
N=50
tsteps = range(tspan[1], length = N, tspan[2])
dt = (tspan[2] - tspan[1]) / N
# push!(u0, 0)# set up ODE problem
prob = ODEProblem(cartpole_controlled, u0, tspan, pinit)# wrangles output from ODE solver
function format(pred)
x = pred[1, :]
dx = pred[2, :]theta = modpi.(pred[3, :])
dtheta = pred[4, :]# take derivative of impulse to get force
impulse = pred[5, :]
tmp = (impulse .- circshift(impulse, 1)) / dt
force = [tmp[2],tmp[2:end]...]return x, dx, theta, dtheta, force
end# solves ODE
function predict_neuralode(p)
tmp_prob = remake(prob, p = p)
solve(tmp_prob, Tsit5(), saveat = tsteps)
end
We define our loss function to penalize angular and velocity deviations at the end. We add a penalty for average angular deviation, thus encouraging the controller to swing up the pole faster. We use least squares penalties but you can use any function, including log or even discontinuous penalties!
# loss to minimize as a function of neural network parameters p
function loss_neuralode(p)
pred = predict_neuralode(p)
x, dx, theta, dtheta, force = format(pred)
loss = sum(theta .^ 2) / N + 4theta[end]^2 + dx[end]^2return loss, pred
end
Finally we train
i = 0 # training epoch counter
data = 0 # time series of state vector and control signal
# callback function after each training epoch
callback = function (p, l, pred; doplot = true)
global i += 1global data = format(pred)
x, dx, theta, dtheta, force = data# ouput every few epochs
if i % 50 == 0
println(l)
display(plot(tsteps, theta))
display(plot(tsteps, x))
display(plot(tsteps, force))
endreturn falseendresult = DiffEqFlux.sciml_train(
loss_neuralode,
pinit,
ADAM(0.05),
cb = callback,
maxiters = 1000,
)p = result.minimizer# save model and data
open(io -> write(io, json(p)), "model.json", "w")
open(io -> write(io, json(data)), "data.json", "w")
And animate :)
gr()
x, dx, theta, dtheta, force = data
anim = Animation()plt=plot(tsteps,[modpi.(theta.+.01),x,force],title=["Angle" "Position" "Force"],layout=(3,1))
display(plt)
savefig(plt,"cartpole_data.png")for (x, theta) in zip(x, theta)cart = [x - 1 x + 1; 0 0]
pole = [x x + 10*sin(theta); 0 10*cos(theta)]
plt = plot(
cart[1, :],
cart[2, :],
xlim = (-10, 10),
ylim = (-10, 10),
title = "Cartpole",
linewidth = 3,
)
plot!(plt, pole[1, :], pole[2, :], linewidth = 6)frame(anim)
endgif(anim, "cartpole_animation.gif", fps = 10)
Results
Training to a decent solution takes less than a minute! The controller (agent) neural network first accelerates the cart rapidly to swing the pole. It lets the pole clear the horizon on its angular momentum and then accelerates the cart in the opposite direction to continue the pole’s ascent. Finally, it does a small correction for a standstill at the upright position. Perfecto! We didn’t tell it how to move. It learned on its own!
Beauty of the method lies in its versatility. If the objectives and constraints change, we can change the loss function accordingly. If there’s friction and we wish to minimize the energy loss, we simply add a friction term to the model and tag on a frictional power integral in the loss function.
To illustrate, suppose our mechanical engineer says the motor can’t generate the 60 N peak force demanded by the controller. Actuator limits are common in real life. Motors, fans, pumps all have finite capacity. To reduce the force, we add a max force penalty to the loss function. Also, we increase the time allotted to 10s.
tspan = (0.0, 10.0)
...
loss = sum(theta .^ 2) / N + 4theta[end]^2 + dx[end]^2 + .1sum(x .^ 2) / N + .001maximum(force.^2)
Guess what happens upon training?
Ingenious! The AI taps into the natural system resonance like a child on a swing. It periodically accelerates and decelerates to couple energy into the pole, which reduces the peak force required by 10x!
Caveats
Neural ODEs can be prone to local minima or underfitting. This isn’t due to the neural network but rather the natural minima of the system. For example, depending on the loss function, the cart can end up not moving. In the language of the calculus of variations, any perturbation to the starting trajectory may in fact increase the loss. In such cases, one needs to adjust the loss function, perhaps penalizing the system for staying still at the beginning. See https://diffeqflux.sciml.ai/dev/examples/local_minima/
Discussion
- Credit goes to automatic differentiation, namely Julia’s Zygote. It saves us from having to hard code the gradient function of (almost) arbitrary code, just as Tensorflow did so for the restricted case of neural network chains.
- Neural ODE is a continuum model that requires no time discretization or interpolation of the dynamics. The ODE solver is instructed to sample at whatever time points necessary to compute the loss. Our method is not only well suited for continuous time physical systems, it also works on discrete time systems. AD would work the same on a computer game loop (even faster).
- Third, we can make the system simulation data driven. There, we can replace the physics (or part thereof) with a system dynamics neural network. We train it on observation data, such as sensors time series from a vehicle or industrial process. Then, we stick in the controller neural network to compute optimal control.
Acknowledgment
Thanks to the open source Julia community: Chris Rackauckas for DiffEqFlux, Mike Innes for Zygote, and many others..
References & Tutorials
Neural Ordinary Differential Equations. https://arxiv.org/abs/1806.07366
Neural Ordinary Differential Equations with sciml_train. https://diffeqflux.sciml.ai/dev/examples/neural_ode_sciml/
Forecasting the weather with neural ODEs. https://sebastiancallh.github.io/post/neural-ode-weather-forecast/
A Differentiable Programming System to Bridge Machine Learning and Scientific Computing. https://arxiv.org/abs/1907.07587
Note
This is getting turned into a conference submission. I’d also love to help if you have any other question, idea or problem.
Paul Shen
MS Electrical Engineering, BS Math, Stanford University
https://www.linkedin.com/in/paulxshen/