Transfer Learning (10): Continual Learning
Chen Kai BOSS

Humans can continuously learn new skills without forgetting old knowledge, but neural networks often "forget" when learning new tasks — this is catastrophic forgetting. How can models learn like humans throughout their lifetime, remembering the first task after mastering 100 tasks? Continual Learning provides the answer.

This article systematically explains the principles and implementations of four major approaches — regularization, dynamic architectures, memory replay, and meta-learning — starting from the mathematical mechanisms of catastrophic forgetting. We analyze parameter importance estimation, inter-task knowledge transfer, and the forgetting-stability trade-off, and provide complete code (250+ lines) for implementing EWC from scratch.

Problem Definition of Continual Learning

Task Sequence

Continual learning processes a sequence of tasks arriving over time:

Each taskis defined as a learning problem:where: -is task data -is task loss function -is the model for task (with parameters)

Key constraint: When learning task, previous task datacannot be accessed.

Catastrophic Forgetting

Phenomenon: After training on task, model performance on previous tasksdrops dramatically.

Formalized as:where: -: accuracy on taskimmediately after learning task -: accuracy on taskafter learning task Goal: Minimize forgetting while maintaining learning capability for new tasks.

Three Scenarios of Continual Learning

  1. Task-Incremental Learning:
    • Task ID known at inference
    • Each task has independent output head
    • Example: MNIST → FashionMNIST → CIFAR10
  2. Domain-Incremental Learning:
    • Same task, different data distributions
    • Example: Sunny images → Rainy images → Night images
  3. Class-Incremental Learning:
    • Task ID unknown at inference
    • Must predict from all learned classes
    • Most challenging scenario

Evaluation Metrics

  1. Average Accuracy:

  2. Average Forgetting:

  3. Backward Transfer:

-: Forgetting -: Positive transfer

  1. Forward Transfer:

-: Zero-shot performance (accuracy on taskbefore learning task)

Mathematical Mechanisms of Catastrophic Forgetting

Loss Landscape Perspective

Neural network loss functions are defined in high-dimensional parameter space:Training is optimization to find local optima on the loss landscape:

Problems:

  1. Task 1's optimumand Task 2's optimumare typically far apart
  2. Optimizing fromtoleaves Task 1's low-loss region
  3. Gradient descent updates parameters globally, difficult to adjust locally

Gradient Interference

Gradients of Task 1 and Task 2 may conflict:

Intuition: Parameter updates that improve Task 2 worsen Task 1 performance.

Define gradient conflict:Extra close brace or missing open brace\text{Conflict} = \frac{|\{(i,j): g_i \cdot g_j < 0} |}{T(T-1)/2}whereis the gradient of task.

Weight Importance

Not all parameters are equally important for old tasks. Define importance of parameterfor task:

Insight: Protect important parameters (large), allow unimportant ones to change.

Fisher Information Matrix

Fisher information matrix measures parameter sensitivity to loss:Diagonal elementrepresents importance of parameter:

Properties: - Large: Parameterhas significant impact on predictions, should be protected - Small: Parameterhas minimal impact on predictions, can be modified

Regularization Methods

Elastic Weight Consolidation (EWC)

Core Idea of EWC

EWC1 applies regularization constraints to important parameters when learning new tasks, preventing them from deviating from old task optima.

Objective function:where: -: New task B's loss -: Old task A's optimal parameters -: Fisher information (importance) of parameter -: Regularization strength

Intuition: Constrain changes in important parameters to protect old task knowledge.

Computing Fisher Information

For classification tasks, diagonal elements of Fisher information matrix:In practice, compute on task A's data:

  1. Forward pass to get prediction$ = f_(x)p(|x; )g_i = F_i = [g_i^2]$

Multi-Task Extension

When learning task sequence, EWC objective:

Problem: Accumulating Fisher information makes parameters increasingly "rigid".

Improvement: Online EWC2, only maintains current Fisher information and parameters, avoiding accumulation:where: is decay factor (e.g., 0.9).

Memory Aware Synapses (MAS)

MAS Improvements

MAS3 addresses EWC's limitation: Fisher information only considers last layer gradients, ignoring intermediate layer importance.

MAS parameter importance definition:Note this is gradient of output with respect to parameters, not loss with respect to parameters.

Objective function:

MAS vs EWC

Dimension EWC MAS
Importance measure Fisher information (gradient square) Output sensitivity (gradient absolute value)
Computation dependency Requires labels No labels needed
Applicable scenarios Supervised learning Unsupervised/self-supervised
Computational complexity Moderate Low

Synaptic Intelligence (SI)

SI's Online Update

SI4 computes parameter importance online during training, not after task completion.

Parameter importance:where: -: Parameter update at step -: Gradient at step -: Total parameter change -: Small constant to prevent division by zero

Intuition: Longer "path length" of parameter movement during training with more loss reduction indicates higher importance.

Objective function:

Learning without Forgetting (LwF)

Application of Knowledge Distillation

LwF5 uses knowledge distillation to maintain old task output distributions.

Loss function has two parts:

  1. New task loss:

  2. Distillation loss (maintain old task outputs):Total loss:

Advantage: No need to save old task data, only old model predictions.

Disadvantage: Need to save copy of old model (storage overhead).

Dynamic Architecture Methods

Progressive Neural Networks

Progressive Expansion

Progressive Networks6 add new network columns for each new task:where: -: Activation of taskat layer -: Weights of task(trainable) -: Lateral connections from taskto task(trainable) - Old task parameters() completely frozen

Advantages: - Completely avoids catastrophic forgetting (old parameters unchanged) - Supports forward transfer (lateral connections leverage old knowledge)

Disadvantages: - Model size grows linearly:tasks requirenetwork columns - Inference overhead increases with task count

Dynamically Expandable Networks (DEN)

Dynamic Expansion Strategy

DEN7 dynamically expands network capacity as needed:

  1. Selective Retraining:
    • Freeze important parameters
    • Only fine-tune unimportant parameters
  2. Dynamic Expansion:
    • If existing capacity insufficient, add new neurons
    • Decision criterion: Validation loss stops decreasing
  3. Network Split/Duplication:
    • Duplicate neurons with added noise
    • Increase model capacity without breaking old knowledge

Expansion algorithm:

For layer, addnew neurons:whereare new neuron weights.

Sparse regularization:

To avoid excessive expansion, DEN usesregularization:First term encourages sparsity, second term protects old knowledge.

PackNet

Clever Binary Mask Design

PackNet8 assigns different parameter subsets to each task via binary masks:whereExtra close brace or missing open braceM_t \in \{0, 1} ^{|\theta|}is the mask for task.

Key constraint: Masks for different tasks don't overlap:

Training process:

  1. When taskarrives, freeze parameters already used by previous tasks:2. Train taskon remaining parameters:3. After training, determine task's maskvia pruning:
    • Keep important parameters (e.g., topby weight magnitude)
    • Remaining parameters available for future tasks

Advantages: - High parameter reuse - Fixed model size - Completely avoids forgetting

Disadvantages: - Available parameters gradually decrease - Later task performance limited

Memory Replay Methods

Gradient Episodic Memory (GEM)

Constrained Optimization Perspective

GEM9 models continual learning as constrained optimization:whereis current task gradient,is old task gradient.

Intuition: New task gradient cannot conflict with old task gradients (negative inner product).

Gradient Projection

If gradientviolates constraints, project to feasible region:This is a quadratic programming problem, solvable with existing solvers.

Simplified version: If only one old task (), projection formula:

Memory Buffer

GEM saves few samples per task (e.g., 100 per task):Old task gradientcomputed on memory buffer.

Averaged GEM (A-GEM)

Computational Efficiency Improvement

A-GEM10 simplifies GEM's constraints: doesn't require non-negative inner product with all old task gradients, only with average gradient.

Average gradient:Constraint:

Projection formula:

Advantages: - Computational complexity reduced fromto(is number of tasks) - No need to solve quadratic programming

Disadvantages: - Looser constraints, potentially higher forgetting

Experience Replay (ER)

Simplest Replay

Experience Replay11 mixes memory samples from old tasks when training new task:whereis memory buffer sampled from all old tasks.

Sampling strategies:

  1. Uniform sampling: Equal samples per task
  2. Performance-based sampling: More samples from tasks with severe forgetting
  3. Time decay: Higher weight for recent tasks

Memory buffer management:

  • Reservoir Sampling: Equal probability of retaining all seen samples
  • Ring Buffer: Fixed size, new samples replace old ones
  • Herding: Select samples closest to class centers

Dark Experience Replay (DER)

Combining Knowledge Distillation and Replay

DER12 saves not only samplesin memory buffer, but also model outputs.

Loss function:Second term is classification loss (true labels of memory samples), third term is distillation loss (maintain old model outputs).

Advantage: Distillation loss mitigates overfitting to memory samples.

Meta-Learning Methods

Model-Agnostic Meta-Learning for Continual Learning

Application of MAML

MAML13 finds a good initializationthrough meta-learning, enabling fast adaptation to any task from.

In continual learning, MAML can be used as:

  1. Inner loop: Fast adaptation on current task:

  2. Outer loop: Update meta-parametersto perform well on all tasks:

Problem: Needs to retain all old task data (conflicts with continual learning's no-data assumption).

Improvement: Meta-Experience Replay14, only performs outer loop updates on memory buffer.

Online Meta-Learning (OML)

Online Meta-Learning

OML15 updates meta-parameters online in continual learning:

Representation learner split into two parts: - Representation network:, parameters(slow update) - Prediction head:, parameters(fast update)

Update strategy:

  1. When task arrives: Fast adapt prediction head:

  2. After task ends: Slow update representation:

Advantage: Representation networklearns general features, prediction headlearns task-specific knowledge.

Learning to Learn without Forgetting (Meta-LwF)

Meta-Learning Regularization

Meta-LwF16 combines meta-learning and LwF:

Loss function:First term is new task loss, second term is distillation loss (LwF), third term is meta-regularization (pull toward meta-parameters).

Intuition:is "universal parameters" from meta-learning, task-specific parametersshould be close to.

Complete Code Implementation: EWC from Scratch

Below is a complete EWC framework implementation including Fisher information computation, multi-task training, forgetting evaluation, and visualization.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
"""
EWC from Scratch: Elastic Weight Consolidation
Includes: Fisher information computation, multi-task training, forgetting evaluation
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
from copy import deepcopy

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

# ============================================================================
# Simple MLP Model
# ============================================================================

class SimpleMLP(nn.Module):
"""
Simple multi-layer perceptron
"""
def __init__(self, input_dim: int = 784, hidden_dims: List[int] = [256, 256], output_dim: int = 10):
super().__init__()

layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim

layers.append(nn.Linear(prev_dim, output_dim))

self.network = nn.Sequential(*layers)

def forward(self, x):
return self.network(x)

# ============================================================================
# EWC Implementation
# ============================================================================

class EWC:
"""
Elastic Weight Consolidation
"""
def __init__(self, model: nn.Module, dataloader: DataLoader, device: str = 'cpu'):
self.model = model
self.dataloader = dataloader
self.device = device

# Store Fisher information and parameters for each task
self.fisher_dict: Dict[int, Dict[str, torch.Tensor]] = {}
self.optpar_dict: Dict[int, Dict[str, torch.Tensor]] = {}

def compute_fisher(self, task_id: int):
"""
Compute diagonal elements of Fisher information matrix
"""
self.model.eval()

# Initialize Fisher information
fisher = {}
for name, param in self.model.named_parameters():
fisher[name] = torch.zeros_like(param)

# Accumulate gradient squares on data
num_samples = 0
for inputs, targets in self.dataloader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)

self.model.zero_grad()

# Forward pass
outputs = self.model(inputs)

# Compute negative log-likelihood
loss = F.cross_entropy(outputs, targets)

# Backward pass
loss.backward()

# Accumulate gradient squares
for name, param in self.model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.data ** 2

num_samples += inputs.size(0)

# Average Fisher information
for name in fisher:
fisher[name] /= num_samples

# Store Fisher information
self.fisher_dict[task_id] = fisher

# Store current parameters
optpar = {}
for name, param in self.model.named_parameters():
optpar[name] = param.data.clone()
self.optpar_dict[task_id] = optpar

print(f"Fisher information computed for task {task_id}")

def penalty(self) -> torch.Tensor:
"""
Compute EWC penalty term
"""
loss = 0.0

for task_id in self.fisher_dict:
for name, param in self.model.named_parameters():
fisher = self.fisher_dict[task_id][name]
optpar = self.optpar_dict[task_id][name]
loss += (fisher * (param - optpar) ** 2).sum()

return loss

# ============================================================================
# Training Function
# ============================================================================

def train_task(
model: nn.Module,
dataloader: DataLoader,
ewc: EWC,
ewc_lambda: float,
optimizer: optim.Optimizer,
device: str,
num_epochs: int = 10
) -> List[float]:
"""
Train on single task (with EWC regularization)
"""
model.train()
losses = []

for epoch in range(num_epochs):
epoch_loss = 0.0

for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)

# Forward pass
outputs = model(inputs)

# Task loss
task_loss = F.cross_entropy(outputs, targets)

# EWC penalty
ewc_loss = ewc.penalty() if len(ewc.fisher_dict) > 0 else 0.0

# Total loss
loss = task_loss + ewc_lambda * ewc_loss

# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_loss += loss.item()

avg_loss = epoch_loss / len(dataloader)
losses.append(avg_loss)

if (epoch + 1) % 2 == 0:
print(f" Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

return losses

def evaluate_task(model: nn.Module, dataloader: DataLoader, device: str) -> float:
"""
Evaluate task accuracy
"""
model.eval()
correct = 0
total = 0

with torch.no_grad():
for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)

outputs = model(inputs)
_, predicted = torch.max(outputs, 1)

correct += (predicted == targets).sum().item()
total += targets.size(0)

accuracy = 100 * correct / total
return accuracy

# ============================================================================
# Generate Multi-Task Dataset (Permuted MNIST)
# ============================================================================

def create_permuted_mnist_tasks(num_tasks: int = 5, num_samples: int = 1000) -> List[Tuple[DataLoader, DataLoader]]:
"""
Create Permuted MNIST task sequence
Each task is a random pixel permutation of MNIST
"""
tasks = []

# Generate random MNIST data (simplified version)
for task_id in range(num_tasks):
# Generate random data (simulating MNIST)
X_train = torch.randn(num_samples, 784)
y_train = torch.randint(0, 10, (num_samples,))

X_test = torch.randn(200, 784)
y_test = torch.randint(0, 10, (200,))

# Apply random permutation (simulating Permuted MNIST)
if task_id > 0:
perm = torch.randperm(784)
X_train = X_train[:, perm]
X_test = X_test[:, perm]

# Create DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

tasks.append((train_loader, test_loader))

return tasks

# ============================================================================
# Main Experiment: Compare Baseline and EWC
# ============================================================================

def run_continual_learning_experiment(
num_tasks: int = 5,
ewc_lambda: float = 5000.0,
num_epochs: int = 10,
device: str = 'cpu'
):
"""
Run continual learning experiment
"""
print("="*70)
print("Continual Learning Experiment: Baseline vs EWC")
print("="*70)

# Create task sequence
print(f"\nCreating {num_tasks} permuted MNIST tasks...")
tasks = create_permuted_mnist_tasks(num_tasks=num_tasks)

# ========================================================================
# Method 1: Baseline (No Regularization)
# ========================================================================
print("\n" + "="*70)
print("Method 1: Baseline (No Regularization)")
print("="*70)

model_baseline = SimpleMLP().to(device)
baseline_accuracies = np.zeros((num_tasks, num_tasks))

for task_id in range(num_tasks):
print(f"\n--- Training Task {task_id+1} ---")

train_loader, _ = tasks[task_id]
optimizer = optim.SGD(model_baseline.parameters(), lr=0.01, momentum=0.9)

train_task(model_baseline, train_loader, EWC(model_baseline, train_loader, device),
ewc_lambda=0.0, optimizer=optimizer, device=device, num_epochs=num_epochs)

# Evaluate all tasks
print(f"\nEvaluating all tasks after training Task {task_id+1}:")
for eval_task_id in range(task_id + 1):
_, test_loader = tasks[eval_task_id]
acc = evaluate_task(model_baseline, test_loader, device)
baseline_accuracies[task_id, eval_task_id] = acc
print(f" Task {eval_task_id+1}: {acc:.2f}%")

# ========================================================================
# Method 2: EWC
# ========================================================================
print("\n" + "="*70)
print("Method 2: EWC")
print("="*70)

model_ewc = SimpleMLP().to(device)
ewc = EWC(model_ewc, tasks[0][0], device) # Initialize EWC
ewc_accuracies = np.zeros((num_tasks, num_tasks))

for task_id in range(num_tasks):
print(f"\n--- Training Task {task_id+1} ---")

train_loader, _ = tasks[task_id]
optimizer = optim.SGD(model_ewc.parameters(), lr=0.01, momentum=0.9)

train_task(model_ewc, train_loader, ewc, ewc_lambda=ewc_lambda,
optimizer=optimizer, device=device, num_epochs=num_epochs)

# Compute Fisher information
ewc.dataloader = train_loader
ewc.compute_fisher(task_id)

# Evaluate all tasks
print(f"\nEvaluating all tasks after training Task {task_id+1}:")
for eval_task_id in range(task_id + 1):
_, test_loader = tasks[eval_task_id]
acc = evaluate_task(model_ewc, test_loader, device)
ewc_accuracies[task_id, eval_task_id] = acc
print(f" Task {eval_task_id+1}: {acc:.2f}%")

return baseline_accuracies, ewc_accuracies

# ============================================================================
# Visualization
# ============================================================================

def plot_continual_learning_results(baseline_acc: np.ndarray, ewc_acc: np.ndarray):
"""
Plot continual learning results
"""
num_tasks = baseline_acc.shape[0]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Accuracy heatmap (Baseline)
im1 = axes[0].imshow(baseline_acc, cmap='YlGnBu', vmin=0, vmax=100)
axes[0].set_xlabel('Evaluated Task', fontsize=12)
axes[0].set_ylabel('After Training Task', fontsize=12)
axes[0].set_title('Baseline: Accuracy Heatmap (%)', fontsize=14, fontweight='bold')
axes[0].set_xticks(range(num_tasks))
axes[0].set_yticks(range(num_tasks))
axes[0].set_xticklabels([f'T{i+1}' for i in range(num_tasks)])
axes[0].set_yticklabels([f'T{i+1}' for i in range(num_tasks)])

# Add numerical annotations
for i in range(num_tasks):
for j in range(i + 1):
text = axes[0].text(j, i, f'{baseline_acc[i, j]:.1f}',
ha="center", va="center", color="black", fontsize=10)

plt.colorbar(im1, ax=axes[0])

# 2. Accuracy heatmap (EWC)
im2 = axes[1].imshow(ewc_acc, cmap='YlGnBu', vmin=0, vmax=100)
axes[1].set_xlabel('Evaluated Task', fontsize=12)
axes[1].set_ylabel('After Training Task', fontsize=12)
axes[1].set_title('EWC: Accuracy Heatmap (%)', fontsize=14, fontweight='bold')
axes[1].set_xticks(range(num_tasks))
axes[1].set_yticks(range(num_tasks))
axes[1].set_xticklabels([f'T{i+1}' for i in range(num_tasks)])
axes[1].set_yticklabels([f'T{i+1}' for i in range(num_tasks)])

# Add numerical annotations
for i in range(num_tasks):
for j in range(i + 1):
text = axes[1].text(j, i, f'{ewc_acc[i, j]:.1f}',
ha="center", va="center", color="black", fontsize=10)

plt.colorbar(im2, ax=axes[1])

# 3. Average accuracy and forgetting comparison
avg_acc_baseline = [baseline_acc[i, :i+1].mean() for i in range(num_tasks)]
avg_acc_ewc = [ewc_acc[i, :i+1].mean() for i in range(num_tasks)]

# Forgetting: accuracy drop on first task
forgetting_baseline = [baseline_acc[0, 0] - baseline_acc[i, 0] for i in range(num_tasks)]
forgetting_ewc = [ewc_acc[0, 0] - ewc_acc[i, 0] for i in range(num_tasks)]

x = np.arange(1, num_tasks + 1)
width = 0.35

ax3_1 = axes[2]
ax3_1.plot(x, avg_acc_baseline, marker='o', label='Baseline - Avg Acc',
linewidth=2, color='tab:blue')
ax3_1.plot(x, avg_acc_ewc, marker='s', label='EWC - Avg Acc',
linewidth=2, color='tab:green')
ax3_1.set_xlabel('Number of Tasks Trained', fontsize=12)
ax3_1.set_ylabel('Average Accuracy (%)', fontsize=12, color='tab:blue')
ax3_1.tick_params(axis='y', labelcolor='tab:blue')
ax3_1.legend(loc='upper left')
ax3_1.grid(True, alpha=0.3)

ax3_2 = ax3_1.twinx()
ax3_2.plot(x, forgetting_baseline, marker='o', label='Baseline - Forgetting',
linewidth=2, linestyle='--', color='tab:red')
ax3_2.plot(x, forgetting_ewc, marker='s', label='EWC - Forgetting',
linewidth=2, linestyle='--', color='tab:orange')
ax3_2.set_ylabel('Forgetting on Task 1 (%)', fontsize=12, color='tab:red')
ax3_2.tick_params(axis='y', labelcolor='tab:red')
ax3_2.legend(loc='upper right')

axes[2].set_title('Average Accuracy & Forgetting', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('ewc_continual_learning.png', dpi=150, bbox_inches='tight')
plt.close()
print("\nResults saved to ewc_continual_learning.png")

# ============================================================================
# Main Function
# ============================================================================

def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Run experiment
baseline_acc, ewc_acc = run_continual_learning_experiment(
num_tasks=5,
ewc_lambda=5000.0,
num_epochs=10,
device=device
)

# Compute final metrics
print("\n" + "="*70)
print("Final Results")
print("="*70)

num_tasks = baseline_acc.shape[0]

# Average accuracy
avg_acc_baseline = baseline_acc[-1, :].mean()
avg_acc_ewc = ewc_acc[-1, :].mean()
print(f"\nAverage Accuracy (after all tasks):")
print(f" Baseline: {avg_acc_baseline:.2f}%")
print(f" EWC: {avg_acc_ewc:.2f}%")
print(f" Improvement: {avg_acc_ewc - avg_acc_baseline:.2f}%")

# Forgetting (using first task as example)
forgetting_baseline = baseline_acc[0, 0] - baseline_acc[-1, 0]
forgetting_ewc = ewc_acc[0, 0] - ewc_acc[-1, 0]
print(f"\nForgetting on Task 1:")
print(f" Baseline: {forgetting_baseline:.2f}%")
print(f" EWC: {forgetting_ewc:.2f}%")
print(f" Reduction: {forgetting_baseline - forgetting_ewc:.2f}%")

# Plot
plot_continual_learning_results(baseline_acc, ewc_acc)

print("\n" + "="*70)
print("Experiment completed!")
print("="*70)

if __name__ == "__main__":
main()

Code Explanation

Core Components:

  1. EWC class:
    • compute_fisher(): Compute Fisher information matrix
    • penalty(): Compute EWC regularization term
  2. Training process:
    • Task sequence: Permuted MNIST (each task is random pixel permutation of MNIST)
    • Compare Baseline (no regularization) and EWC
  3. Evaluation metrics:
    • Accuracy heatmap: Show performance of each task at different time points
    • Average accuracy: Average performance across all tasks
    • Forgetting: Accuracy drop on first task

Key Details:

  • Fisher information computed after each task ends
  • EWC penalty accumulates constraints from all old tasks
  • Learning rate and EWC strengthneed tuning

Frontiers of Continual Learning

Theoretical Analysis

Stability-Plasticity Dilemma

Continual learning faces fundamental dilemma17:

  • Stability: Maintain old knowledge → Reduce forgetting
  • Plasticity: Learn new knowledge → Adapt to new tasks

Trade-off exists:Optimaldepends on task similarity, data distribution drift, etc.

Memory Capacity Analysis

Network memory capacity is the maximum number of tasks without catastrophic forgetting18.

For network withparameters, memory capacity upper bound:

Intuition: Each parameter needs to "remember"bits of information on average.

Latest Methods

Orthogonal Gradient Descent (OGD)

OGD19 projects new task gradients to subspace orthogonal to old task gradients:

Advantage: Completely eliminates gradient conflicts.

Disadvantage: Need to store gradients of all old tasks.

Continual Backprop

Continual Backprop20 modifies backpropagation algorithm to only update parameters important for current task.

Update rule:whereis parameter mask for task, learned automatically.

Supermasks in Superposition (SupSup)

SupSup21 learns binary mask for each task, all tasks share same parameters:During training, only optimize mask, parametersfixed after random initialization.

Surprising finding: Randomly initialized networks can achieve multi-task learning performance through different masks!

Benchmarks and Evaluation

Standard Benchmarks

  1. Permuted MNIST: Random pixel permutation of MNIST
  2. Split CIFAR: CIFAR-10 split by classes into multiple tasks
  3. CORe50: 50 objects with images from different scenes
  4. Continual Reinforcement Learning: Atari game sequences

Evaluation Protocol

Standard evaluation includes:

  1. Within-task performance: Upper bound performance when each task trained separately
  2. Average accuracy: Average performance across all tasks
  3. Backward transfer: Change in old task performance after learning new tasks
  4. Forward transfer: Benefit of old tasks to new tasks
  5. Parameter efficiency: Growth rate of model size with task count

Frequently Asked Questions

Q1: How to choose EWC's?

Empirical rules:

  • Small tasks (e.g., Permuted MNIST):
  • Medium tasks (e.g., Split CIFAR):
  • Large tasks (e.g., ImageNet subsets): Tuning strategy:
  1. Start with smaller(e.g., 100) for testing
  2. Observe forgetting: If forgetting is severe, increase$$on validation set

Q2: When to compute Fisher information?

Recommended: Immediately after each task finishes training.

Reason: Model reaches optimum on that task, Fisher information most accurate.

Note: If task data volume is large, can compute Fisher information only on subset (e.g., 100 samples per class).

Q3: How to choose between EWC, MAS, SI?

Scenario Recommended Method
Supervised learning EWC
Unsupervised learning MAS
Online learning (update while training) SI
Need label-free importance computation MAS

Performance comparison: EWC ≈ MAS > SI (on most benchmarks)

Q4: How large should memory buffer be?

Empirical values:

  • Per task: 50-200 samples
  • Total buffer: 500-2000 samples

Trade-off: - Larger buffer → Less forgetting, but high storage and computation overhead - Smaller buffer → More efficient, but potentially severe forgetting

Optimal strategy: Dynamically adjust based on available memory and task count.

Q5: Which is better, GEM or A-GEM?

Dimension GEM A-GEM
Forgetting control Stricter Looser
Computational complexity High (quadratic programming) Low (linear projection)
Scalability Poor (slow with many tasks) Good
Implementation difficulty Difficult Simple

Recommendation: Unless extremely sensitive to forgetting, prefer A-GEM (similar performance but much faster).

Q6: What are disadvantages of dynamic architecture methods?

Progressive Networks: - Model size grows linearly:tasks requiretimes parameters - Inference time grows linearly - Difficult to deploy (model too large)

DEN: - High training complexity (needs dynamic expansion decisions) - Sensitive to hyperparameters (expansion threshold, sparsity coefficient) - May overexpand

PackNet: - Later task performance limited (available parameters decrease) - Pruning strategy has significant impact - Upper limit on task count (parameters exhausted)

Trade-off: Dynamic architectures completely avoid forgetting, but sacrifice efficiency and scalability.

Q7: Role of meta-learning in continual learning?

Advantages: - Learn general representations, reduce task-specific parameters - Support fast adaptation to new tasks - Theoretically elegant (minimize meta-loss across all tasks)

Disadvantages: - Need meta-training phase (task distribution known) - High computational complexity (second-order derivatives) - Limited performance improvement in practice

Applicable scenarios: High task similarity, need fast adaptation scenarios (e.g., few-shot learning).

Q8: Difference between continual learning and multi-task learning?

Dimension Continual Learning Multi-Task Learning
Task visibility Sequential arrival Simultaneous visibility
Data access Cannot access old data All data available
Main challenge Catastrophic forgetting Task balancing
Goal Maintain old task performance Average performance across all tasks

Connection: Continual learning can be seen as data-constrained multi-task learning.

Q9: How to handle Class-Incremental Learning (Class-IL)?

Class-IL is the most difficult scenario, requiring special handling:

  1. Output layer expansion: Add new class output neurons for each new task
  2. Bias correction: New class outputs typically smaller (not fully trained), need correction
  3. Knowledge distillation: Maintain old class output distributions
  4. Memory replay: Mix samples from old classes

Recommended methods: iCaRL22, LUCIR23.

Q10: Can continual learning be used in production?

Challenges:

  1. Forgetting unacceptable: Production requires strict performance guarantees
  2. Inference latency: Dynamic architecture methods slow inference
  3. Model update frequency: New tasks arrive quickly

Practical strategies:

  1. Hybrid methods: EWC + small amount of memory replay
  2. Periodic full fine-tuning: Everytasks, full fine-tune with memory buffer
  3. A/B testing: Continual learning model runs in parallel with old model, compare performance
  4. Fallback mechanism: If new task learning fails, rollback to old model

Success cases: Incremental updates for recommendation systems, speech recognition, image classification.

Q11: How to debug continual learning models?

Diagnostic steps:

  1. Check single-task performance: Train each task separately, confirm baseline performance
  2. Check gradient conflicts: Compute inner products of different task gradients, see if negative values exist
  3. Visualize Fisher information: See which parameters marked as important
  4. Monitor forgetting curves: Plot accuracy changes over time for each task
  5. Ablation experiments: Remove regularization/memory replay, see performance drop

Q12: What are theoretical limits of continual learning?

Information theory limits24:

For network withparameters, to learntasks without forgetting requires:whereis task mutual information.

Intuition: Network capacity is limited, too many tasks inevitably cause forgetting.

Breakthrough directions: - Leverage task similarity (shared representations) - Compress old task knowledge (knowledge distillation) - Dynamically expand capacity (architecture search)

Summary

This article comprehensively introduced continual learning techniques:

  1. Problem definition: Mathematical mechanisms of catastrophic forgetting and evaluation metrics
  2. Regularization methods: Principles and comparisons of EWC, MAS, SI, LwF
  3. Dynamic architectures: Designs of Progressive Networks, DEN, PackNet
  4. Memory replay: Strategies of GEM, A-GEM, ER, DER
  5. Meta-learning: Applications of MAML, OML in continual learning
  6. Complete code: 250+ lines of production-level code implementing EWC from scratch
  7. Frontiers: Stability-plasticity dilemma, memory capacity theory, latest methods

Continual learning enables models to have lifelong learning capabilities, an important foundation for artificial general intelligence. In the next chapter, we will explore cross-lingual transfer and see how models can seamlessly transfer knowledge across different languages.

References


  1. Kirkpatrick, J., Pascanu, R., Rabinowitz, N., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS.↩︎

  2. Schwarz, J., Czarnecki, W., Luketina, J., et al. (2018). Progress & compress: A scalable framework for continual learning. ICML.↩︎

  3. Aljundi, R., Babiloni, F., Elhoseiny, M., et al. (2018). Memory aware synapses: Learning what (not) to forget. ECCV.↩︎

  4. Zenke, F., Poole, B., & Ganguli, S. (2017). Continual learning through synaptic intelligence. ICML.↩︎

  5. Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI.↩︎

  6. Rusu, A. A., Rabinowitz, N. C., Desjardins, G., et al. (2016). Progressive neural networks. arXiv:1606.04671.↩︎

  7. Yoon, J., Yang, E., Lee, J., & Hwang, S. J. (2018). Lifelong learning with dynamically expandable networks. ICLR.↩︎

  8. Mallya, A., & Lazebnik, S. (2018). PackNet: Adding multiple tasks to a single network by iterative pruning. CVPR.↩︎

  9. Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.↩︎

  10. Chaudhry, A., Ranzato, M., Rohrbach, M., & Elhoseiny, M. (2019). Efficient lifelong learning with A-GEM. ICLR.↩︎

  11. Robins, A. (1995). Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science.↩︎

  12. Buzzega, P., Boschini, M., Porrello, A., et al. (2020). Dark experience for general continual learning: A strong, simple baseline. NeurIPS.↩︎

  13. Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. ICML.↩︎

  14. Riemer, M., Cases, I., Ajemian, R., et al. (2019). Learning to learn without forgetting by maximizing transfer and minimizing interference. ICLR.↩︎

  15. Javed, K., & White, M. (2019). Meta-learning representations for continual learning. NeurIPS.↩︎

  16. Beaulieu, S., Frati, L., Miconi, T., et al. (2020). Learning to continually learn rapidly from few and noisy data. arXiv:2006.10220.↩︎

  17. Abraham, W. C., & Robins, A. (2005). Memory retention – the synaptic stability versus plasticity dilemma. Trends in Neurosciences.↩︎

  18. French, R. M. (1999). Catastrophic forgetting in connectionist networks. Trends in Cognitive Sciences.↩︎

  19. Farajtabar, M., Azizan, N., Mott, A., & Li, A. (2020). Orthogonal gradient descent for continual learning. AISTATS.↩︎

  20. Golkar, S., Kagan, M., & Cho, K. (2019). Continual learning via neural pruning. arXiv:1903.04476.↩︎

  21. Wortsman, M., Ramanujan, V., Liu, R., et al. (2020). Supermasks in superposition. NeurIPS.↩︎

  22. Rebuffi, S. A., Kolesnikov, A., Sperl, G., & Lampert, C. H. (2017). iCaRL: Incremental classifier and representation learning. CVPR.↩︎

  23. Hou, S., Pan, X., Loy, C. C., et al. (2019). Learning a unified classifier incrementally via rebalancing. CVPR.↩︎

  24. Farquhar, S., & Gal, Y. (2018). Towards robust evaluations of continual learning. arXiv:1805.09733.↩︎

  • Post title:Transfer Learning (10): Continual Learning
  • Post author:Chen Kai
  • Create time:2024-12-27 15:00:00
  • Post link:https://www.chenk.top/transfer-learning-10-continual-learning/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments