Neural Ordinary Differential Equations and Dynamics Models
This post was written by Aidan Abdulali.
In this post, we explore the deep connection between ordinary differential equations and residual networks, leading to a new deep learning component, the Neural ODE. We explain the math that unlocks the training of this component and illustrate some of the results. From a bird’s eye perspective, one of the exciting parts of the Neural ODEs architecture by Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud is the connection to physics. ODEs are often used to describe the time derivatives of a physical situation, referred to as the dynamics. Knowing the dynamics allows us to model the change of an environment, like a physics simulation, unlocking the ability to take any starting condition and model how it will change. With Neural ODEs, we don’t define explicit ODEs to document the dynamics, but learn them via ML. This approach removes the issue of hand modeling hard to interpret data. Ignoring interpretability is an issue, but we can think of many situations in which it is more important to have a strong model of what will happen in the future than to oversimplify by modeling only the variables we know. NeuralODEs also lend themselves to modeling irregularly sampled time series data. The standard approach to working with this data is to create time buckets, leading to a plethora of problems like empty buckets and overlaps in a bucket. The NeuralODE approach also removes these issues, providing a more natural way to apply ML to irregular time series.
To explain and contextualize Neural ODEs, we first look at their progenitor: the residual network. In a vanilla neural network, the transformation of the hidden state through a network is h(t+1) = f(h(t), 𝛳(t)), where f represents the network, h(t) is the hidden state at layer t (a vector), and 𝛳(t) are the weights at layer t (a matrix). The hidden state transformation within a residual network is similar and can be formalized as h(t+1) = h(t) + f(h(t), 𝛳(t)). The difference is we add the input to the layer to the output of the layer.
Why do residual layers help networks achieve higher accuracies and grow deeper? Firstly, skip connections help information flow through the network by sending the hidden state, h(t), along with the transformation by the layer, f(h(t)), to layer t+1, preventing important information from being discarded by f. As each residual block starts out as an identity function with only the skip connection sending information through, depth can be incrementally introduced to the network via training f after other weights in the network have stabilized. If the network achieves a high enough accuracy without salient weights in f, training can terminate without f influencing the output, demonstrating the emergent property of variable layers.
Secondly, residual layers can be stacked, forming very deep networks. Introducing more layers and parameters allows a network to learn a more accurate representations of the data. But why can residual layers be stacked deeper than layers in a vanilla neural network? To answer this question, we recall the backpropagation algorithm. To calculate how the loss function depends on the weights in the network, we repeatedly apply the chain rule on our intermediate gradients, multiplying them along the way. These multiplications lead to vanishing or exploding gradients, which simply means that the gradient approaches 0 or infinity. Gradient descent relies on following the gradient to a decent minima of the loss function. A 0 gradient gives no path to follow and a massive gradient leads to overshooting the minima and huge instability.
As introduced above, the transformation h(t+1) = h(t) + f(h(t), 𝛳(t)) can represent variable layer depth, meaning a 34 layer ResNet can perform like a 5 layer network or a 30 layer network. Thus ResNets can learn their optimal depth, starting the training process with a few layers and adding more as weights converge, mitigating gradient problems. Thus the concept of a ResNet is more general than a vanilla NN, and the added depth and richness of information flow increase both training robustness and deployment accuracy.
However, ResNets still employ many layers of weights and biases requiring much time and data to train. On top of this, the backpropagation algorithm on such a deep network incurs a high memory cost to store intermediate values. ResNets are thus frustrating to train on moderate machines.
Differential Equations & Euler’s Method
The rich connection between ResNets and ODEs is best demonstrated by the equation h(t+1) = h(t) + f(h(t), 𝛳(t)). As stated above, this relationship represents the transformation of the hidden state during a single residual block, but as it is recursive, we can expand into the sequence below in which i is the input:
To connect the above relationship to ODEs, let’s refresh ourselves on differential equations. They relate an unknown function y to its derivatives. The solution to such an equation is a function which satisfies the relationship. Let’s look at a simple example:
This equation states “the first derivative of y is a constant multiple of y,” and the solutions are simply any functions that obey this property! For this example, functions of the form
obey this relationship. To solve for the constant A, we need an initial value for y. Lets say y(0) = 15. Solving this for A tells us A = 15.
This sort of problem, consisting of a differential equation and an initial value, is called an initial value problem.
Often times, differential equations are large, relate multiple derivatives, and are practically impossible to solve analytically, as done in the previous paragraph. Thankfully, for most applications analytic solutions are unnecessary. The value of the function y(t) at time t is needed, but we don’t necessarily need the function expression itself. Even more convenient is the fact that we are given a starting value of y(x) in an initial value problem, meaning we can calculate y’(x) at the start value with our DE.
As seen above, we can start at the initial value of y and travel along the tangent line to y (slope given by the ODE) for a small horizontal distance of y, denoted as s (step size). Our value for y at t(0)+s is
We can repeat this process until we reach the desired time value for our evaluation of y. The recursive process is shown below:
Hmmmm, doesn’t that look familiar! This numerical method for solving a differential equation relies upon the same recursive relationship as a ResNet. Let’s look at how Euler’s method correspond with a ResNet. In Euler’s we have the ODE relationship y’ = f(y,t), stating that the derivative of y is a function of y and time. Next we have a starting point for y, y(0). How does a ResNet correspond? In a ResNet we also have a starting point, the hidden state at time 0, or the input to the network, h(0). Instead of an ODE relationship, there are a series of layer transformations, f(𝛳(t)), where t is the depth of the layer. These transformations are dependent on the specific parameters of the layer, 𝛳(t). These layer transformations take in a hidden state f(𝛳(t), h(t-1)) and output
the hidden state to be passed on to the next layer. This is analogous to Euler’s method with a step size of 1.
Even though the underlying function to be modeled is continuous, the neural network is only defined at natural numbers t, corresponding to a layer in the network. In the figure below, this is made clear on the left by the jagged connections modeling an underlying function. However, only at the black evaluation points (layers) is this function defined whereas on the right the transformation of the hidden state is smooth and may be evaluated at any point along the trajectory.
Differential equations are defined over a continuous space and do not make the same discretization as a neural network, so we modify our network structure to capture this difference to create an ODENet.
The primary differences between these two code blocks is that the ODENet has shared parameters 𝛳 across all layers. Without weights and biases which depend on time, the transformation in the ODENet is defined for all t, giving us a continuous expression for the derivative of the function we are approximating. Another difference is that, because of shared weights, there are fewer parameters in an ODENet than in an ordinary ResNet. For example, a ResNet getting ~0.4 test error on MNIST used 0.6 million parameters while an ODENet with the same accuracy used 0.2 million parameters!
In the ODENet structure, we propagate the hidden state forward in time using Euler’s method on the ODE defined by f(z, t, 𝛳). However, we can expand to other ODE solvers to find better numerical solutions. With over 100 years of research in solving ODEs, there exist adaptive solvers which restrict error below predefined thresholds with intelligent trial and error. These methods modify the step size during execution to account for the size of the derivative. For example, in a t interval on the function where f(z, t, 𝛳) is small or zero, few evaluations are needed as the trajectory of the hidden state is barely changing. But when the derivative f(z, t, 𝛳) is of greater magnitude, it is necessary to have many evaluations within a small window of t to stay within a reasonable error threshold.
There are some interesting interpretations of the number of times d an adaptive solver has to evaluate the derivative. If d is high, it means the ODE learned by our model is very complex and the hidden state is undergoing a cumbersome transformation. Meanwhile if d is low, then the hidden state is changing smoothly without much complexity. Thus, the number of ODE evaluations an adaptive solver needs is correlated to the complexity of the model we are learning. In terms of evaluation time, the greater d is the more time an ODENet takes to run, and therefore the number of evaluations is a proxy for the depth of a network. In adaptive ODE solvers, a user can set the desired accuracy themselves, directly trading off accuracy with evaluation cost, a feature lacking in most architectures.
With adaptive ODE solver packages in most programming languages, solving the initial value problem can be abstracted: we allow a black box ODE solver with an error tolerance to determine the appropriate method and number of evaluation points. The pseudocode is shown on the left.
Continuous depth ODENets are evaluated using black box ODE solvers, but first the parameters of the model must be optimized via gradient descent. To do this, we need to know the gradient of the loss with respect to the parameters, or how the loss function depends on the parameters in the ODENet. In deep learning, backpropagation is the workhorse for finding this gradient, but this algorithm incurs a high memory costs to store the intermediate values of the network. On top of this, the sheer number of chain rule applications produces numerical error. Since an ODENet models a differential equation, these issues can be circumvented using sensitivity analysis methods developed for calculating gradients of a loss function with respect to the parameters of the system producing its input. We defer the curious reader to read the derivation in the original paper .
Neural ODEs for Supervised Learning
In the Neural ODE paper, the first example of the method functioning is on the MNIST dataset, one of the most common benchmarks for supervised learning. It contains ten classes of numerals, one for each digit as shown below.
The task is to try to classify a given digit into one of the ten classes. To achieve this, the researchers used a residual network with a few downsampling layers, 6 residual blocks, and a final fully connected layer as a baseline. For the Neural ODE model, they use the same basic setup but replace the six residual layers with an ODE block, trained using the mathematics described in the above section. They also ran a test using the same Neural ODE setup but trained the network by directly backpropagating through the operations in the ODE solver. Along with these modern results they pulled an old classification technique from a paper by Yann LeCun called 1-Layer MLP. The results are very exciting:
Disregarding the dated 1-Layer MLP, the test errors for the remaining three methods are quite similar, hovering between 0.5 and 0.4 percent. The big difference to notice is the parameters used by the ODE based methods, RK-Net and ODE-Net, versus the ResNet. The ResNet uses three times as many parameters yet achieves similar accuracy! This tells us that the ODE based methods are much more parameter efficient, taking less effort to train and execute yet achieving similar results. The next major difference is between the RK-Net and the ODE-Net. The RK-Net, backpropagating through operations as in a standard neural network training uses memory proportional to L, the number of operations in the ODESolver. This scales quickly with the complexity of the model. However, the ODE-Net, using the adjoint method, does away with such limiting memory costs and takes constant memory! This is amazing because the lower parameter cost and constant memory drastically increase the compute settings in which this method can be trained compared to other ML techniques. For mobile applications, there is potential to create smaller accurate networks using the Neural ODE architecture that can run on a smartphone or other space and compute restricted devices.
Limitations of Neural ODEs
Above, we demonstrate the power of Neural ODEs for modeling physics in simulation. The results are unsurprising because the language of physics is differential equations. The connection stems from the fact that the world is characterized by smooth transformations working on a plethora of initial conditions, like the continuous transformation of an initial value in a differential equation. Below, we see a graph of the object an ODE represents, a vector field, and the corresponding smoothness in the trajectory of points, or hidden states in the case of Neural ODEs, moving through it:
But what if the map we are trying to model cannot be described by a vector field, i.e. our data does not represent a continuous transformation? In the paper Augmented Neural ODEs out of Oxford, headed by Emilien Dupont, a few examples of intractable data for Neural ODEs are given. Let’s use one of their examples. Let A_1 be a function such that A_1(1) = -1 and A_1(-1) = 1.
Above is a graph which shows the ideal mapping a Neural ODE would learn for A_1, and below is a graph which shows the actual mapping it learns. Both graphs plot time on the x axis and the value of the hidden state on the y axis.
Hmmmm, what is going on here? The trajectories of the hidden states must overlap to reach the correct solution. However, with a Neural ODE this is impossible! ODE trajectories cannot cross each other because ODEs model vector fields. If the paths were to successfully cross, there would have to be two different vectors at one point to send the trajectories in opposing directions! The smooth transformation of the hidden state mandated by Neural ODEs limits the types of functions they can model. Since ResNets also roughly model vector fields, why can they achieve the correct solution for A_1? Below is a graph of the ResNet solution (dotted lines), the underlying vector field arrows (grey arrows), and the trajectory of a continuous transformation (solid curves).
Because ResNets are not continuous transformations, they can jump around the vector field, allowing trajectories to cross each other. But with the continuous transformation, the trajectories cannot cross, as shown by the solid curves on the vector field. Thus Neural ODEs cannot model the simple 1-D function A_1. In fact, any data that is not linearly separable within its own space breaks the architecture. For example, the annulus distribution below, which we will call A_2.
In this data distribution, everything radially between the origin and r_1 is one class and everything radially between r_2 and r_3 is another class. The issue with this data is that the two classes are not linearly separable in 2D space. Since a Neural ODE is a continuous transformation which cannot lift data into a higher dimension, it will try to smush around the input data to a point where it is mostly separated. However, this brute force approach often leads to the network learning overly complicated transformations as we see below.
On the left, the plateauing error of the Neural ODE demonstrates its inability to learn the function A_1, while the ResNet quickly converges to a near optimal solution. On the right, a similar situation is observed for A_2. Peering more into the map learned for A_2, below we see the complex squishification of data sampled from the annulus distribution.
Fixing The Problem
The issue pinpointed in the last section is that Neural ODEs model continuous transformations by vector fields, making them unable to handle data that is not easily separated in the dimension of the hidden state. One solution is to increase the dimensionality of the data, a technique standard neural nets often employ. The way to encode this into the Neural ODE architecture is to increase the dimensionality of the space the ODE is solved in. If our hidden state is a vector in ℝ^n, we can add on d extra dimensions and solve the ODE in ℝ^(n+d). The augmented ODE is shown below.
We are concatenating a vector of 0s to the end of each datapoint x, allowing the network to learn some nontrivial values for the extra dimensions. The data can hopefully be easily massaged into a linearly separable form with the extra freedom, and we can ignore the extra dimensions when using the network.
Below is a graphic comparing the number of calls to ODESolve for an Augmented Neural ODE in comparison to a Neural ODE for A_2.
Instead of learning a complicated map in ℝ², the augmented Neural ODE learns a simple map in ℝ³, shown by the near steady number of calls to ODESolve during training. The researchers also found in this experiment that validation error went to ~0 while error remained high for vanilla Neural ODEs. The graphic below shows A_2 initialized randomly with a single extra dimension, and on the right is the basic transformation learned by the augmented Neural ODE.
One criticism of this tweak is that it introduces more parameters, which should in theory increase the ability of the model be default. However, the researchers experimented with a fixed number of parameters for both models, showing the benefits of ANODEs are from the freedom of higher dimensions. Another criticism is that adding dimensions reduces the interpretability and elegance of the Neural ODE architecture. The appeal of NeuralODEs stems from the smooth transformation of the hidden state within the confines of an experiment, like a physics model. In this case, extra dimensions may be unnecessary and may influence a model away from physical interpretability. Thus augmenting the hidden state is not always the best idea. Furthermore, the above examples from the A-Neural ODE paper are adversarial for an ODE based architecture. Practically, Neural ODEs are unnecessary for such problems and should be used for areas in which a smooth transformation increases interpretability and results, potentially areas like physics and irregular time series data.
Conclusions and Future Work
Neural ODEs present a new architecture with much potential for reducing parameter and memory costs, improving the processing of irregular time series data, and for improving physics models. The architecture relies on some cool mathematics to train and overall is a stunning contribution to the ML landscape. In the near future, this post will be updated to include results from some physical modeling tasks in simulation.
 Neural Ordinary Differential Equations, Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. https://arxiv.org/abs/1806.07366
 Augmented Neural ODEs, Emilien Dupont, Arnaud Doucet, Yee Whye Teh. https://arxiv.org/abs/1904.01681