This post continues from my previous blog post on Probabilistic programming languages (PPLs) where I introduced PPLs in the context of Bayesian inference, the task of learning underlying model parameters from data. We looked at a simple example when we have a probabilistic model to describe how a cloud moves according to an underlying hidden variable that we can’t directly observe; the wind speed. We refer to these hidden variables as “latent variables”.
We started with a simple model based on
distance = speed × time. We are interested in more complex models that describe physical systems we might want to forecast, such as the weather. These systems often involve solving differential equations to update variables of interest. In this post, I’ll present an example of how probabilistic programming can be used for Bayesian inference on a differential equation model. We will see how inference can be useful to learn latent variables that fit the model best, even if the model isn’t perfect. I’ll be using Pyro (a probabilistic programming language built on top of PyTorch in Python).
1D Linear Advection
A simple place to start is with the 1D Linear advection equation for a travelling wave. Let’s assume we know the height of the wave at initialisation at time t=0. Then, we will observe it after time, t. During this time, it follows the 1D linear advection equation which we know and have a model for (albeit an imperfect model):
where h(x,t) is the wave and c is the constant speed.
We want to infer the speed of the wave, a latent variable in our model. Then we can use this to predict the shape and position of the wave at some future time, t, according to our model.
We will start with a grid of width L=10 m, with nx=200 grid points. The initial wave height is a cosine bell shape between 0 and 4 m.
Next, let’s set up the numerical model. The spatial grid can be indexed by i, separated by Δx and the timesteps indexed by t, separated by Δt. We can discretise the above equation using forward in time, backward in space (FTBS) finite difference scheme (more info on that here).
Note that this discretisation doesn’t give us the exact solution. The scheme is stable but we will see it produces unwanted numerical diffusion. This discretised 1D advection equation is implemented in the code below:
The following plot highlights how the 3 waves with different speeds propagate over
nt timesteps of width
Notice how the amplitude of the waves decay slightly due to the numerical errors in the model. This isn’t ideal but it is relevant — in reality many of our models of real world systems are imperfect. Can we still use probabilistic programming to infer the latent variables of our imperfect model?
The Probabilistic Programming Approach
We put this into the complete Pyro model below. There are 4 steps:
- Set up the initial conditions for the wave height.
- Sample the latent variable, the wave speed, from our prior distribution.
- Run a deterministic finite difference scheme for the 1D advection equation to step forward in time.
- Sample the observed wave height, including the observation error. The observation error at each point is independent.
This model is more complicated than the simple example in the previous blog post, because we have many iterations of a time-stepping scheme embedded within it. This reflects the type of models we are typically doing inference on.
You might also notice we have an additional loop because we will observe the wave height at multiple points with independent measurement error on each observation. The
pyro.plate("x_axis", nx) statement tells Pyro that these observations are conditionally independent (i.e. independent from each other but they all depend on the speed, which is outside of this loop).
I’ve also assumed that we don’t know much about the wave speed, so I’ve provided a vague prior — the speed is sampled uniformly between a minimum and maximum value.
The next step is to condition on observations. In this case, we will assume we observe the wave at only some of the points on the grid. We can generate data that corresponds to the exact solution for a wave, using
where hₒ is the initial condition of the wave. For example, below I’ve generated some data using the analytical equation for a wave travelling at 1.65 m/s (blue). The solution to the discretised model is also shown (orange).
We will assume that we see the exact solution, but we only observe the wave at few points. Also, our observations have a small measurement error (the black crosses). Hopefully, we can use the discretised model to infer the wave speed from this data.
As with the previous blog post, we will condition on observations with the
pyro.condition statement, which fixes the height of the wave at the observed points.
This statement only fixes the waveheight at the observed points. If we run this function, we see that speed is still sampled from the prior distribution and the waveheight is calculated based on this. However, Pyro then effectively overwrites the wave height at observed points with the data. We can see this in the example plot below, where the model has sampled a value of the speed that happens to be too fast and predicts the wave further ahead compared to the observations. This isn’t particularly useful until we carry out the inference step.
We want to find a function that samples the entire wave based on the data. Following the same method as before we will infer the latent variable, speed, using Stochastic Variational Inference (SVI). The following piece of code defines the “guide” which is an approximation to the posterior distribution for the speed.
In the above code, the posterior distribution for the speed is approximated with a Normal distribution with mean a and standard deviation b. We then use SVI to infer the values of these, by minimising the loss function (the Evidence Lower Bound Operator or ELBO).
The gif above shows the predicted wave for a few samples during the SVI process. Initially, there is a large variance on the speed, and therefore on the predicted wave position. As the number of iterations is increased, the distribution for the speed is refined to agree with the value that produced the observations. We can see the predicted wave moves towards the observed points and the variance between samples is reduced.
Once it looks like the distribution for the speed has converged, we can use this to do useful things, like sample from the distribution and make predictions about the position of the wave at different time points, in just two lines of code, for example:
This inference method learns the speed =1.64 ± 0.03m/s, which is close to the true value used to generate the data (1.65 m/s). We have inferred a latent variable of the model, even though the model is not perfect.
It is important to remember that we’ve learned the latent variable that fits the model best for the given observations. This might not be the true value of the speed and it might not even produce good predictions when used in the model in different situations, e.g. different discretisation, different wave-shape. It all depends on how good our model is. Inference just helps us get as close as possible to the real world observations with this model.
The great thing about probabilistic programming languages like Pyro is that once you know how to do inference it doesn’t take long to set this up for a different model. The SVI algorithm is almost identical to the one used in the previous blog post. The main user input is the model, with a sensible choice for the prior distribution, and the guide. Then we should be able to use the same approach with real data for inference on latent variables in more complex physical models, to bring the output of these models as close as possible to real world observations.
The notebook version of this blog is here: https://github.com/informatics-lab/probabilistic-programming
There are some Computational Fluid Dynamics notebooks with further examples of differential equations in fluid dynamics here: https://lorenabarba.com/blog/cfd-python-12-steps-to-navier-stokes/
The Pyro documentation has lots of useful examples here: https://pyro.ai/examples/intro_part_i.html