# Neural Processes for Uncertainty-Aware Continual Learning

With this blog post, I intend to explain our

NeurIPS paper(going by the same title) in simple terms. This blog does not cover all the aspects of our paper. In particular, I am leaving out the experimental results and of course, the equations to keep the content focused on our methodology.

# Motivation

Let us assume an image classification model that has been trained on millions of images of cats and dogs, and can discriminate well between these. Let us also consider the scenario where this classifier has been deployed in the real-world. Here, the model is continuously fed with streaming images of cats and dogs. Further, with time, we assume that our streaming data has new images of koalas. Now, to expand the performance of our cat vs. dog classifier on the new class images, we must train it on the images of koalas alongside cats and dogs. However, the problem we face is that we no longer have access to those millions of cats and dogs images that we originally trained our classifier on (due to privacy limits on pet images). Voila! we have stumbled upon the need for the Continual learning setting.

**Continual learning (CL) **aims to train deep neural networks (like our image classifier) efficiently on streaming data (koalas) while retaining the previously acquired knowledge (cats vs. dogs). To achieve this, CL agents target alleviating the **catastrophic forgetting** issue with restricted computational and memory costs (imagine a subset of the original cats vs. dog dataset).

Now, let us assume an even more realistic scenario where we know that our streaming data is bound to contain new (koala) classes over time. Here, we desire that upon observing a koala image, our cat vs. dog classifier should reliably allude this in its predictions. One way to achieve this is by looking at the model’s output probabilities. Namely, if our classifier is confident about an input image x containing a cat or a dog, then it should assign a high probability (p) to its predicted class (e.g. p(cat | x) = 95%, p(dog | x) = 5%). Otherwise, the assigned probability distribution of classes should have a high entropy (e.g. p(cat | x) = 45%, p(dog | x) = 55%).

Despite sounding trivial at first, the aforesaid **uncertainty-awareness **property has been the *Achilles heel* of most of the CL methods based on deterministic networks. One reason for this being that deterministic models usually ignore modeling the generative nature of the inputs, i.e., modeling the uncertainty inherent to the process that generated the input.

Put together, we have arrived at the key motivation for our work:

designing Continual learning methods that are uncertainty-aware, and thus can be deployed reliably in (safety-critical) real-world scenarios.

# What are Neural Processes?

To achieve our aforesaid goal, we turn to probabilistic modeling of CL. A popular approach to go probabilistic is the Bayesian inference paradigm where we capture our existing belief about a real-world phenomenon using a prior distribution. Upon observing more data, we update the prior (our belief) to a posterior distribution. The CL setting, as such, has a natural fit for a Bayesian inference framework. That is, we can treat the data from our newly arrived task as posterior while the knowledge we acquired from previous task data (our previous posterior) becomes our new prior [2]. We will revisit this elegant property in designing NPs for CL next.

Neural Processes (NPs) are one such class of probabilistic *Bayesian* meta-learners that rely on data-driven priors. Too many jargons already? Well, let us break them down:

**#1** NPs are Bayesian because they use prior distributions derived from partially observed data, a.k.a., the context points. Upon observing data points in addition to the context set, a.k.a., the target points, the priors get updated into posteriors.

**#2** NPs are meta-learners because their Bayesian inference framework meta learns to predict/complete the outputs for a set of unlabelled target data points based on the labelled context data points. For example, NPs can learn to in-paint all the pixels of an image upon having seen a subset of these.

**#3** Finally, NPs can be probabilistic by sampling latent variables from a distribution of the set-based inputs (see the figure below). Such a latent variable captures a summary of the input data points. For more details on NPs, we refer the reader to [1].

## How to adapt Neural Processes for Continual Learning?

A hurdle in applying existing NPs for CL is that these are based on a single latent variable that captures the global summary of the inputs. Since, in CL, our inputs can span multiple tasks, using a global summary just happens to be sub-optimal. To adapt NPs for CL, we propose preserving this global latent variable and adding another layer of task-specific latent variables conditioned on it. This makes our framework use a hierarchical latent variable model.

The first layer of our hierarchical framework consists of the global latent variable that can help facilitate cross-task knowledge transfer among CL tasks. This helps achieve the forward and backward transfer of knowledge. The second layer of our hierarchical framework comprises the task-specific latent variables that maintain fine-grained knowledge per task for a superior task-specific performance. However, instead of deriving the task-specific latent variables from the task-specific data points alone, we additionally condition these on the global latent variable. This conditioning is to endow cross-task knowledge factors on to our task-specific distributions.

Now, let us return to leveraging the Bayesian framework of NPs for CL. To do so, we propose incorporating the **global** and **task-specific** regularization terms into the learning objective of the NPCL. The global regularization loss term uses the global distribution from the past task as the prior while ensuring the current global distribution (the posterior upon seeing new task data) matches this prior. Similarly, the task-specific regularization losses use past task-specific head distributions as priors and act only on the past-task data, whose posterior is regularized.

The resulting framework is shown below where we end up with one global encoder and several task-specific encoders where each encoder was originally trained upon receiving the new task.

## Task-head inference in NPCL

Task-head inference is a fundamental problem in class-incremental learning — a challenging variant of CL where at test time, we have no task IDs for a given input. As a result, our model must learn to infer the correct task component for a given input point. Now, since NPCL maintains task-specific heads, each of these are capable of producing a unique set of predictions. At training time, we can use the task IDs of inputs to select the right task head. However, at test time, we do not have such task IDs and are therefore left with the problem of inferring predictions from the correct task head.

Fortunately, we present uncertainty as the elegant solution to this fundamental CL issue. As shown in the figure above, at test time decoding, we simply choose the task-head that produces output probabilities with the least entropy score. The resulting model gives us accuracy scores that are on par with several state-of-the-art deterministic models (see Table 1 in the main paper).

## Anything else that ``uncertainty” could help us with?

Besides task-head inference, we study the scope of uncertainty in CL for a range of tasks. Below is a brief one-liner comment on these:

- Low model calibration errors: Compared to the SOTA deterministic CL methods, the predicted output probabilities of NPCL are more aligned with the actual probabilities of the ground truth label distribution (Table 2, main paper).
- Few-shot replay settings: NPCL retains better accuracy on CL settings with extremely small rehearsal memory sizes — for e.g. a total replay buffer size of 5 on CIFAR-100 (Table 12, Appendix).
- Novel data identification: The variance scores of NPCL’s output predictions show a marked difference on in-domain and out-of-distribution input data — thus offering a key to identifying novel data (Table 4, main paper). Note that due to their deterministic nature, the state-of-the art CL methods cannot produce such output variances.
- Instance-level model confidence evaluation: Finally, we show that the confidence evaluation framework of Han et al. [3] is applicable to NPCL too. In a nutshell, what this means is that NPCL’s variances of predicted class labels are smaller when its predictions are correct (Table 5, main paper). It is worth noting that as a precursor test to this framework, we first observe that the differences of the top-two predicted probability scores of NPCL is normally distributed (see the GIF below).

## Conclusion

Our work proposes Neural Processes for Continual Learning (NPCL), a hierarchical latent variable setup designed to jointly model the task-agnostic and task-specific data generating functions in continual learning. We study the potential forgetting aspects in NPCL and propose to regularize the previously learned distributions at a global and a per-task granularity. We further demonstrate that using entropy as an uncertainty quantification metric helps NPCL infer correct task heads and

boost the performance of baseline experience replay to even surpass state-of-the-art deterministic models on several CL settings. We further show out-of-the-box applications of the uncertainty estimation capabilities of NPCL for a range of downstream applications.

## References:

[1] Jha, S., Gong, D., Wang, X., Turner, R.E., & Yao, L. (2022). The Neural Process Family: Survey, Applications and Perspectives. *ArXiv, abs/2209.00517*.

[2] Nguyen, C.V., Li, Y., Bui, T.D., & Turner, R.E. (2017). Variational Continual Learning. *ArXiv, abs/1710.10628*.

[3] Han, X., Zheng, H., & Zhou, M. (2022). CARD: Classification and Regression Diffusion Models. *ArXiv, abs/2206.07275*.

[4] Jha, S., Gong, D., Zhao, H., & Yao, L. (2023). NPCL: Neural Processes for Uncertainty-Aware Continual Learning. *ArXiv, abs/2310.19272.*