An Introduction to Laplace Approximations for Bayesian Deep Learning in Julia

Paving the Way Towards Low-Overhead Uncertainty Calibration

Severin Bratus
10 min readJun 24, 2023
A nice image to attract your attention. The exact inverse Fisher information matrix for a MNIST classifier network (left), its block-diagonal and tri-block-diagonal approximations (middle), and the absolute error (right). Source: Martens & Grosse (2015), Figure 6, link

This post summarizes a quarter-long second-year BSc coursework project at TU Delft. Our team of five students has made multiple improvements to
LaplaceRedux.jl, a package in Julia. Originally inspired by its Pythonic counterpart, laplace-torch, this package aims to provide low-overhead Bayesian uncertainty calibration to deep neural networks via Laplace Approximations. Now, we will begin the article with demystifying the technical terms in the last sentence, to then explain our contributions to the library, and highlight some impressions from the experience.

Bayesian Learning

Uncertainty calibration remains a crucial issue in safety-critical applications of modern AI, like, for instance, in autonomous driving. You would want your car autopilot not only to make accurate predictions, but also to indicate when a model prediction is uncertain, to give control back to the human driver.

A model is well-calibrated if the confidence of a prediction matches its true error rate. Note that you can have well-fit models that are badly calibrated, and vice versa (just like in life, you meet people who are smart, yet annoyingly arrogant).

The standard deep learning training process of gradient descent converges at a weight configuration that minimizes the loss function. The model obtained may be great, yet it is only a point estimate of what the weight parameters should look like.

However, with the sheer immensity of the weight space, neural networks are probably underspecified by the data (or, overfit). As neural networks can approximate highly complex functions, there are many weight configurations that would yield roughly the same training loss, yet with varying ability to generalize outside the training dataset. This is why there are so many regularization methods out there, to keep the models simpler. One radical, yet effective approach is described by Yann LeCun in Optimal Brain Damage (1989):

… it is possible to take a perfectly reasonable network, delete half (or more) of the weights and wind up with a network that works just as well, or better.

The loss landscape. One can imagine gradient descent as a particle, let’s say a ball, or a grain of sand, rolling to the bottom of a pit. Then for Bayesian Learning, we have as if a pile of sand poured around at that bottom point, with the pile being thicker where loss is lower. This proverbial sand pile would represent the posterior parameter distribution. Figure due to Amini et al. (2017), link

The way gradient is usually illustrated is with a picture like the one above: a curved terrain of the loss function across the parameter space. Each point of the horizontal plane corresponds to some configuration of parameters. Gradient descent seeks the point at the bottom of this terrain, as the point with the lowest loss, however as the loss-curvature is highly non-convex and high-dimensional there are many directions in which we could move and still maintain a low loss. Thus instead of a singular point we would like to specify a probability distribution around that optimal point. Bayesian methods, and in particular Laplace Approximations, allow us to do this!

Firstly, the Bayesian approach to neural network uncertainty calibration is that of modeling the posterior using Bayes’ Theorem:

Here p(D | θ) is the likelihood of the data given by the parameters θ.
The prior distribution p(θ) specifies our beliefs about what the model parameters would be prior to observing the data. Finally, the intractable constant Z is called the evidence: it characterizes the probability of observing D as a whole, across all possible parameter settings [3].

For models returning a probability distribution (e.g. classifiers), the loss is commonly defined as the negative log-likelihood. Thus if gradient descent minimizes loss, it maximizes the likelihood, producing the maximum likelihood estimate (MLE), which (assuming a uniform prior) also maximizes the posterior. This is why we call this point the maximum a posteriori, or the MAP. It makes sense to model this point as the mode of the posterior distribution, say, modeled as a normal Gaussian distribution.

Laplace Approximations

We do this by a simple-yet-smart trick introduced back in the late 18th century by Pierre-Simon Laplace, the self-proclaimed “greatest French mathematician of his time”. In general, the Laplace Approximation (LA) aims to find a Gaussian approximation to a probability density (in our case, the posterior) defined over a set of continuous variables (in our case, the weights) [5]. We can then estimate the loss (negative log-likelihood) as its second-order Taylor expansion:

Note that the first-order Taylor term vanishes at the MAP, since it contains the gradient, and the gradient is zero at MAP, since MAP is a maximum, by definition. What remains is the constant (zeroth-order) term, and the second-order term, containing the Hessian, which is a matrix of partial second-order derivatives.

Then from this approximation, we can derive the long-sought multivariate normal distribution with the MAP as the mean, and the inverted Hessian as the covariance:

The evidence Z is now also tractably approximated in closed form, allowing us to apply the Bayes’ theorem, to obtain the posterior distribution p(θ | D).
We can then express the posterior predictive distribution, to obtain the probability for an output y, given a prediction f(x*) for an input x*.

This is what we are really after, after all — instead of giving one singular point-estimate prediction ŷ = f(x*), we make the neural network give a distribution over y.

However, since the Hessian, a square matrix, defines the covariance between all model parameters (upon inversion), of which there may be millions or billions, the computation and storage of the Hessian (not to speak of inversion!) become intractable, as its size scales quadratically with the number of parameters involved. Thus to apply Laplace approximations to large models, we must make some simplifications — which brings us to…

Hessian approximations

Multiple techniques to approximate the Hessian have arisen from a field adjacent, yet distinct from Bayesian learning — that of second-order optimization, where Hessians are used to accelerate gradient descent convergence.

One such approximation is the Fisher information matrix, or simply the Fisher:

Note that if instead of sampling the prediction ŷ ~ p(y | f(xₙ)) from the model-defined distribution, we take the actual training-set label yₙ,
the resulting matrix is called the empirical Fisher, which is distinct from the Fisher, yet aligns with it under some conditions, and does not generally capture second-order information. See Kunstner et al. (2019) for an excellent discussion on the distinction.

Instead of the Fisher, one can use the Generalized Gauss-Newton (GGN):

Here J(xₙ) represents the Jacobian of the model output w.r.t. the parameters. The middle factor ∇²… is a Hessian of log-likelihood of yₙ w.r.t. model output. Note that the model does not necessarily output ready target probabilities — for instance, classifiers output logits, values that define a probability distribution only after the application of the soft-max.

Unlike the Fisher, GGN does not require the network to define a probabilistic model on its output [9]. For models defining an exponential family distribution over the output, the two coincide [7]. This applies to classifiers since they define a categorical distribution over the output, but not to simple regression models.

These matrices are quadratically large, it is infeasible to store them in full.
The simplest estimation is to model the matrix as a diagonal — however one can easily contemplate how crude this approximation can be: for 100 parameters, only 1% of the full Hessian is captured.

A more sophisticated approach, due to Martens and Grosse (2015), is inspired by the observation that in practice the covariance matrices (i.e. inverted Hessians) for neural networks are block-diagonal-dominant. Thus we can effectively model the covariance matrix (and hence the Fisher) as a block-diagonal matrix, where blocks correspond to parameters grouped by layers. Additionally, each block is decomposed into two Kronecker factors, reducing the size of data stored several magnitudes more, at a cost of another assumption.

Lastly, a novel approach is to sketch a low-rank approximation of the Fisher [11]. Below is a figure with the four Hessian approximation structures:

Various Hessian approximation structures. (a) Hessian in full, intractable for large networks. (b) Low-rank. (c) Kronecker-factored Approximate Curvature, a block-diagonal method. (d) Diagonal. Source: Daxberger et al. (2021), link

It is also possible to cut the costs by treating only a subset of the model parameters (i.e. a subnetwork) probabilistically, fixing the remaining parameters at their MAP-estimated values. One special case of subnetwork Laplace that was found to perform well in practice is the last-layer Laplace, where the selected subnetwork contains only the weights and biases of the last layer.

Weight subsets to be treated probabilistically by Laplace. (a) Full-network. (b) Subnetwork Laplace, the general case. (c) Last-layer Laplace, a special case of subnetwork Laplace. Source: Daxberger et al. (2021), link

Our contributions to LaplaceRedux.jl

In the scope of the project we have added support for:

  • multi-class classification, in addition to regression and binary classification;
  • GGN, in addition to empirical Fisher;
  • hardware-parallelized batched computation of both the empirical Fisher and the GGN;
  • subnetwork and last-layer Laplace;
  • KFAC for multi-class classification with Fisher; and
  • interfacing to MLJ, a common machine learning framework for Julia.

We have also made quality assurance / quality-of-life additions to the repository, adding:

  • a formatting check in the CI/CD pipeline;
  • an extensive test suite comparing the results of LaplaceRedux.jl against those of its Python counter-part package laplace-torch; and
  • a benchmark pipeline tracking possible downturns in performance.

Methodology

We adhered to the Agile/Scrum practices, with two-week-long sprints, and weekly meetings with our formal client, Patrick Altmeyer. We have prioritized the expected requirements by the Moscow method into must-, could-, should-, and won’t-haves. This is all fairly standard for BSc software projects at TU Delft. By the end of the project, we have completed all of our self-assigned must-haves and should-haves.

Pain Points

Here we list some obstacles we have encountered along the way:

  • Julia is slow to compile and load dependencies on less powerful machines.
  • Stack traces are sometimes rather obscure, though it seems to be the price to pay for macros.
  • Zygote.jl, the automatic differentiation library, is not self-autodifferentiable — it cannot differentiate its own functions. We would want this since we apply Zygote.jacobianswhen making predictions with the LA.
  • There is no accessible tool reporting branch coverage on tests — only line coverage is available.
  • Limited LSP and Unicode support for Jupyter Lab.
  • Conversion between Flux and ONNX is not yet implemented.
  • There is no extension library for Zygote equivalent to BackPACK or ASDL for second-order information.

Highlights

And here is what we found refreshing:

  • Metaprogramming and first-class support for macros are something completely different for students who are used to Java & Python.
  • The Julia standard API, and Flux/Zygote, are fairly straightforward to use, and well-thought-out for the purposes of numerical computing and machine learning.

Conclusion

We have covered some elements of theory behind Laplace Approximations, laid down our additions to the LaplaceRedux.jl package, and brought out some difficulties we, as complete newcomers to Julia, came across. Hope you have enjoyed the tour, and hopefully it has intrigued you enough to look deeper into Bayesian learning and/or Julia, since both are developing at a lively pace. You can check out LaplaceRedux on the JuliaTrustworthyAI GitHub page here. Contributions and comments welcome!

Acknowledgments

Our team members are Mark Ardman, Severin Bratus, Adelina Cazacu, Andrei Ionescu, and Ivan Makarov. We would like to thank Patrick Altmeyer for the opportunity to work on a project this unique, and for the continuous guidance throughout the development process. We are also grateful to Sebastijan Dumančić, our coach, Sven van der Voort, our TA mentor, and Antony Bartlett, our supporting advisor.

References and Further Reading

[1] LeCun, Y., Denker, J., & Solla, S. (1989). Optimal brain damage. Advances in neural information processing systems, 2.

[2] Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., & Hennig, P. (2021). Laplace Redux — effortless Bayesian deep learning. Advances in Neural Information Processing Systems, 34, 20089–20103.

[3] Altmeyer, P. (2022). Go Deep, but Also … Go Bayesian! https://www.paltmeyer.com/blog//blog/posts/effortsless-bayesian-dl.

[4] Baan, J. (2021). A Comprehensive Introduction to Bayesian Deep Learning. https://jorisbaan.nl/2021/03/02/introduction-to-bayesian-deep-learning.html

[5] Bishop, C. M., & Nasrabadi, N. M. (2006). Pattern recognition and machine learning. New York: Springer.

[6] Huszár, F. (2019). Notes on the Limitations of the Empirical Fisher Approximation. https://www.inference.vc/on-empirical-fisher-information/

[7] Kunstner, F., Hennig, P., & Balles, L. (2019). Limitations of the empirical Fisher approximation for natural gradient descent. Advances in neural information processing systems, 32.

[8] Amini, A., Soleimany, A., Karaman, S., & Rus, D. (2018). Spatial uncertainty sampling for end-to-end control. arXiv preprint arXiv:1805.04829.

[9] Botev, A., Ritter, H., & Barber, D. (2017, July). Practical Gauss-newton optimisation for deep learning. In International Conference on Machine Learning (pp. 557–565). PMLR.

[10] Martens, J., & Grosse, R. (2015, June). Optimizing neural networks with Kronecker-factored approximate curvature. In International conference on machine learning (pp. 2408–2417). PMLR.

[11] Sharma, A., Azizan, N., & Pavone, M. (2021, December). Sketching curvature for efficient out-of-distribution detection for deep neural networks. In Uncertainty in Artificial Intelligence (pp. 1958–1967). PMLR.

--

--