Continual Learning with Node-wise Importance Regularization

SNU AI
SNU AIIS Blog
Published in
10 min readApr 1, 2022

By Seyeon An

Continual learning (lifelong learning), a long standing open problem in machine learning, refers to the ability of a model to learn continually from a stream of data, accommodating new knowledge and retaining previously learned experiences.

It is a concept to learn a model for a large number of tasks sequentially without forgetting knowledge obtained from the preceding tasks, where the data in the old tasks are not available any more during training new ones.

Continual Learning and the Plasticity-Stability Dilemma

The main challenge of continual learning is the plasticity-stability dilemma: if the model focuses too much on the stability, it suffers from poor forward transfer to the new task, and if it focuses too much on the plasticity, it suffers from the catastrophic forgetting of past tasks.

  • Forward Transfer : Forward transfer is the influence that learning a task has on the performance on a future task. Parameters should be “plastic” to learn new tasks/concepts, and thus excess stability in continual learning may make neural networks suffer from poor forward transfer.
  • Catastrophic Forgetting : Neural networks suffer from catastrophic forgetting, in which the learning of each new task causes the neural network to forget the models learned for the previous tasks. Such phenomenon makes the neural network hard to adopt to continual learning, since data of a certain stage cannot be used in the next stage in such case.

Regularization based Continual Learning

To address this dilemma, a comprehensive study for neural network-based continual learning was conducted broadly under the following categories: regularization-based, dynamic architecture-based, and replay memory-based methods, as displayed in the images below.

Common approaches for task-incremental learning: Regularization, Dynamic Architecture, Memory Replay [From Left to Right]
Common approaches for task-incremental learning: Regularization, Dynamic Architecture, Memory Replay [From Left to Right]
  • Regularization-based Methods : Retrain the whole network x(t) while regularizing to prevent catastrophic forgetting with previously learned tasks x(t-1)
  • Dynamic Architecture-based Methods : Selectively train the network x(t) and expand it if necessary to represent new tasks
  • Memory Replay-based Methods : Store the data discovered for [state, action, reward, next_state] that the agent observes, which is used as the raw data to feed into action-value calculations later

Our focus is the regularization-based methods, since they pursue to use the fixed-capacity neural network model as efficiently as possible, which may potentially allow them to be combined with other approaches.

These methods typically identify important learned weights for previous tasks and heavily penalize their deviations while learning new tasks.

A typical loss function form for a learning task t looks like this:

Intuitively, the loss function is calculated with the sum of task-specific loss and the product of adaptive regularization for weight importance and the weight difference of the current and past model, of each parameter.

The goal is to lower the loss function, thus the higher gap between the weight learned up to the past model and the weight of the current task would lead to more penalty.

Such methods have the chronic problem of large memory cost, since the storing of regularization parameters takes a lot of additional memory.

Adaptive Group Sparsity based Continual Learning

As a solution to the aforementioned two problems — (i) plasticity-stability dilemma and (ii) large memory cost — we consider node-based regularization.

We do not capture deviation based on weight: rather, we choose node.

Weight : Learnable parameter of a machine learning model that controls the signal (the strength of the connection) between two neurons.
Node : Computational Unit that has one or more weighted input connections, a transfer function that combines the inputs in some way, and an output connection.

Such focus on the node-level importance could lead to a more efficient representation of the model and achieve better compression than focusing on the weight-wise importance.

Model drift refers to a phenomenon in which the neural network forgets past training data
Model drift refers to a phenomenon in which the neural network forgets past training data
Negative transfer refers to the phenomenon in which unimportant data is also learned
Negative transfer refers to the phenomenon in which unimportant data is also learned

(i) Regarding the plasticity-stability dilemma, the node-based regularization solves the poor forward transfer problem by preventing model drift (the phenomenon in which the neural network forgets past training), via freezing the incoming weights of an important node.

It solves the catastrophic forgetting problem by preventing negative transfer (the phenomenon in which unimportant data is also learned and thus catastrophic forgetting happens), via pruning the outgoing weights of an unimportant node.

(ii) Since there exists λ_i for each node, not for each parameter, we solve the problem of large memory cost.

Motivation for AGS-CL

The concept of regularization-based continual learning schemes naturally connects with a separate line of research: the model compression of neural networks.

In order to obtain a compact model, typical model compression methods measure the importance of each node or weight in a given neural network and prune the unimportant ones — which is a similar principle with the regularization-based continual learning schemes.

Several representative model compression methods used the group Lasso-like penalties, which define the incoming or outgoing weights to a node as groups and achieve structured sparsity within a neural network:

This emphasis on the node-level rather than the weight-level allows a more efficient representation of the model and helps achieve better compression.

Using this as a motivation, we adopt group-sparsity norms for continual learning, in which the group lasso term guarantees group-sparsity.

AGS-CL Algorithm / Loss Function

Again, the AGS-CL algorithm is based on the node-level regularizations based on group sparse norms.

Now, let’s take a look at the loss function of the AGS-CL:

The loss function comprises of three main parts:

  • Task-Specific loss
  • (a) Unimportant Nodes : Use as few nodes as possible
  • (b) Important Nodes : Adaptively freeze important nodes

The red part of the function is dedicated to the exact pruning of unimportant nodes: the neural network prunes (or namely, leaves out) the unimportant nodes via the most sparse training possible.

The green part of the function is dedicated to the classification of important nodes.

The blue part represents the importance of each node via Ω. The bigger Ω is, the more the neural network tries to freeze (or namely, hold onto) the memory, and the smaller Ω is — vice versa.

The purple part represents the node weights learned up to the previous task (t-1), and the green part on the right — group norm of difference — enables more accurate freezing for bigger Ω.

Proximal Gradient Descent for Learning

The above regularization method prompts sparsity, but results in a loss function that is non-differentiable. We get around with this problem using the PGD.

The proximal gradient descent (PGD) method for learning can be used when there exists a group norm. It allows exact freezing (important) and pruning (unimportant) of the nodes when used with a group norm.

Basically the PGD method goes like this:

  1. Break f into two parts: g (the differentiable part) and h (the non-differentiable part).
  2. Minimize the gradient of the g part of the function.
  3. Utilize proximal operator to reduce h, while staying close to the point selected by (2).
  4. Repeat (2) and (3) until the optimal condition is met.

Since the proximal operator can be applied to each node parameter vector, no additional thresholds or heuristics that are otherwise required when using vanilla SGD-variants are required when we employ the PGD method.

AGS-CL: Re-initialization

The AGS-CL re-initializes the unimportant nodes after each task. Of course, it is only done towards pruned nodes.

The re-initialization process of the AGS-CL
The re-initialization process of the AGS-CL

There are two steps for re-initialization, zero and random initialization.

  • [Zero - init] : The zero-initialization process prevents negative transfer. It sets the outgoing weights of the unimportant nodes found by the loss function to 0.
  • [Rand - init] : The random-initialization process enables efficient learning. It randomly initializes the incoming weights of the unimportant node.

AGS-CL: Learning

Then how does the AGS-CL derive and update Ω that represents the importance of each node?

Below is the equation to update Ω, which uses the average ReLU (Rectified Linear Unit) activation as the basis, since the average ReLU activation can be a correct measure for identifying the importance of a node.

Area Over Prediction Curves for Ω
Area over prediction curves for Ω

The figure above shows Area Over Prediction Curves (AOPC) of our importance measure Ω, in which the pruning of nodes is done in the order of random (dotted), highest (solid) and lowest (dashed) values after learning task 1 (blue line with star) and all tasks (orange line with circle), respectively. We clearly observe the significant gaps between the solid and dashed/dotted lines, which corroborates the validity of using average ReLU activation for Ω.

The ReLU (Rectified Linear Activation Function) is a piece-wise linear function that will output the input directly if it is positive, otherwise, it will output zero. This function is commonly used in deep learning in many types of neural network since it is easier to achieve facile training and effective performance in a model that uses the ReLU function.

Here, a and x is the part for ReLU activation, while accumulates the sum over data. This part is added by ηΩ, which is the exponential average of the past nodes.

Experimental Results

Supervised learning on vision datasets

Average accuracy results on CIFAR-100, CIFAR-10/100, Omniglot, CUB200, and the sequence of 8 datasets

We evaluate the performance of AGS-CL together with the pre-existing regularization-based methods: EWC, SI, RWALK, MAS, and HAT. We used multi-headed outputs for all experiments, and 5 different random seed runs (that also shuffle task sequences except for Omniglot) are averaged for all datasets.

We have tested on multiple different vision datasets and thoroughly showed the effectiveness of our method:

  • CIFAR-100 : Split 100 classes into 10 tasks with 10 classes per task
  • CIFAR-10/100 : Use CIFAR-10 for pre-training before learning tasks from CIFAR-100 / used as a standard benchmark with smaller number of tasks
  • CUB200 : Split 200 classes into 10 tasks with 20 classes per task / used pre-trained network
  • Omniglot : Treat each alphabet as a single task and uses all 50 alphabets / used to compare the performance for large number of tasks
  • Sequence of 8 different dataset : CIFAR-10 / CIFAR-100 / MNIST / SVHN / Fashion-MNIST / Traffic-Signs / FaceScrub / NotMNIST / used to test the check the learning capability for different visual domains

From the experiment explained above, we can derive three main observations:

  1. Our AGS-CL consistently dominates other baselines for all the datasets throughout most tasks. This is especially notable since AGS-CL uses much smaller memory to store the regularization parameters than others.
  2. Among other baselines, there is no clear winner; MAS tends to excel in the first three sets, while it is the worst for CUB200 and the 8 different vision datasets.
  3. As seen in the results for Omniglot, SI and RWALK, which are based on the path integral of the gradient vector field, had large performance variance for a larger number of tasks.

Reinforcement learning on Atari tasks

Representative regularization-based methods, as EWC allow the agents to learn multiple tasks in a recurring fashion. In contrast, our reinforcement learning results on Atari games show that our AGS-CL method is for the pure continual learning setting, in which past tasks cannot be learned again.

The Atari 2600 Games task
The Atari 2600 Games task

The Atari 2600 Games task (and dataset) involves training an agent to achieve high game scores. The tasks were frequently considered in past representative works of continual learning, but their settings allowed the agent to learn past tasks again in a recurring fashion.

Difference between the settings on Atari Games task experiment of (Left) previous work and (Right) AGS-CL
Difference between the settings on Atari Games task experiment of (Left) previous work and (Right) AGS-CL

In contrast, we consider pure continual learning setting: the past tasks cannot be learned again, but the average rewards are evaluated for all tasks learned so far after learning each task. We randomly selected eight Atari tasks, and compared AGS-CL with three baselines: EWC, MAS and fine-tuning.

Normalized Accumulated Performance
Normalized Accumulated Performance

The figure above shows the normalized accumulated performance, in which each evaluated reward is normalized with the maximum reward obtained by fine-tuning for each task. We can observe that AGS-CL achieves much superior accumulated reward at the end of the 8 tasks compared to all EWC, MAS, and fine-tuning: 3x higher than the closest competitor.

Conclusion

We propose AGS-CL, a new continual learning method based on node-wise importance regularization.

Using a newly-developed loss function based on group-sparsity norms, PGD optimization technique, and the re-initialization tricks, the AGS-CL enables exact freezing/pruning. Our experiments have demonstrated that the AGS-CL achieves SOTA performance on both supervised learning and reinforcement learning.

Furthermore, such effective memory usage can be also beneficial for using our method in memory-limited environments, as mobile devices.

Acknowledgements

We thank Sangwon Jung, Hongjoon Ahn, and the co-authors of the paper “Continual Learning with Node-Importance based Adaptive Group Sparse Regularization” for their contributions and discussions in preparing this blog. The views and opinions expressed in this blog are solely of the authors.

This post is based on the following paper:

  • Continual Learning with Node-Importance based Adaptive Group Sparse Regularization, Sangwon Jung, Hongjoon Ahn, Sungmin Cha, Taesup Moon, 34th Conference on Neural Information Processing Systems (NeurIPS 20) 2020, arXiv.

This post was originally posted on our Notion blog, at July 20, 2021.

--

--

SNU AI
SNU AIIS Blog

AIIS is an intercollegiate institution of Seoul National University, committed to integrate and support AI related research at Seoul National University.