Predicting with Confidence: Localized Conformal Prediction for Regression with Deep Learning Models

Shubhendu
AI FUSION LABS
Published in
12 min readDec 22, 2022

Introduction

As deep learning methods become increasingly embedded in critical decision-making pipelines, such as in healthcare, quantifying the uncertainty of predictions in a theoretically-grounded manner has gained critical importance — a task that continues to remain fundamentally challenging. Quantification of uncertainty is important because it can add an additional protective layer and serve as a form of risk management. One way to quantify uncertainty in predictions is to also generate prediction intervals in addition to point predictions. Such prediction intervals can specify a possible range that a response to a particular input may take, thus permitting more honest decision-making.

We expect useful uncertainty information to meet the following desiderata: First, it should be valid (guaranteeing coverage). Meaning, the prediction interval quantifying the uncertainty of predictions should contain the true response with a pre-specified high probability. Second, it should be discriminative. That is, when the expected risk is high, this should be reflected in increased uncertainty (prediction intervals should be wider). We would also like our prediction intervals to be efficient having small widths wherever we are confident in our predictions. Two further criteria are particularly important for deep learning models. Any practically useful method should not take an inordinate time to run. Thus, the third criterion is scalability. Lastly, any UQ method should not decrease the accuracy of the base neural network model. The last two criteria make a strong case for post-hoc methods since they don’t interfere with the base neural network predictions at all.

Existing uncertainty quantification methods in the context of deep learning usually address only one or two of the above criteria simultaneously. Approximate Bayesian methods such as probabilistic backpropagation [1], SGLD of Teh & Welling [2], Deep Ensembles [3] [4], and Monte-Carlo dropout [5] are popular and used widely. However, the credible intervals given by the posteriors of such methods are not valid in the frequentist sense [6]. Furthermore, they also interfere with the base model, and loss function, and can require multiple runs of models, affecting both accuracy and scalability. Such methods have also been extended to the time series context and used with RNNs [7][8], and suffer from the same issues. We will deal with the cross-sectional time series (panel data) case in a future post.

The purpose of this post is to discuss the adaptation of conformal prediction methods to construct prediction intervals in the usual point prediction setting which meets all the aforementioned desiderata. In brief, we will cover:

  1. Brief background on conformal prediction and problem setup
  2. A locally valid and discriminative method (LVD) for the construction of valid prediction intervals for regression problems. This method, with formal proofs, is described in more detail in our NeurIPS paper [9].
  3. Experiments on synthetic data and molecular property prediction
  4. Conclusions and observations

Problem Setup and Background

We first discuss the problem setup and fix some notation. In the regression setting we work with the following:

For the prediction interval specified above, as mentioned in the introduction, we say it is valid if, for some specified probability 1-α, we expect it to cover the test example at least 1-α of the time. Further, we say it is discriminative if we expect it to be narrower for test examples for which we are confident. An ideal prediction interval should be both valid and discriminative. This is illustrated in the figure below:

We now elaborate further on notions of validity and discriminability. We will elaborate on the usage of “conditionally valid” in the figure later in the post. The type of validity specified above is in fact called marginal validity, meaning it holds on average across examples. Or more formally,

Conformal Prediction for Marginal Validity

We now take a brief digression and introduce the basic idea of conformal prediction keeping in mind the marginal validity above. Conformal prediction is a powerful framework originally developed by Vovk, Gammerman, and Shafer [11], which provides tools for the construction of prediction intervals with provable coverage guarantees. The framework makes minimal assumptions about the model and the data distribution, and its flexibility has recently attracted widespread attention in machine learning both in academia and industry. Indeed, it has also seen deployment in clinical and financial-industrial settings, as well as in large-scale language modelling. While there are many variations to the framework, we illustrate the basic idea in its simplest form below, which can be used to obtain the marginal validity described above.

To obtain prediction intervals with marginal validities, we split our training data into a separate, held-out, calibration set. On this calibration set only, we make predictions using our deep learning model and collect error residuals.

Then for a new prediction (on a test point) ŷ, we can construct prediction intervals simply as:

In the above, (β; A) refers to the β-quantile for the set A. With the i.i.d assumption stated in the beginning, a PI constructed in this manner is guaranteed to cover the corresponding Y with the target coverage probability/level 1-α. While we only describe the simplest use of the conformal method, it captures the essence of many of its variations. Now we return to our motivation of capturing discriminative information.

Local Validity and Discriminability

As might be immediate, marginal validity is clearly not sufficient in many settings of interest. For example, in a clinical setting, such validity would mean that our prediction intervals will cover an average patient. However, it is possible that it entirely misses an important minority subgroup of patients (such as those with rare diseases) even if it can cover a large number of patients on average. What we thus desire is conditional validity i.e. conditioned on a particular example. Formally,

However, for prediction intervals to be useful with deep learning models, we also want to make minimal assumptions about the model — we want to assume a black-box model. Further, we also want to make no distributional assumptions about the data. If we consider such a general setting, it is in fact known that constructing prediction intervals with conditional validity is theoretically impossible to achieve in general (for example see [10, 13]). A reasonable workaround to this conundrum could be to instead consider approximately conditional validity, where we relax strict conditional validity and only require that validity is maintained in a close neighbourhood of any point X. We could consider the following formulation considered in the important works of Tibshirani et al. [11] and Leying Guan [12]. These works also serve as our main inspirations and starting points. We consider a kernel function, which takes two points as inputs and outputs a real value as output. We can fix an anchor point 𝑥︲and obtain the following kernel-based relaxation:

Now, instead of fixing an anchor point beforehand, if we set it to be the new test point instead, and fold the integral into the probability as described in [12] after reweighing it by the kernel, we can now define a notion of conditional coverage as follows:

Note the change in Y above. This is because we now deal with a distribution that is reweighted using the kernel with the anchor. The tilde indicates that the points are sampled from this re-weighted distribution:

We now have a more satisfactory notion of local validity. However, we also discussed the idea of discriminability above. Which we now discuss briefly. The notion of discriminability in this context is simple. If we denote W to be a measure of the width of the PI, and if we let 𝓁 be a loss function (such as MSE), we could formally write it as:

As mentioned above, this property simply means that we would like our prediction interval to be shorter or tighter if the expected risk is low and vice-versa.

Given this background, it might be useful to reiterate our goal: We want to construct prediction intervals that are discriminative but also have some kind of (approximately) conditional or local validity since we do not want to ignore distinct subgroups. We will work with the notion we have defined above and in the next section show how such intervals can be constructed. We will further ensure discriminability by using a learnt kernel.

Locally Valid and Discriminative Prediction Intervals for Regression

Using the above concepts, we now sketch our approach, which will ensure prediction intervals that are discriminative and also provably locally valid. We first consider our training set and split it into two sets one for learning an embedding, and another a calibration set to collect residuals, in the same manner as we described above. Next, we consider our trained neural network and extract the final layer embeddings f. Using these embeddings, we now use the embedding set to train our kernel function. More precisely, we perform leave-one-out Nadaraya Watson kernel regression to regress on each data point while optimizing for the distance function of the kernel. We only consider linear projections of the embeddings, which are learnt by backpropagation. Once we have trained such a kernel, for a test point we can now obtain prediction intervals using the calibration set. The procedure is almost the same as described earlier with one fundamental difference: We now weigh the residuals using our learnt kernel and obtain PIs as:

In fact, this simple methodology ensures local coverage provably. Please see our paper [9] and the paper of Leying Guan [12] (which however does not learn a kernel, and is not specific to deep learning) for more details and proofs of validity. Moreover, since we actually learn the kernel function, we can also adapt to the dataset and thus obtain prediction intervals that are discriminative. The approach sketched above is stated more formally below:

Experiments

We now show some experimental results that test our method and confirm its claims of validity and discriminability. We also consider a plethora of state-of-the-art methods to compare against. More specifically, we consider Discriminative Jackknife [14], Deep Ensemble [3], Monte-Carlo Dropout (MCDP) [5], Probabilistic Backpropagation (PBP) [1], Conformalized Quantile Regression (CQR) [15], and MAD-Normalized Split Conformal (MADSplit) [16].

As an illustrative example, we first consider a synthetic dataset, standard in the UQ literature. This is generated as 𝘺 = 𝘹³ + ε, where ε is sampled from a Gaussian with zero mean and a standard deviation of 16. 𝘹 is sampled from a uniform distribution between [-1,1] with probability 0.9 and the half normal distribution on [1, ∞) with σ = 1 with probability 0.1. The results, along with a qualitative comparison are illustrated below:

Next, we consider results on some UCI datasets which are also standard in the UQ literature. We also consider the molecular property prediction datasets QM8 and QM9, with QM9 being an example of a large dataset. More extensive results are in the paper.

For evaluating performance, we consider the following metrics: For checking validity, we consider the marginal coverage rate (MCR) and the tail coverage rate (TCR). TCR is defined to be the coverage rate for data where the Y falls in the top and bottom 10%. To check for discriminability, we consider the AUROC of using the prediction interval width to predict whether the absolute residual is in the top half of all residuals. However, since the AUROC alone can be misleading, we also report the mean absolute deviation (MAD). All the experiments are repeated 10 times and the standard deviations are also reported.

For the measures of validity, we see the following results on some of the datasets considered:

For discriminability, we observe the following outcomes:

Observations and Conclusions

In the synthetic data experiment, we see that the conformal methods LVD, CQR, MADSplit, and DJ all return nearly 90% coverage — they are all valid in a marginal sense. LVD, however, gives a discriminative prediction interval. For example, we see that at the boundaries we have wider intervals, as we would expect since the data is sparser. This is not the case with other methods. CQR and MADSplit can also be discriminative, but they are still only marginally valid, and the prediction intervals at the boundaries tend to become narrower, which is not desirable.

In the experimental results on the UCI datasets, QM8, and QM9, we again see that the conformal methods (LVD, CQR, MADSplit, and DJ) are all marginally covered (as measured by MCR). The approximate Bayesian methods do not achieve meaningful coverage. However, for the tail coverage rate (TCR), LVD is the only method that consistently achieves target (90%) coverage levels. For the table reporting results on discriminability, we see that LVD is usually in the two top, and maintains a low mean absolute difference. When other methods have high discriminability, we also see an increase in the mean absolute difference.

At the beginning of the post, we also highlighted scalability and accuracy as key requirements. As compared to the approximate Bayesian methods considered in our experiments, LVD is post-hoc, and thus no model re-training is required. We can simply consider the neural network embeddings and train them on a small validation dataset. We generally observe that the training cost is minimal (more details about wall-clock times can be found in our paper). Other methods such as MADSplit and CQR require training a new quantile predictor for every α and get expensive quickly. Finally, we also see that in all the datasets, the overall accuracy of the base neural network is maintained, even improved (due to the trained kernel) by using LVD.

References:

[1] J. M. Hernández-Lobato and R. P. Adams. Probabilistic backpropagation for scalable learning of bayesian neural networks. Proceedings of the 32nd International Conference on Machine Learning, ICML 2015.

[2] M. Welling and Y. W. Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th International Conference on Machine Learning, ICML 2011.

[3] B. Lakshminarayanan, A. Pritzel, and C. Blundell. Simple and scalable
predictive uncertainty estimation using deep ensembles. In Advances in Neural Information Processing Systems, 2017.

[4] A. G. Wilson and P. Izmailov. Bayesian deep learning and a probabilistic
perspective of generalization. In Advances in Neural Information Processing, 2020.

[5] Y. Gal and Z. Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In Proceedings of the 33rd International Conference on Machine Learning, ICML 2016.

[6] M. J. Bayarri and J. O. Berger. The Interplay of Bayesian and Frequentist Analysis. Statistical Science, 19(1):58–80, 2004.

[7] J. Caceres, D. Gonzalez, T. Zhou, and E. L. Droguett. A probabilistic
b12ayesian recurrent neural network for remaining useful life prognostics considering epistemic and aleatory uncertainties. Structural Control and Health Monitoring, 28(10):e2811, 2021.

[8] M. Fortunato, C. Blundell, and O. Vinyals. Bayesian recurrent neural networks. CoRR, abs/1704.02798, 2017.

[9] Z. Lin, S. Trivedi, and J. Sun. Locally Valid and Discriminative Prediction Intervals for Deep Learning Models, In Advances in Neural Information Processing Systems 34, 2021.

[10] V. Vovk, A. Gammerman, and G. Shafer. Algorithmic learning in a random world. Springer US, 2005.

[11] R. J. Tibshirani, R. Foygel Barber, E. J. Candes, and A. Ramdas. Conformal prediction under covariate shift, In Advances in Neural Information Processing Systems, 2020

[12] L. Guan. Conformal prediction with localization. arXiv, abs/1908.08558, 2020

[13] J. Lei and L. Wasserman. Distribution-free prediction bands for non-parametric regression. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 76(1):71–96, 2014.

[14] A. Alaa and M. Van Der Schaar. Discriminative Jackknife: Quantifying Uncertainty in Deep Learning via Higher-Order Influence Functions. In Proceedings of 37th International Conference on Machine Learning, 2020

[15] Y. Romano, E. Patterson, and E. Candes. Conformalized quantile regression. In Advances in Neural Information Processing Systems, 2019.

[16] A. Bellotti. Constructing normalized nonconformity measures based on maximizing predictive efficiency. In Proceedings of the Ninth Symposium on Conformal and Probabilistic Prediction and Applications, PMLR, Sep 2020.

--

--