Top-down Domain Inference for Lifelong Learning

Romain Mouret
ContinualAI
Published in
7 min readAug 15, 2021

Content of this article

  • Making the case for prior identification of domains/contexts in lifelong learning systems.
  • Showing that detecting discrepancies between expected macro state and actual macro state allows lifelong learning systems to infer domains with the promise of achieving arbitrarily high accuracy with very large models.
  • Getting away with having to detect domain boundaries during training.
  • Advocating for lifelong learning systems that exhibit the property “the larger the model, the less it forgets.”

Catastrophic forgetting in a nutshell

Machine learning models excel at fitting data, so if you feed them with data from a certain distribution D[t], they will do well on data sampled from D[t], but we can’t expect them to perform well on any other distribution, including distributions D[0], …, D[t-1] that the model has already seen in previous cycles.

In short, with D[t] coming in, D[0],…,D[t-1] risk to be forgotten. This phenomenon is called catastrophic forgetting.

This illustrates the parallel between speciation and learning. Woolly mammoths are not well adapted to the environment of their ancestors. The new species has “forgotten” how to survive in warmer climates.

The two most common groups of methods for overcoming catastrophic forgetting are (1) replay and (2) regularization-based methods. Replay and generative replay circumvent the problem by arranging the data into a single distribution D[0…t], often at a great cost in terms of memory and compute, while regularization-based methods (e.g. EWC) discourage the model from changing too much when it is fitting a new distribution D[t].

The method I will present belongs to a less popular category: (3) domain inference combined with a domain-aware or task-aware architecture.

Domain inference

For now, we will be working within a domain incremental learning setting, a.k.a. domain IL. For example, a system that categorizes cars under different lighting conditions by having access to weather domains (sunny day, night, overcast sky) one at a time.

In a domain IL setting, the set of classes is fixed. This contrasts with class incremental scenarios in which new classes are observed on the fly.

Accurate domain inference completely eliminates catastrophic forgetting in domain IL scenarios.

As you can instantiate a separate model for each domain without any sort of interference between them, you can readily build a robust continual learning system. Such a system would make predictions in two steps:

1) Infer the domain of the current input.

2) Route the input to the adequate domain-specialized model.

The second step implements the simplest form of domain awareness. In the more general case, domain-specialized models might share some parameters and might even be completely meshed together. In such a case, it may not be possible to completely eliminate catastrophic forgetting. Nevertheless, catastrophic forgetting can be greatly mitigated by conditioning the main model on the domain or the task identifier, as demonstrated here and there.

Top-down Architecture

The proposed architecture, dubbed TDDI for Top Down Domain Inference, is composed of 3 modules,

  • Feature Model: in its simplest form, it’s an ensemble of N processors predicting the current macro state from the sensory inputs, where N is the number of domains.
  • Center: a feed-forward neural net predicting the next macro state from the current macro state.
  • ActionModel: a neural net mapping a sequence of macro states to an output sequence.
“Center” is depicted in green

The macro states predicted by Center act like priors on the sensory inputs. They dictate how the observations must be interpreted.

When the fully trained system is deployed, we pick the domain hypothesis that minimizes the discrepancy between Center’s macro state and the actual macro state computed via the sensory input processors.

Macro states are represented by high-dimensional vectors. As we widen them, the system becomes increasingly precise in detecting discrepancies. This is due to the decrease in the number of accidental false positives, and this is why the system exhibits the desired property “the larger the system, the less it forgets”.

Experiment

The architecture was evaluated on a dataset of cooking recipes over 7 domains. The words of the recipes are randomly mapped to integers with a different random seed for each domain. This is essentially the NLP equivalent of Permuted MNIST.

As shown on the figure, the domain identification error rate quickly drops as we increase the size of the macro states, i.e. the size of the bottleneck layer. From size=16 to size=32, it is divided by 3.3. From size=32 to size=48, it is further divided by 2.9. At size = 48, the average domain identification accuracy is 99.23%, nearly eliminating catastrophic forgetting for non-overlapping neural networks on 7 domains or less.

Training domain boundaries

In general, we cannot expect domain boundaries to be known at training. Take for instance an autonomous robot roaming in the street. Passersby might point at cars and educate the robot about car brands. Ideally, the robot should be given a clear description of the current context too, such as the weather conditions, but passersby aren’t likely to make the extra effort. With the coming of the first snow, if the robot cannot strike a good balance between plasticity and stability, it might not remember how to distinguish cars in snowy conditions, even though it saw snowy streets the year before.

Unlike some other continual learning algorithms, TDDI is not naturally equipped to cope with the lack of domain boundaries during the training phase. While a discrepancy threshold can be tuned to detect new domains, TDDI cannot easily guess that two observations belong to the same new domain, a prerequisite to instantiating one processor for each new domain.

To solve this problem, it is tempting to keep all the recent observations in a buffer and run a clustering algorithm once the buffer is large enough, but this raises new challenges, such as choosing a distance function between observations. The distance module would have to be continually learned as well.

Infinite regress, a common problem when building continual learning systems by assembling modules. In “CF-immune”, CF stands for Catastrophic Forgetting.

For TDDI-like architectures, I suggest that we simply give up on detecting domain boundaries. Instead of creating a new processor for each domain, we would do so for each new “context”.

Contexts are created by lumping together any new observations if they happen in the same timeframe, e.g. 24 hours, regardless of whether they belong to the same domain or not. One could also skip any observation that is already correctly processed.

Domains carry with them a lot of abstract baggage whereas contexts are free of any high-level concepts, thereby steering clear of the homunculus fallacy.

Domain 0 is utilized to train ActionModel and Center

This setup was tested with 105 contexts spanning over 7 domains. The results show trends in line with those of the first experiment, albeit less pronounced. Refer to the github page for reproducing this experiment.

Regularization and other approaches

Regularization-based methods have their own merits in a variety of continual learning settings but they are arguably ill-equipped for large time scales. Regularization only delays the inevitable: sooner than later, the model will have to erase useful knowledge to make room for new knowledge. The model has to sacrifice stability for plasticity, or vice versa.

This leads me to my main point: If one wants to reach a stability-plasticity tradeoff suitable for lifelong learning, we need to create a virtuous relationship between the tradeoff and something that machine learning engineers have control over.

To that end, I suggest that we work on making the tradeoff vary with the size of the models. Tackling complex tasks with large models is also an objective of ours, so I believe this is a sound long-term plan for the algorithms that are held back by the stability-plasticity dilemma. After all, this worked out well for deep learning: Scaling up models unlocked performance that could not be achieved with shallow models alone.

I proposed an architecture that satisfies the desired criteria, but I should point out that there can be other approaches that also fit the bill. For instance, it has been theorized that sparsity can help with mitigating forgetting, and, conveniently, sparsity also provides greater benefits with larger models, since large models generally have proportionally more room to spare.

For domain inference, one-class domain classification could be in contention, but it is facing an uphill battle against irrelevant attributes and other problems plaguing anomaly detection on high-dimensional data.

References

--

--