Scalable Bayesian Modeling

An updated benchmark

Sandra Meneses
8 min readMar 23, 2023

Bayesian modeling has become widely spread in the last years thanks to the development of new sampling algorithms from Metropolis-Hastings (1970), gradient-based MCMC samplers like NUTS (2011) to Variational Inference methods. And to the development of Probabilistic Programming Libraries (PPL) which have made these algorithms easy to apply in science and industry.

With the growth of probabilistic frameworks and high performance machine learning libraries and compilers, we have the option to make use of Bayesian methods in more complex use cases. Implementing a new technique in industry brings challenges that are not present in the academic world. For most use cases in the industry, predictions have to be recalculated as new data is available and the model has to be fit again or updated in real-time. This requires being able to estimate the needed computing resources and the estimated time to make design decisions.

In the end, you need to know which resources are needed to fit a model, calculate predictions and how often it’s recommended to iterate being cost-effective and giving value to users. In this post, a method to be able to answer some of these questions is introduced which intends to help Python PPL users implementing Bayesian models to know which configuration is better for their specific case and data.

High-Performance Computing in Python

Just-in-time compilers like Numba allow the generation and optimization of array operations facilitating the parallelization of algorithms, making Python as efficient as code written in C. Using XLA, JAX offers similar features focused on Linear Algebra operations and automatic differentiation in Python and NumPy code. These libraries also bring the possibility to execute the same code on GPUs and TPUs.

PyMC introduced JAX and Numba as backends for faster sampling in its Version 4. This was made possible thanks to PyTensor, which creates a representation of the model logp. This representation can be transpiled to C, Numba or JAX. PyMC currently supports NUTS via NumPyro and Blackjax, allowing the execution of PyMC models on GPUs.

Blackjax is a library of samplers with JAX as backend and it’s meant to be used by PPL developers. NumPyro is a lightweight PPL based on Pyro which also works with JAX. Pyro’s goal is to unify Deep Learning and Bayesian Modeling using PyTorch as backend.

Sampling with JAX based libraries

When sampling multiple chains, vectorization and parallel processing capacities come in handy. JAX can vectorize code which helps to efficiently execute batch operations. JAX also follows a SPMD (single program, multiple data) approach to parallelize code being able to run the same code in multiple devices at the same time.

By default, PyMC and NumPyro have the parameter chain_method parallel, which is faster when you have multiple devices. Using JAX, the same sampler can be used in CPU and GPU.

A standard Bayesian workflow to benchmark

Finding the right model for a particular case involves an iterative process, which starts for doing exploratory analysis, comparing different models and their diagnostics to know if there is convergence and how efficient the sampling was.

Components

To benchmark different libraries and samplers, this template was designed. It has the option to add your data and models, run the models with different data sizes and validate the results and compare the resources each configuration make use of. You can see how to use it in the README of the repository.

The template contains these components:

  • Data Generation: The function data_generator can create multiple DataFrames using the original dataset based on the given sizes or filtering the original data with parameter filters. You can also add Gaussian noise to selected variables using the parameter include_noise.
  • Sampling: It supports sampling with PyMC (using the default sampler and the JAX based samplers: NumPyro and Blackjax, all of them sampling with NUTS) and NumPyro NUTS. The goal is to support more libraries, for now PyMC and NumPyro were prioritized. PyMC because of its popularity and user friendly API, and NumPyro being a modern library (first released in 2019) directly powered by JAX, making it an excellent candidate to validate performance. You can configure the models (written in PyMC or NumPyro) to be compared, the PyMC samplers, the number of samples to draw, iterations to tune and chains.
  • Visualization: The template has plots to show the runtime and memory used across models and data sizes. To visualize the performance, the Effective Sample Size (ESS) per second (s) using the methods mean and tail are displayed. If the samples drawn by the model are highly correlated the uncertainty of the estimations calculated from the posterior is high and the ESS is low. Therefore, ESS is an indicator of each sampler performance, and dividing this value by the total runtime gives us an efficiency metric.

And this additional feature:

  • Colab support: The template can be used from Google Colab to be able to benchmark GPUs. By default, if the runtime has GPU, the samplers will use the chain_method vectorize as Colab has only one GPU. For detailed instructions, please see this section in the README.

Diagnostics

Comparing different options to evaluate their performance brings the need of following some of these steps in a way where errors are easy to identify and results among models can be cross checked. To accomplish these requirements, two steps were included.

  • Check convergency: It will detect common convergence issues. This function uses the rank normalized splitR-hat. An R-hat>1.05 indicates convergence failures. This diagnostic can not confirm that the model have converged, if results are not consistent, it’s recommended to double check convergence for that run.
  • Validate results: It checks if different models are sampling from the same distributions. It works by estimating the ranges of the mean for each variable using the MCSE of one of the models as reference.

The percentages displayed indicate how many variables are within the calculated range using ±3 sigma. The values should be in theory greater or equal to 95% following a weaker three-sigma rule. If the models didn't converge, the results of the validation are not valid.

For the diagnostics, ArviZ is being used which integrates for many more PPL (CmdStanPy, PyStan, Pyro and emcee) which will facilitate to add more libraries in the future.

Let’s try it out

Now, let’s use the previous workflow to show the performance of the supported options at the moment. This exercise also shows how flexible the template is. As example, the model used in the blog post MCMC for big datasets: faster sampling with JAX and the GPU was replicated. The goal of the model is to rank the skills of the tennis players based on their performance in previous matches. For that, the original PyMC model was used and the homologous version in NumPyro was written. The rest of the work was done using the standard workflow previously described.

The data contains 169073 tennis matches from 1968 till 2023, indicating the winner and loser names for each match. It was taken by executing the function get_data in the repository of the original blog post. You can download the data available online to be able to reproduce the results or you can follow the code in the notebook tennis. Executing the whole notebook in a Mac M1 16 GB took around 2.5 hours, out of which 122 minutes were taken for the sampling. Note that these results were taken using CPU, with a GPU sampling with NumPyro and the PyMC JAX samplers can be much faster.

Creating the datasets

8 different datasets were generated from the entire data using the parameter filters in the data generator. Each dataset using a filter year >= x where x are the years 2020, 2019, 2015, 2010, 2000, 1990, 1980 and 1968. Therefore, the last dataset has the whole data: 169073 rows.

Validations

  • Check convergency: All models give values between 1.00 and 1.02 which means that there are no convergence issues according to the predefined threshold.
  • Validate results: Here, 6 configurations with data sizes equal or higher to 93549 didn’t meet the 95% defined as standard, however seeing the same models passed the validation for different sizes and that the minimum value was 92.07%, it can be concluded that PyMC and NumPyro models are equivalent.

Comparison

Figure 1. Resource metrics. Image by author.
  • Resource metrics: Seeing the metrics, the PyMC default sampler is slower (21.91 minutes for the larger dataset while its sampler Blackjax took 7.06) and needs much more RAM memory.

For the smallest samples, NumPyro is faster, but for datasets with more than 80000 rows sampling with Blackjax turns out to be a better options.

Although the PyMC sampler consumes more RAM, PyMC with the two JAX samplers (NumPyro and Blackjax) uses less memory than NumPyro.

Figure 2. ESS/s for all data sizes. Image by author.
  • ESS/s: From this perspective, the PyMC sampler is also consistently less efficient. NumPyro has the highest ESS/s for 4 out of the 8 datasets, for the other 4 either PyMC using Blackjax or NumPyro as samplers have the best performance. These last two have similar performance metrics which is not surprising as only the sampling backend is changing.
Table 1. Top performance per data size. Image by author.

The ESS/s for the whole data is shown below, where we can see BlackJax has the highest value followed by NumPyro.

Figure 3. ESS/S means and tails for the complete dataset. Image by author.

Results

Apart from the validations, we can also see the results shown in the original post (Addendum 1), where tennis players are ranked according to their skills.

Table 2. Top 20 tennis players. Image by author.

What can we do next?

The main motivation to write this blog and create the template was to support the great work of Bayesian enthusiasts and PPL developers. With a growing ecosystem, we hope to support more libraries and algorithms in the future. It will also be of value to add comparisons for posterior predictive sampling.

Having the option to compare samplers can help to find in which situations they perform better, their shortcomings, and potential improvements.

This blog post was written as part of the PyMCon Web Series. If you have any suggestion to improve the code, you can open an issue here and, of course, your code contributions are more than welcomed.

Without the valuable comments from anonymous reviewers, this post wouldn’t have been the same . Special thanks to Oriol Abril-Pla, for his guidance to create the diagnostics and reviewing an early draft of the blog and to Ravin Kumar for his suggestions and code review.

--

--

Sandra Meneses

Machine Learning engineer and architect of end-to-end solutions. Currently focused on Graph Learning, NLP and EdTech. https://github.com/symeneses