Humans can continuously learn new skills without forgetting old
knowledge, but neural networks often "forget" when learning new tasks —
this is catastrophic forgetting. How can models learn like humans
throughout their lifetime, remembering the first task after mastering
100 tasks? Continual Learning provides the answer.
This article systematically explains the principles and
implementations of four major approaches — regularization, dynamic
architectures, memory replay, and meta-learning — starting from the
mathematical mechanisms of catastrophic forgetting. We analyze parameter
importance estimation, inter-task knowledge transfer, and the
forgetting-stability trade-off, and provide complete code (250+ lines)
for implementing EWC from scratch.
Problem Definition of
Continual Learning
Task Sequence
Continual learning processes a sequence of tasks arriving over
time:
Each taskis defined as a learning
problem:where: -is task data -is task loss function -is the model for task (with parameters)
Key constraint: When learning task, previous task datacannot be accessed.
Catastrophic Forgetting
Phenomenon: After training on task, model performance on
previous tasksdrops dramatically.
Formalized as:where: -: accuracy on taskimmediately after learning task -: accuracy on taskafter learning taskGoal: Minimize
forgetting while maintaining learning capability for new tasks.
Three Scenarios of
Continual Learning
Task-Incremental Learning:
Task ID known at inference
Each task has independent output head
Example: MNIST → FashionMNIST → CIFAR10
Domain-Incremental Learning:
Same task, different data distributions
Example: Sunny images → Rainy images → Night images
Class-Incremental Learning:
Task ID unknown at inference
Must predict from all learned classes
Most challenging scenario
Evaluation Metrics
Average Accuracy:
Average Forgetting:
Backward Transfer:
-: Forgetting
-: Positive
transfer
Forward Transfer:
-: Zero-shot
performance (accuracy on taskbefore learning task)
Mathematical
Mechanisms of Catastrophic Forgetting
Loss Landscape Perspective
Neural network loss functions are defined in high-dimensional
parameter space:Training is optimization to
find local optima on the loss landscape:
Problems:
Task 1's optimumand Task 2's
optimumare
typically far apart
Optimizing fromtoleaves Task 1's low-loss
region
Gradient descent updates parameters globally, difficult to adjust
locally
Define gradient conflict:whereis the gradient of task.
Weight Importance
Not all parameters are equally important for old tasks. Define
importance of parameterfor task:
Insight: Protect important parameters (large), allow unimportant ones
to change.
Fisher Information Matrix
Fisher information matrix measures parameter sensitivity to
loss:Diagonal elementrepresents importance of
parameter:
Properties: - Large: Parameterhas significant impact on
predictions, should be protected - Small: Parameterhas minimal impact on
predictions, can be modified
Regularization Methods
Elastic Weight Consolidation
(EWC)
Core Idea of EWC
EWC1 applies regularization constraints
to important parameters when learning new tasks, preventing them from
deviating from old task optima.
Objective function:where: -: New task B's loss
-: Old task A's optimal
parameters -: Fisher information
(importance) of parameter -: Regularization strength
Intuition: Constrain changes in important parameters
to protect old task knowledge.
Computing Fisher Information
For classification tasks, diagonal elements of Fisher information
matrix:In practice, compute on task A's data:
Forward pass to get prediction$ = f_(x)p(|x; )g_i =
F_i = [g_i^2]$
Multi-Task Extension
When learning task sequence, EWC objective:
Problem: Accumulating Fisher information makes
parameters increasingly "rigid".
Improvement: Online EWC2,
only maintains current Fisher information and parameters, avoiding
accumulation:where:is decay factor (e.g., 0.9).
Memory Aware Synapses (MAS)
MAS Improvements
MAS3 addresses EWC's limitation: Fisher
information only considers last layer gradients, ignoring intermediate
layer importance.
MAS parameter importance definition:Note
this is gradient of output with respect to parameters,
not loss with respect to parameters.
Objective function:
MAS vs EWC
Dimension
EWC
MAS
Importance measure
Fisher information (gradient square)
Output sensitivity (gradient absolute value)
Computation dependency
Requires labels
No labels needed
Applicable scenarios
Supervised learning
Unsupervised/self-supervised
Computational complexity
Moderate
Low
Synaptic Intelligence (SI)
SI's Online Update
SI4 computes parameter importance online
during training, not after task completion.
Parameter importance:where: -: Parameter update at
step -: Gradient at step -: Total parameter change
-: Small constant to prevent
division by zero
Intuition: Longer "path length" of parameter
movement during training with more loss reduction indicates higher
importance.
Objective function:
Learning without Forgetting
(LwF)
Application of Knowledge
Distillation
LwF5 uses knowledge distillation to
maintain old task output distributions.
Loss function has two parts:
New task loss:
Distillation loss (maintain old task
outputs):Total loss:
Advantage: No need to save old task data, only old
model predictions.
Disadvantage: Need to save copy of old model
(storage overhead).
Dynamic Architecture Methods
Progressive Neural Networks
Progressive Expansion
Progressive Networks6 add new network columns for each new
task:where: -:
Activation of taskat layer -: Weights of task(trainable) -: Lateral connections
from taskto task(trainable) - Old task parameters() completely frozen
Advantages: - Completely avoids catastrophic
forgetting (old parameters unchanged) - Supports forward transfer
(lateral connections leverage old knowledge)
Disadvantages: - Model size grows linearly:tasks requirenetwork columns - Inference overhead
increases with task count
Dynamically Expandable
Networks (DEN)
Dynamic Expansion Strategy
DEN7 dynamically expands network capacity
as needed:
Selective Retraining:
Freeze important parameters
Only fine-tune unimportant parameters
Dynamic Expansion:
If existing capacity insufficient, add new neurons
Decision criterion: Validation loss stops decreasing
Network Split/Duplication:
Duplicate neurons with added noise
Increase model capacity without breaking old knowledge
Expansion algorithm:
For layer, addnew neurons:whereare new neuron weights.
Sparse regularization:
To avoid excessive expansion, DEN usesregularization:First term encourages sparsity, second term
protects old knowledge.
PackNet
Clever Binary Mask Design
PackNet8 assigns different parameter subsets
to each task via binary masks:whereExtra close brace or missing open braceM_t \in \{0, 1} ^{|\theta|}is the mask for task.
Key constraint: Masks for different tasks don't
overlap:
Training process:
When taskarrives, freeze
parameters already used by previous tasks:2. Train taskon remaining parameters:3. After training, determine task's maskvia pruning:
Keep important parameters (e.g., topby weight magnitude)
Remaining parameters available for future tasks
Advantages: - High parameter reuse - Fixed model
size - Completely avoids forgetting
Disadvantages: - Available parameters gradually
decrease - Later task performance limited
Memory Replay Methods
Gradient Episodic Memory
(GEM)
Constrained Optimization
Perspective
GEM9 models continual learning as
constrained optimization:whereis current task
gradient,is old task gradient.
Intuition: New task gradient cannot conflict with
old task gradients (negative inner product).
Gradient Projection
If gradientviolates
constraints, project to feasible region:This is a quadratic programming problem,
solvable with existing solvers.
Simplified version: If only one old task (), projection formula:
Memory Buffer
GEM saves few samples per task (e.g., 100 per task):Old task gradientcomputed on memory buffer.
Averaged GEM (A-GEM)
Computational Efficiency
Improvement
A-GEM10 simplifies GEM's constraints:
doesn't require non-negative inner product with all old task gradients,
only with average gradient.
Average gradient:Constraint:
Projection formula:
Advantages: - Computational complexity reduced
fromto(is number of tasks) - No need to solve
quadratic programming
Loss function:First term is new
task loss, second term is distillation loss (LwF), third term is
meta-regularization (pull toward meta-parameters).
Intuition:is "universal parameters" from
meta-learning, task-specific parametersshould be close to.
Complete Code
Implementation: EWC from Scratch
Below is a complete EWC framework implementation including Fisher
information computation, multi-task training, forgetting evaluation, and
visualization.
""" EWC from Scratch: Elastic Weight Consolidation Includes: Fisher information computation, multi-task training, forgetting evaluation """
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, TensorDataset import numpy as np import matplotlib.pyplot as plt from typing importDict, List, Tuple from copy import deepcopy
# Set random seed torch.manual_seed(42) np.random.seed(42)
# ============================================================================ # Simple MLP Model # ============================================================================
classEWC: """ Elastic Weight Consolidation """ def__init__(self, model: nn.Module, dataloader: DataLoader, device: str = 'cpu'): self.model = model self.dataloader = dataloader self.device = device # Store Fisher information and parameters for each task self.fisher_dict: Dict[int, Dict[str, torch.Tensor]] = {} self.optpar_dict: Dict[int, Dict[str, torch.Tensor]] = {} defcompute_fisher(self, task_id: int): """ Compute diagonal elements of Fisher information matrix """ self.model.eval() # Initialize Fisher information fisher = {} for name, param in self.model.named_parameters(): fisher[name] = torch.zeros_like(param) # Accumulate gradient squares on data num_samples = 0 for inputs, targets in self.dataloader: inputs = inputs.to(self.device) targets = targets.to(self.device) self.model.zero_grad() # Forward pass outputs = self.model(inputs) # Compute negative log-likelihood loss = F.cross_entropy(outputs, targets) # Backward pass loss.backward() # Accumulate gradient squares for name, param in self.model.named_parameters(): if param.grad isnotNone: fisher[name] += param.grad.data ** 2 num_samples += inputs.size(0) # Average Fisher information for name in fisher: fisher[name] /= num_samples # Store Fisher information self.fisher_dict[task_id] = fisher # Store current parameters optpar = {} for name, param in self.model.named_parameters(): optpar[name] = param.data.clone() self.optpar_dict[task_id] = optpar print(f"Fisher information computed for task {task_id}") defpenalty(self) -> torch.Tensor: """ Compute EWC penalty term """ loss = 0.0 for task_id in self.fisher_dict: for name, param in self.model.named_parameters(): fisher = self.fisher_dict[task_id][name] optpar = self.optpar_dict[task_id][name] loss += (fisher * (param - optpar) ** 2).sum() return loss
# ============================================================================ # Training Function # ============================================================================
defplot_continual_learning_results(baseline_acc: np.ndarray, ewc_acc: np.ndarray): """ Plot continual learning results """ num_tasks = baseline_acc.shape[0] fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # 1. Accuracy heatmap (Baseline) im1 = axes[0].imshow(baseline_acc, cmap='YlGnBu', vmin=0, vmax=100) axes[0].set_xlabel('Evaluated Task', fontsize=12) axes[0].set_ylabel('After Training Task', fontsize=12) axes[0].set_title('Baseline: Accuracy Heatmap (%)', fontsize=14, fontweight='bold') axes[0].set_xticks(range(num_tasks)) axes[0].set_yticks(range(num_tasks)) axes[0].set_xticklabels([f'T{i+1}'for i inrange(num_tasks)]) axes[0].set_yticklabels([f'T{i+1}'for i inrange(num_tasks)]) # Add numerical annotations for i inrange(num_tasks): for j inrange(i + 1): text = axes[0].text(j, i, f'{baseline_acc[i, j]:.1f}', ha="center", va="center", color="black", fontsize=10) plt.colorbar(im1, ax=axes[0]) # 2. Accuracy heatmap (EWC) im2 = axes[1].imshow(ewc_acc, cmap='YlGnBu', vmin=0, vmax=100) axes[1].set_xlabel('Evaluated Task', fontsize=12) axes[1].set_ylabel('After Training Task', fontsize=12) axes[1].set_title('EWC: Accuracy Heatmap (%)', fontsize=14, fontweight='bold') axes[1].set_xticks(range(num_tasks)) axes[1].set_yticks(range(num_tasks)) axes[1].set_xticklabels([f'T{i+1}'for i inrange(num_tasks)]) axes[1].set_yticklabels([f'T{i+1}'for i inrange(num_tasks)]) # Add numerical annotations for i inrange(num_tasks): for j inrange(i + 1): text = axes[1].text(j, i, f'{ewc_acc[i, j]:.1f}', ha="center", va="center", color="black", fontsize=10) plt.colorbar(im2, ax=axes[1]) # 3. Average accuracy and forgetting comparison avg_acc_baseline = [baseline_acc[i, :i+1].mean() for i inrange(num_tasks)] avg_acc_ewc = [ewc_acc[i, :i+1].mean() for i inrange(num_tasks)] # Forgetting: accuracy drop on first task forgetting_baseline = [baseline_acc[0, 0] - baseline_acc[i, 0] for i inrange(num_tasks)] forgetting_ewc = [ewc_acc[0, 0] - ewc_acc[i, 0] for i inrange(num_tasks)] x = np.arange(1, num_tasks + 1) width = 0.35 ax3_1 = axes[2] ax3_1.plot(x, avg_acc_baseline, marker='o', label='Baseline - Avg Acc', linewidth=2, color='tab:blue') ax3_1.plot(x, avg_acc_ewc, marker='s', label='EWC - Avg Acc', linewidth=2, color='tab:green') ax3_1.set_xlabel('Number of Tasks Trained', fontsize=12) ax3_1.set_ylabel('Average Accuracy (%)', fontsize=12, color='tab:blue') ax3_1.tick_params(axis='y', labelcolor='tab:blue') ax3_1.legend(loc='upper left') ax3_1.grid(True, alpha=0.3) ax3_2 = ax3_1.twinx() ax3_2.plot(x, forgetting_baseline, marker='o', label='Baseline - Forgetting', linewidth=2, linestyle='--', color='tab:red') ax3_2.plot(x, forgetting_ewc, marker='s', label='EWC - Forgetting', linewidth=2, linestyle='--', color='tab:orange') ax3_2.set_ylabel('Forgetting on Task 1 (%)', fontsize=12, color='tab:red') ax3_2.tick_params(axis='y', labelcolor='tab:red') ax3_2.legend(loc='upper right') axes[2].set_title('Average Accuracy & Forgetting', fontsize=14, fontweight='bold') plt.tight_layout() plt.savefig('ewc_continual_learning.png', dpi=150, bbox_inches='tight') plt.close() print("\nResults saved to ewc_continual_learning.png")
# ============================================================================ # Main Function # ============================================================================
Disadvantage: Need to store gradients of all old
tasks.
Continual Backprop
Continual Backprop20 modifies backpropagation algorithm
to only update parameters important for current task.
Update rule:whereis parameter mask for task, learned automatically.
Supermasks in Superposition
(SupSup)
SupSup21 learns binary mask for each task,
all tasks share same parameters:During training, only optimize
mask, parametersfixed after random
initialization.
Surprising finding: Randomly initialized networks
can achieve multi-task learning performance through different masks!
Benchmarks and Evaluation
Standard Benchmarks
Permuted MNIST: Random pixel permutation of
MNIST
Split CIFAR: CIFAR-10 split by classes into
multiple tasks
CORe50: 50 objects with images from different
scenes
Continual Reinforcement Learning: Atari game
sequences
Evaluation Protocol
Standard evaluation includes:
Within-task performance: Upper bound performance
when each task trained separately
Average accuracy: Average performance across all
tasks
Backward transfer: Change in old task performance
after learning new tasks
Forward transfer: Benefit of old tasks to new
tasks
Parameter efficiency: Growth rate of model size
with task count
Frequently Asked Questions
Q1: How to choose EWC's?
Empirical rules:
Small tasks (e.g., Permuted MNIST):
Medium tasks (e.g., Split CIFAR):
Large tasks (e.g., ImageNet subsets):Tuning
strategy:
Start with smaller(e.g.,
100) for testing
Observe forgetting: If forgetting is severe, increase$$on validation set
Q2: When to compute
Fisher information?
Recommended: Immediately after each task finishes
training.
Reason: Model reaches optimum on that task, Fisher
information most accurate.
Note: If task data volume is large, can compute
Fisher information only on subset (e.g., 100 samples per class).
Q3: How to choose between
EWC, MAS, SI?
Scenario
Recommended Method
Supervised learning
EWC
Unsupervised learning
MAS
Online learning (update while training)
SI
Need label-free importance computation
MAS
Performance comparison: EWC ≈ MAS > SI (on most
benchmarks)
Q4: How large should
memory buffer be?
Empirical values:
Per task: 50-200 samples
Total buffer: 500-2000 samples
Trade-off: - Larger buffer → Less forgetting, but
high storage and computation overhead - Smaller buffer → More efficient,
but potentially severe forgetting
Optimal strategy: Dynamically adjust based on
available memory and task count.
Q5: Which is better, GEM or
A-GEM?
Dimension
GEM
A-GEM
Forgetting control
Stricter
Looser
Computational complexity
High (quadratic programming)
Low (linear projection)
Scalability
Poor (slow with many tasks)
Good
Implementation difficulty
Difficult
Simple
Recommendation: Unless extremely sensitive to
forgetting, prefer A-GEM (similar performance but much faster).
Q6:
What are disadvantages of dynamic architecture methods?
Progressive Networks: - Model size grows
linearly:tasks requiretimes parameters - Inference time grows
linearly - Difficult to deploy (model too large)
DEN: - High training complexity (needs dynamic
expansion decisions) - Sensitive to hyperparameters (expansion
threshold, sparsity coefficient) - May overexpand
PackNet: - Later task performance limited (available
parameters decrease) - Pruning strategy has significant impact - Upper
limit on task count (parameters exhausted)
Trade-off: Dynamic architectures completely avoid
forgetting, but sacrifice efficiency and scalability.
Q7: Role of
meta-learning in continual learning?
Advantages: - Learn general representations, reduce
task-specific parameters - Support fast adaptation to new tasks -
Theoretically elegant (minimize meta-loss across all tasks)
Disadvantages: - Need meta-training phase (task
distribution known) - High computational complexity (second-order
derivatives) - Limited performance improvement in practice
Applicable scenarios: High task similarity, need
fast adaptation scenarios (e.g., few-shot learning).
Q8:
Difference between continual learning and multi-task learning?
Dimension
Continual Learning
Multi-Task Learning
Task visibility
Sequential arrival
Simultaneous visibility
Data access
Cannot access old data
All data available
Main challenge
Catastrophic forgetting
Task balancing
Goal
Maintain old task performance
Average performance across all tasks
Connection: Continual learning can be seen as
data-constrained multi-task learning.
Q9: How to
handle Class-Incremental Learning (Class-IL)?
Class-IL is the most difficult scenario, requiring special
handling:
Output layer expansion: Add new class output
neurons for each new task
Bias correction: New class outputs typically
smaller (not fully trained), need correction
Knowledge distillation: Maintain old class output
distributions
Continual learning enables models to have lifelong learning
capabilities, an important foundation for artificial general
intelligence. In the next chapter, we will explore cross-lingual
transfer and see how models can seamlessly transfer knowledge across
different languages.
References
Kirkpatrick, J., Pascanu, R., Rabinowitz, N., et al.
(2017). Overcoming catastrophic forgetting in neural networks.
PNAS.↩︎
Schwarz, J., Czarnecki, W., Luketina, J., et al. (2018).
Progress & compress: A scalable framework for continual learning.
ICML.↩︎
Aljundi, R., Babiloni, F., Elhoseiny, M., et al. (2018).
Memory aware synapses: Learning what (not) to forget. ECCV.↩︎
Zenke, F., Poole, B., & Ganguli, S. (2017).
Continual learning through synaptic intelligence. ICML.↩︎
Li, Z., & Hoiem, D. (2017). Learning without
forgetting. TPAMI.↩︎
Rusu, A. A., Rabinowitz, N. C., Desjardins, G., et al.
(2016). Progressive neural networks. arXiv:1606.04671.↩︎
Yoon, J., Yang, E., Lee, J., & Hwang, S. J. (2018).
Lifelong learning with dynamically expandable networks. ICLR.↩︎
Mallya, A., & Lazebnik, S. (2018). PackNet: Adding
multiple tasks to a single network by iterative pruning.
CVPR.↩︎
Lopez-Paz, D., & Ranzato, M. (2017). Gradient
episodic memory for continual learning. NeurIPS.↩︎
Chaudhry, A., Ranzato, M., Rohrbach, M., &
Elhoseiny, M. (2019). Efficient lifelong learning with A-GEM.
ICLR.↩︎
Robins, A. (1995). Catastrophic forgetting, rehearsal
and pseudorehearsal. Connection Science.↩︎
Buzzega, P., Boschini, M., Porrello, A., et al. (2020).
Dark experience for general continual learning: A strong, simple
baseline. NeurIPS.↩︎
Finn, C., Abbeel, P., & Levine, S. (2017).
Model-agnostic meta-learning for fast adaptation of deep networks.
ICML.↩︎
Riemer, M., Cases, I., Ajemian, R., et al. (2019).
Learning to learn without forgetting by maximizing transfer and
minimizing interference. ICLR.↩︎
Javed, K., & White, M. (2019). Meta-learning
representations for continual learning. NeurIPS.↩︎
Beaulieu, S., Frati, L., Miconi, T., et al. (2020).
Learning to continually learn rapidly from few and noisy data.
arXiv:2006.10220.↩︎
Abraham, W. C., & Robins, A. (2005). Memory
retention – the synaptic stability versus plasticity dilemma. Trends
in Neurosciences.↩︎
French, R. M. (1999). Catastrophic forgetting in
connectionist networks. Trends in Cognitive Sciences.↩︎
Farajtabar, M., Azizan, N., Mott, A., & Li, A.
(2020). Orthogonal gradient descent for continual learning.
AISTATS.↩︎
Golkar, S., Kagan, M., & Cho, K. (2019). Continual
learning via neural pruning. arXiv:1903.04476.↩︎
Wortsman, M., Ramanujan, V., Liu, R., et al. (2020).
Supermasks in superposition. NeurIPS.↩︎
Rebuffi, S. A., Kolesnikov, A., Sperl, G., &
Lampert, C. H. (2017). iCaRL: Incremental classifier and representation
learning. CVPR.↩︎
Hou, S., Pan, X., Loy, C. C., et al. (2019). Learning a
unified classifier incrementally via rebalancing. CVPR.↩︎
Farquhar, S., & Gal, Y. (2018). Towards robust
evaluations of continual learning. arXiv:1805.09733.↩︎
Post title:Transfer Learning (10): Continual Learning
Post author:Chen Kai
Create time:2024-12-27 15:00:00
Post link:https://www.chenk.top/transfer-learning-10-continual-learning/
Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.