迁移学习(十)—— 持续学习
Chen Kai BOSS

人类可以不断学习新技能而不忘记旧知识,但神经网络在学习新任务时却常常"健忘"——这就是灾难性遗忘( Catastrophic Forgetting)。如何让模型像人一样终身学习,在掌握 100 个任务后依然记得第 1 个任务?持续学习( Continual Learning)给出了答案。

本文从灾难性遗忘的数学机理出发,系统讲解正则化、动态架构、记忆重放、元学习四大类方法的原理与实现,深入分析参数重要性估计、任务间知识迁移与遗忘-稳定性权衡,并提供从零实现 EWC 的完整代码( 250+行)。

持续学习的问题定义

任务序列

持续学习处理按时间顺序到达的任务序列:

每个任务 定义为一个学习问题:

其中:

  • 是任务数据
  • 是任务损失函数
  • 是任务 的模型(参数为

关键约束:学习任务 时,无法访问之前任务的数据

灾难性遗忘

现象:在任务 上训练后,模型在之前任务 上的性能急剧下降。

形式化为:

其中:

  • :学习任务 后立即在任务 上的准确率
  • :学习任务 后在任务 上的准确率

目标:最小化遗忘,同时保持对新任务的学习能力。

持续学习的三种场景

  1. 任务增量学习( Task-Incremental Learning)
    • 推理时已知任务 ID
    • 每个任务有独立的输出头
    • 示例: MNIST → FashionMNIST → CIFAR10
  2. 领域增量学习( Domain-Incremental Learning)
    • 同一任务,不同数据分布
    • 示例:晴天图像 → 雨天图像 → 夜间图像
  3. 类别增量学习( Class-Incremental Learning)
    • 推理时未知任务 ID
    • 需要从所有学过的类别中预测
    • 最困难的场景

评估指标

  1. 平均准确率( Average Accuracy)

  2. 平均遗忘度( Average Forgetting)

  3. 后向迁移( Backward Transfer)

    • :遗忘
    • :正向迁移
  4. 前向迁移( Forward Transfer)

    • :零样本性能(学习任务 前在任务 上的准确率)

灾难性遗忘的数学机理

损失曲面视角

神经网络的损失函数定义在高维参数空间

训练过程是优化 找到损失曲面的局部最优:

问题

  1. 任务 1 的最优点 任务 2 的最优点 通常相距甚远
  2. 优化到 会离开任务 1 的低损失区域
  3. 梯度下降对参数的更新具有全局性,难以局部调整

梯度干扰

任务 1 和任务 2 的梯度可能相互冲突:

直觉:改善任务 2 的参数更新会恶化任务 1 的性能。

定义梯度冲突度

Extra close brace or missing open brace\text{Conflict} = \frac{|\{(i,j): g_i \cdot g_j < 0} |}{T(T-1)/2}

其中 Missing open brace for subscript g_i = __i() 是任务 的梯度。

权重重要性

并非所有参数对旧任务都同等重要。定义参数 对任务 重要性

洞察:保护重要参数( 大),允许不重要参数改变。

Fisher 信息矩阵

Fisher 信息矩阵衡量参数对损失的敏感度:

$$

F = _{(x,y) } $$

对角元素 表示参数 的重要性:

$$

F_{ii} = $$

性质: - 大:参数 对预测影响大,应该保护 - 小:参数 对预测影响小,可以修改

正则化方法

Elastic Weight Consolidation (EWC)

EWC 的核心思想

EWC1在学习新任务时,对重要参数施加正则化约束,防止它们偏离旧任务的最优值。

目标函数:

其中: - :新任务 B 的损失 - :旧任务 A 的最优参数 - :参数 的 Fisher 信息(重要性) - :正则化强度

直觉:限制重要参数的变化,保护旧任务的知识。

Fisher 信息的计算

对于分类任务, Fisher 信息矩阵的对角元素为:

$$

F_i = _{n=1}^{N} ( )^2 $$

实践中,在任务 A 的数据上计算:

  1. 前向传播得到预测 2. 计算对数似然 3. 反向传播得到梯度 $ g_i = F_i = [g_i^2]$

多任务扩展

学习任务序列 时, EWC 的目标函数为:

问题: Fisher 信息的累积导致参数越来越"僵化"。

改进: Online EWC2,只保留当前 Fisher 信息和参数,避免累积:

其中:

是衰减因子(如 0.9)。

Memory Aware Synapses (MAS)

MAS 的改进

MAS3指出 EWC 的局限: Fisher 信息只考虑最后一层的梯度,忽略了中间层的重要性。

MAS 的参数重要性定义为:

注意这里是输出对参数的梯度,而非损失对参数的梯度。

目标函数:

MAS vs EWC

维度 EWC MAS
重要性度量 Fisher 信息(梯度平方) 输出敏感度(梯度绝对值)
计算依赖 需要标签 无需标签
适用场景 监督学习 无监督/自监督
计算复杂度 中等

Synaptic Intelligence (SI)

SI 的在线更新

SI4在训练过程中在线计算参数重要性,而非在任务结束后。

参数重要性:

其中: - :第 步的参数更新 - :第 步的梯度 - :参数的总变化量 - :防止除零的小常数

直觉:参数在训练过程中移动的"路径长度"越长且对应的损失降低越多,该参数越重要。

目标函数:

Learning without Forgetting (LwF)

知识蒸馏的应用

LwF5利用知识蒸馏保持旧任务的输出分布。

损失函数包含两部分:

  1. 新任务损失2. 蒸馏损失(保持旧任务的输出):

总损失:

优势:无需保存旧任务数据,只需旧模型的预测。

劣势:需要保存旧模型的副本(存储开销)。

动态架构方法

Progressive Neural Networks

渐进式扩展

Progressive Networks6为每个新任务添加新的网络列:

$$

h_t^{(l)} = f( W_t^{(l)} h_t^{(l-1)} + {i<t} U{i t}^{(l)} h_i^{(l-1)} ) $$

其中: - :任务 在第 层的激活 - :任务 的权重(可训练) - :任务 到任务 的横向连接(可训练) - 旧任务的参数 )完全冻结

优势: - 完全避免灾难性遗忘(旧参数不变) - 支持前向迁移(横向连接利用旧知识)

劣势: - 模型大小线性增长: 个任务需要 个网络列 - 推理开销随任务数增加

Dynamically Expandable Networks (DEN)

动态扩展策略

DEN7根据需要动态扩展网络容量:

  1. 选择性重训练( Selective Retraining)
    • 冻结重要参数
    • 只微调不重要参数
  2. 动态扩展( Dynamic Expansion)
    • 如果现有容量不足,添加新神经元
    • 决策标准:验证损失不再下降
  3. 网络分裂( Network Split/Duplication)
    • 复制神经元并加入噪声
    • 增加模型容量而不破坏旧知识

扩展算法

对于第 层,添加 个新神经元:

$$

W^{(l)}_{} =

$$

其中 是新神经元的权重。

稀疏正则化

为了避免过度扩展, DEN 使用 正则化:

第一项鼓励稀疏,第二项保护旧知识。

PackNet

二值掩码的巧妙设计

PackNet8通过二值掩码为每个任务分配不同的参数子集:

其中 Extra close brace or missing open braceM_t \in \{0, 1} ^{|\theta|} 是任务 的掩码。

关键约束:不同任务的掩码不重叠:

$$

M_i M_j = 0, i j $$

训练流程

  1. 任务 到达时,冻结已被之前任务使用的参数: $$

M_{} = M_1 M_2 M_{t-1}

M_{} = 1 - M_{} $$3. 训练后通过剪枝确定任务 的掩码 : - 保留重要参数(如权重绝对值最大的前 ) - 其余参数可供未来任务使用

优势: - 参数复用率高 - 模型大小固定 - 完全避免遗忘

劣势: - 可用参数逐渐减少 - 后期任务性能受限

记忆重放方法

Gradient Episodic Memory (GEM)

约束优化视角

GEM9将持续学习建模为约束优化问题:

其中 $ g_t = _t() g_i = _i()$ 是旧任务的梯度。

直觉:新任务的梯度不能与旧任务的梯度冲突(负内积)。

梯度投影

如果梯度 违反约束,将其投影到可行域:

$$

g_t' = _{g'} |g' - g_t|^2 g', g_i , i < t $$

这是一个二次规划问题,可以用现成求解器求解。

简化版:如果只有一个旧任务(),投影公式为:

$$

g_t' = g_t - g_1 g_t, g_1 < 0 $$

记忆缓冲

GEM 为每个任务保存少量样本(如每任务 100 个):

Extra close brace or missing open brace\mathcal{M} = \{(x_i, y_i)} _{i=1}^{M}

旧任务的梯度 在记忆缓冲 上计算。

Averaged GEM (A-GEM)

计算效率的改进

A-GEM10简化 GEM 的约束:不要求与所有旧任务梯度非负内积,只要求与平均梯度非负内积。

平均梯度:

约束:

投影公式

$$

g_t' = g_t - {g} g_t, {g} < 0 $$

优势: - 计算复杂度从 降到 是任务数) - 无需求解二次规划

劣势: - 约束更宽松,遗忘可能略高

Experience Replay (ER)

最简单的重放

Experience Replay11在训练新任务时,混合旧任务的记忆样本:

其中 是从所有旧任务中采样的记忆缓冲。

采样策略

  1. 均匀采样:每个任务的样本数相等
  2. 按性能采样:对遗忘严重的任务多采样
  3. 按时间衰减:近期任务的样本权重更高

记忆缓冲的管理

  • Reservoir Sampling:等概率保留所有见过的样本
  • Ring Buffer:固定大小,新样本替换旧样本
  • Herding:选择最接近类别中心的样本

Dark Experience Replay (DER)

知识蒸馏与重放的结合

DER12在记忆缓冲中不仅保存样本 ,还保存模型的输出

损失函数:

第二项是分类损失(记忆样本的真实标签),第三项是蒸馏损失(保持旧模型的输出)。

优势:蒸馏损失缓解了记忆样本的过拟合。

元学习方法

Model-Agnostic Meta-Learning for Continual Learning

MAML 的应用

MAML13通过元学习找到一个良好的初始化 ,使得从 快速适配任何任务。

在持续学习中, MAML 可以这样使用:

  1. 内循环:在当前任务上快速适配: 2. 外循环:更新元参数 ,使其对所有任务都表现良好:

问题:需要保留所有旧任务的数据(与持续学习的无数据假设冲突)。

改进: Meta-Experience Replay14,只在记忆缓冲上执行外循环更新。

Online Meta-Learning (OML)

在线元学习

OML15在持续学习中在线更新元参数:

表示学习器分为两部分: - 表示网络( Representation),参数 (慢更新) - 预测头( Prediction Head),参数 (快更新)

更新策略:

  1. 任务到达时:快速适配预测头: 2. 任务结束后:慢更新表示:

优势:表示网络 学习通用特征,预测头 学习任务特定知识。

Learning to Learn without Forgetting (Meta-LwF)

元学习的正则化

Meta-LwF16结合元学习和 LwF:

损失函数:

第一项是新任务损失,第二项是蒸馏损失( LwF),第三项是元正则化(拉向元参数 )。

直觉 是元学习得到的"通用参数",任务特定参数 应接近

完整代码实现:从零实现 EWC

下面实现一个完整的 EWC 框架,包括 Fisher 信息计算、多任务训练、遗忘度评估与可视化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
"""
从零实现 EWC: Elastic Weight Consolidation
包含: Fisher 信息计算、多任务训练、遗忘度评估
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
from copy import deepcopy

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# ============================================================================
# 简单的 MLP 模型
# ============================================================================

class SimpleMLP(nn.Module):
"""
简单的多层感知机
"""
def __init__(self, input_dim: int = 784, hidden_dims: List[int] = [256, 256], output_dim: int = 10):
super().__init__()

layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim

layers.append(nn.Linear(prev_dim, output_dim))

self.network = nn.Sequential(*layers)

def forward(self, x):
return self.network(x)

# ============================================================================
# EWC 实现
# ============================================================================

class EWC:
"""
Elastic Weight Consolidation
"""
def __init__(self, model: nn.Module, dataloader: DataLoader, device: str = 'cpu'):
self.model = model
self.dataloader = dataloader
self.device = device

# 存储每个任务的 Fisher 信息和参数
self.fisher_dict: Dict[int, Dict[str, torch.Tensor]] = {}
self.optpar_dict: Dict[int, Dict[str, torch.Tensor]] = {}

def compute_fisher(self, task_id: int):
"""
计算 Fisher 信息矩阵的对角元素
"""
self.model.eval()

# 初始化 Fisher 信息
fisher = {}
for name, param in self.model.named_parameters():
fisher[name] = torch.zeros_like(param)

# 在数据上累积梯度平方
num_samples = 0
for inputs, targets in self.dataloader:
inputs = inputs.to(self.device)
targets = targets.to(self.device)

self.model.zero_grad()

# 前向传播
outputs = self.model(inputs)

# 计算负对数似然
loss = F.cross_entropy(outputs, targets)

# 反向传播
loss.backward()

# 累积梯度平方
for name, param in self.model.named_parameters():
if param.grad is not None:
fisher[name] += param.grad.data ** 2

num_samples += inputs.size(0)

# 平均 Fisher 信息
for name in fisher:
fisher[name] /= num_samples

# 存储 Fisher 信息
self.fisher_dict[task_id] = fisher

# 存储当前参数
optpar = {}
for name, param in self.model.named_parameters():
optpar[name] = param.data.clone()
self.optpar_dict[task_id] = optpar

print(f"Fisher information computed for task {task_id}")

def penalty(self) -> torch.Tensor:
"""
计算 EWC 惩罚项
"""
loss = 0.0

for task_id in self.fisher_dict:
for name, param in self.model.named_parameters():
fisher = self.fisher_dict[task_id][name]
optpar = self.optpar_dict[task_id][name]
loss += (fisher * (param - optpar) ** 2).sum()

return loss

# ============================================================================
# 训练函数
# ============================================================================

def train_task(
model: nn.Module,
dataloader: DataLoader,
ewc: EWC,
ewc_lambda: float,
optimizer: optim.Optimizer,
device: str,
num_epochs: int = 10
) -> List[float]:
"""
在单个任务上训练(带 EWC 正则化)
"""
model.train()
losses = []

for epoch in range(num_epochs):
epoch_loss = 0.0

for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)

# 前向传播
outputs = model(inputs)

# 任务损失
task_loss = F.cross_entropy(outputs, targets)

# EWC 惩罚项
ewc_loss = ewc.penalty() if len(ewc.fisher_dict) > 0 else 0.0

# 总损失
loss = task_loss + ewc_lambda * ewc_loss

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_loss += loss.item()

avg_loss = epoch_loss / len(dataloader)
losses.append(avg_loss)

if (epoch + 1) % 2 == 0:
print(f" Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

return losses

def evaluate_task(model: nn.Module, dataloader: DataLoader, device: str) -> float:
"""
评估任务准确率
"""
model.eval()
correct = 0
total = 0

with torch.no_grad():
for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)

outputs = model(inputs)
_, predicted = torch.max(outputs, 1)

correct += (predicted == targets).sum().item()
total += targets.size(0)

accuracy = 100 * correct / total
return accuracy

# ============================================================================
# 生成多任务数据集( Permuted MNIST)
# ============================================================================

def create_permuted_mnist_tasks(num_tasks: int = 5, num_samples: int = 1000) -> List[Tuple[DataLoader, DataLoader]]:
"""
创建 Permuted MNIST 任务序列
每个任务是 MNIST 的一个随机像素排列
"""
tasks = []

# 生成随机 MNIST 数据(简化版)
for task_id in range(num_tasks):
# 生成随机数据(模拟 MNIST)
X_train = torch.randn(num_samples, 784)
y_train = torch.randint(0, 10, (num_samples,))

X_test = torch.randn(200, 784)
y_test = torch.randint(0, 10, (200,))

# 应用随机排列(模拟 Permuted MNIST)
if task_id > 0:
perm = torch.randperm(784)
X_train = X_train[:, perm]
X_test = X_test[:, perm]

# 创建 DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

tasks.append((train_loader, test_loader))

return tasks

# ============================================================================
# 主实验:对比 Baseline 和 EWC
# ============================================================================

def run_continual_learning_experiment(
num_tasks: int = 5,
ewc_lambda: float = 5000.0,
num_epochs: int = 10,
device: str = 'cpu'
):
"""
运行持续学习实验
"""
print("="*70)
print("Continual Learning Experiment: Baseline vs EWC")
print("="*70)

# 创建任务序列
print(f"\nCreating {num_tasks} permuted MNIST tasks...")
tasks = create_permuted_mnist_tasks(num_tasks=num_tasks)

# ========================================================================
# 方法 1: Baseline(无正则化)
# ========================================================================
print("\n" + "="*70)
print("Method 1: Baseline (No Regularization)")
print("="*70)

model_baseline = SimpleMLP().to(device)
baseline_accuracies = np.zeros((num_tasks, num_tasks))

for task_id in range(num_tasks):
print(f"\n--- Training Task {task_id+1} ---")

train_loader, _ = tasks[task_id]
optimizer = optim.SGD(model_baseline.parameters(), lr=0.01, momentum=0.9)

train_task(model_baseline, train_loader, EWC(model_baseline, train_loader, device),
ewc_lambda=0.0, optimizer=optimizer, device=device, num_epochs=num_epochs)

# 评估所有任务
print(f"\nEvaluating all tasks after training Task {task_id+1}:")
for eval_task_id in range(task_id + 1):
_, test_loader = tasks[eval_task_id]
acc = evaluate_task(model_baseline, test_loader, device)
baseline_accuracies[task_id, eval_task_id] = acc
print(f" Task {eval_task_id+1}: {acc:.2f}%")

# ========================================================================
# 方法 2: EWC
# ========================================================================
print("\n" + "="*70)
print("Method 2: EWC")
print("="*70)

model_ewc = SimpleMLP().to(device)
ewc = EWC(model_ewc, tasks[0][0], device) # 初始化 EWC
ewc_accuracies = np.zeros((num_tasks, num_tasks))

for task_id in range(num_tasks):
print(f"\n--- Training Task {task_id+1} ---")

train_loader, _ = tasks[task_id]
optimizer = optim.SGD(model_ewc.parameters(), lr=0.01, momentum=0.9)

train_task(model_ewc, train_loader, ewc, ewc_lambda=ewc_lambda,
optimizer=optimizer, device=device, num_epochs=num_epochs)

# 计算 Fisher 信息
ewc.dataloader = train_loader
ewc.compute_fisher(task_id)

# 评估所有任务
print(f"\nEvaluating all tasks after training Task {task_id+1}:")
for eval_task_id in range(task_id + 1):
_, test_loader = tasks[eval_task_id]
acc = evaluate_task(model_ewc, test_loader, device)
ewc_accuracies[task_id, eval_task_id] = acc
print(f" Task {eval_task_id+1}: {acc:.2f}%")

return baseline_accuracies, ewc_accuracies

# ============================================================================
# 可视化
# ============================================================================

def plot_continual_learning_results(baseline_acc: np.ndarray, ewc_acc: np.ndarray):
"""
绘制持续学习结果
"""
num_tasks = baseline_acc.shape[0]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. 准确率热图( Baseline)
im1 = axes[0].imshow(baseline_acc, cmap='YlGnBu', vmin=0, vmax=100)
axes[0].set_xlabel('Evaluated Task', fontsize=12)
axes[0].set_ylabel('After Training Task', fontsize=12)
axes[0].set_title('Baseline: Accuracy Heatmap (%)', fontsize=14, fontweight='bold')
axes[0].set_xticks(range(num_tasks))
axes[0].set_yticks(range(num_tasks))
axes[0].set_xticklabels([f'T{i+1}' for i in range(num_tasks)])
axes[0].set_yticklabels([f'T{i+1}' for i in range(num_tasks)])

# 添加数值标注
for i in range(num_tasks):
for j in range(i + 1):
text = axes[0].text(j, i, f'{baseline_acc[i, j]:.1f}',
ha="center", va="center", color="black", fontsize=10)

plt.colorbar(im1, ax=axes[0])

# 2. 准确率热图( EWC)
im2 = axes[1].imshow(ewc_acc, cmap='YlGnBu', vmin=0, vmax=100)
axes[1].set_xlabel('Evaluated Task', fontsize=12)
axes[1].set_ylabel('After Training Task', fontsize=12)
axes[1].set_title('EWC: Accuracy Heatmap (%)', fontsize=14, fontweight='bold')
axes[1].set_xticks(range(num_tasks))
axes[1].set_yticks(range(num_tasks))
axes[1].set_xticklabels([f'T{i+1}' for i in range(num_tasks)])
axes[1].set_yticklabels([f'T{i+1}' for i in range(num_tasks)])

# 添加数值标注
for i in range(num_tasks):
for j in range(i + 1):
text = axes[1].text(j, i, f'{ewc_acc[i, j]:.1f}',
ha="center", va="center", color="black", fontsize=10)

plt.colorbar(im2, ax=axes[1])

# 3. 平均准确率和遗忘度对比
avg_acc_baseline = [baseline_acc[i, :i+1].mean() for i in range(num_tasks)]
avg_acc_ewc = [ewc_acc[i, :i+1].mean() for i in range(num_tasks)]

# 遗忘度:第一个任务的准确率下降
forgetting_baseline = [baseline_acc[0, 0] - baseline_acc[i, 0] for i in range(num_tasks)]
forgetting_ewc = [ewc_acc[0, 0] - ewc_acc[i, 0] for i in range(num_tasks)]

x = np.arange(1, num_tasks + 1)
width = 0.35

ax3_1 = axes[2]
ax3_1.plot(x, avg_acc_baseline, marker='o', label='Baseline - Avg Acc',
linewidth=2, color='tab:blue')
ax3_1.plot(x, avg_acc_ewc, marker='s', label='EWC - Avg Acc',
linewidth=2, color='tab:green')
ax3_1.set_xlabel('Number of Tasks Trained', fontsize=12)
ax3_1.set_ylabel('Average Accuracy (%)', fontsize=12, color='tab:blue')
ax3_1.tick_params(axis='y', labelcolor='tab:blue')
ax3_1.legend(loc='upper left')
ax3_1.grid(True, alpha=0.3)

ax3_2 = ax3_1.twinx()
ax3_2.plot(x, forgetting_baseline, marker='o', label='Baseline - Forgetting',
linewidth=2, linestyle='--', color='tab:red')
ax3_2.plot(x, forgetting_ewc, marker='s', label='EWC - Forgetting',
linewidth=2, linestyle='--', color='tab:orange')
ax3_2.set_ylabel('Forgetting on Task 1 (%)', fontsize=12, color='tab:red')
ax3_2.tick_params(axis='y', labelcolor='tab:red')
ax3_2.legend(loc='upper right')

axes[2].set_title('Average Accuracy & Forgetting', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('ewc_continual_learning.png', dpi=150, bbox_inches='tight')
plt.close()
print("\nResults saved to ewc_continual_learning.png")

# ============================================================================
# 主函数
# ============================================================================

def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 运行实验
baseline_acc, ewc_acc = run_continual_learning_experiment(
num_tasks=5,
ewc_lambda=5000.0,
num_epochs=10,
device=device
)

# 计算最终指标
print("\n" + "="*70)
print("Final Results")
print("="*70)

num_tasks = baseline_acc.shape[0]

# 平均准确率
avg_acc_baseline = baseline_acc[-1, :].mean()
avg_acc_ewc = ewc_acc[-1, :].mean()
print(f"\nAverage Accuracy (after all tasks):")
print(f" Baseline: {avg_acc_baseline:.2f}%")
print(f" EWC: {avg_acc_ewc:.2f}%")
print(f" Improvement: {avg_acc_ewc - avg_acc_baseline:.2f}%")

# 遗忘度(以第一个任务为例)
forgetting_baseline = baseline_acc[0, 0] - baseline_acc[-1, 0]
forgetting_ewc = ewc_acc[0, 0] - ewc_acc[-1, 0]
print(f"\nForgetting on Task 1:")
print(f" Baseline: {forgetting_baseline:.2f}%")
print(f" EWC: {forgetting_ewc:.2f}%")
print(f" Reduction: {forgetting_baseline - forgetting_ewc:.2f}%")

# 绘图
plot_continual_learning_results(baseline_acc, ewc_acc)

print("\n" + "="*70)
print("Experiment completed!")
print("="*70)

if __name__ == "__main__":
main()

代码说明

核心组件

  1. EWC 类
    • compute_fisher():计算 Fisher 信息矩阵
    • penalty():计算 EWC 正则化项
  2. 训练流程
    • 任务序列: Permuted MNIST(每个任务是 MNIST 的随机像素排列)
    • 对比 Baseline(无正则化)和 EWC
  3. 评估指标
    • 准确率热图:展示每个任务在不同时间点的性能
    • 平均准确率:所有任务的平均性能
    • 遗忘度:第一个任务的准确率下降

关键细节

  • Fisher 信息在每个任务结束后计算
  • EWC 惩罚项累积所有旧任务的约束
  • 学习率、 EWC 强度 需要调优

持续学习的前沿进展

理论分析

稳定性-可塑性困境

持续学习面临根本困境17

  • 稳定性( Stability):保持旧知识 → 减少遗忘
  • 可塑性( Plasticity):学习新知识 → 适应新任务

两者存在权衡:

最优 取决于任务相似度、数据分布漂移等因素。

记忆容量分析

网络的记忆容量是指不发生灾难性遗忘的最大任务数18

对于 个参数的网络,记忆容量上界为:

$$

C $$

直觉:每个参数平均需要"记住" 比特信息。

最新方法

Orthogonal Gradient Descent (OGD)

OGD19将新任务的梯度投影到与旧任务梯度正交的子空间:

$$

g_t' = g_t - _{i<t} g_i $$

优势:完全消除梯度冲突。

劣势:需要存储所有旧任务的梯度。

Continual Backprop

Continual Backprop20修改反向传播算法,只更新对当前任务重要的参数。

更新规则:

其中 是任务 的参数掩码,自动学习。

Supermasks in Superposition (SupSup)

SupSup21为每个任务学习一个二值掩码,所有任务共享同一组参数:

训练时只优化掩码 ,参数 随机初始化后固定。

惊人发现:随机初始化的网络通过不同掩码可以达到多任务学习的性能!

基准与评估

标准基准

  1. Permuted MNIST: MNIST 的像素随机排列
  2. Split CIFAR: CIFAR-10 按类别分成多个任务
  3. CORe50: 50 个物体在不同场景下的图像
  4. Continual Reinforcement Learning: Atari 游戏序列

评估协议

标准评估包括:

  1. 任务内性能:每个任务单独训练的性能上界
  2. 平均准确率:所有任务的平均性能
  3. 后向迁移:学习新任务后旧任务性能的变化
  4. 前向迁移:旧任务对新任务的帮助
  5. 参数效率:模型大小随任务数的增长率

常见问题解答

Q1: EWC 的 如何选择?

经验规则:

  • 小任务(如 Permuted MNIST):
  • 中等任务(如 Split CIFAR):
  • 大任务(如 ImageNet 子集): 调优策略
  1. 先用较小的 (如 100)测试
  2. 观察遗忘度:如果遗忘严重,增大 4. 在验证集上网格搜索最优

Q2: Fisher 信息何时计算?

推荐:在每个任务训练结束后立即计算。

原因:此时模型在该任务上达到最优, Fisher 信息最准确。

注意:如果任务数据量大,可以只在子集上计算 Fisher 信息(如每类 100 个样本)。

Q3: EWC 、 MAS 、 SI 如何选择?

场景 推荐方法
监督学习 EWC
无监督学习 MAS
在线学习(边训练边更新) SI
需要无标签计算重要性 MAS

性能对比: EWC ≈ MAS > SI(在大多数基准上)

Q4: 记忆缓冲应该多大?

经验值:

  • 每任务: 50-200 个样本
  • 总缓冲: 500-2000 个样本

权衡: - 更大的缓冲 → 更少遗忘,但存储和计算开销大 - 更小的缓冲 → 更高效,但可能遗忘严重

最优策略:根据可用内存和任务数动态调整。

Q5: GEM 与 A-GEM 哪个更好?

维度 GEM A-GEM
遗忘控制 更严格 更宽松
计算复杂度 高(二次规划) 低(线性投影)
可扩展性 差(任务数多时慢)
实现难度 简单

推荐:除非对遗忘极度敏感,优先使用 A-GEM(性能相近但快很多)。

Q6: 动态架构方法的劣势是什么?

Progressive Networks: - 模型大小线性增长: 个任务需要 倍参数 - 推理时间线性增长 - 部署困难(模型太大)

DEN: - 训练复杂度高(需要动态扩展决策) - 超参数敏感(扩展阈值、稀疏系数) - 可能过度扩展

PackNet: - 后期任务性能受限(可用参数减少) - 剪枝策略影响大 - 任务数有上限(参数用完)

权衡:动态架构完全避免遗忘,但牺牲效率和可扩展性。

Q7: 元学习在持续学习中的作用?

优势: - 学习通用表示,减少任务特定参数 - 支持快速适配新任务 - 理论上优雅(最小化所有任务的元损失)

劣势: - 需要元训练阶段(任务分布已知) - 计算复杂度高(二阶导数) - 实践中性能提升有限

适用场景:任务相似度高、需要快速适配的场景(如少样本学习)。

Q8: 持续学习与多任务学习的区别?

维度 持续学习 多任务学习
任务可见性 序列到达 同时可见
数据访问 无法访问旧数据 所有数据可用
主要挑战 灾难性遗忘 任务平衡
目标 保持旧任务性能 所有任务平均性能

联系:持续学习可以看作是数据受限的多任务学习。

Q9: 如何处理类别增量学习( Class-IL)?

Class-IL 是最难的场景,需要特殊处理:

  1. 输出层扩展:每个新任务添加新类的输出神经元
  2. 偏置校正:新类的输出通常偏小(因为未经充分训练),需要校正
  3. 知识蒸馏:保持旧类的输出分布
  4. 记忆重放:混合旧类的样本

推荐方法: iCaRL22、 LUCIR23

Q10: 持续学习能用在生产环境吗?

挑战

  1. 遗忘不可接受:生产环境要求严格的性能保证
  2. 推理延迟:动态架构方法推理慢
  3. 模型更新频率:新任务到达速度快

实用策略

  1. 混合方法: EWC + 少量记忆重放
  2. 周期性全量微调:每 个任务后,用记忆缓冲全量微调
  3. A/B 测试:持续学习模型与旧模型并行,比较性能
  4. 降级机制:如果新任务学习失败,回滚到旧模型

成功案例:推荐系统、语音识别、图像分类的增量更新。

Q11: 如何调试持续学习模型?

诊断步骤:

  1. 检查单任务性能:每个任务单独训练,确认基线性能
  2. 检查梯度冲突:计算不同任务梯度的内积,看是否存在负值
  3. 可视化 Fisher 信息:查看哪些参数被标记为重要
  4. 监控遗忘曲线:画出每个任务的准确率随时间的变化
  5. 消融实验:移除正则化/记忆重放,看性能下降多少

Q12: 持续学习的理论极限是什么?

信息论极限24

对于 个参数的网络,不发生遗忘地学习 个任务,需要:

$$

T $$

其中 是任务的互信息。

直觉:网络容量有限,任务数太多必然遗忘。

突破方向: - 利用任务相似性(共享表示) - 压缩旧任务知识(知识蒸馏) - 动态扩展容量(架构搜索)

小结

本文全面介绍了持续学习技术:

  1. 问题定义:灾难性遗忘的数学机理与评估指标
  2. 正则化方法: EWC 、 MAS 、 SI 、 LwF 的原理与对比
  3. 动态架构: Progressive Networks 、 DEN 、 PackNet 的设计
  4. 记忆重放: GEM 、 A-GEM 、 ER 、 DER 的策略
  5. 元学习: MAML 、 OML 在持续学习中的应用
  6. 完整代码:从零实现 EWC 的 250+行工程级代码
  7. 前沿进展:稳定性-可塑性困境、记忆容量理论、最新方法

持续学习让模型具备终身学习能力,是通用人工智能的重要基石。下一章我们将探讨跨语言迁移,看如何让模型在不同语言间无缝迁移知识。

参考文献


  1. Kirkpatrick, J., Pascanu, R., Rabinowitz, N., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS.↩︎

  2. Schwarz, J., Czarnecki, W., Luketina, J., et al. (2018). Progress & compress: A scalable framework for continual learning. ICML.↩︎

  3. Aljundi, R., Babiloni, F., Elhoseiny, M., et al. (2018). Memory aware synapses: Learning what (not) to forget. ECCV.↩︎

  4. Zenke, F., Poole, B., & Ganguli, S. (2017). Continual learning through synaptic intelligence. ICML.↩︎

  5. Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI.↩︎

  6. Rusu, A. A., Rabinowitz, N. C., Desjardins, G., et al. (2016). Progressive neural networks. arXiv:1606.04671.↩︎

  7. Yoon, J., Yang, E., Lee, J., & Hwang, S. J. (2018). Lifelong learning with dynamically expandable networks. ICLR.↩︎

  8. Mallya, A., & Lazebnik, S. (2018). PackNet: Adding multiple tasks to a single network by iterative pruning. CVPR.↩︎

  9. Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.↩︎

  10. Chaudhry, A., Ranzato, M., Rohrbach, M., & Elhoseiny, M. (2019). Efficient lifelong learning with A-GEM. ICLR.↩︎

  11. Robins, A. (1995). Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science.↩︎

  12. Buzzega, P., Boschini, M., Porrello, A., et al. (2020). Dark experience for general continual learning: A strong, simple baseline. NeurIPS.↩︎

  13. Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. ICML.↩︎

  14. Riemer, M., Cases, I., Ajemian, R., et al. (2019). Learning to learn without forgetting by maximizing transfer and minimizing interference. ICLR.↩︎

  15. Javed, K., & White, M. (2019). Meta-learning representations for continual learning. NeurIPS.↩︎

  16. Beaulieu, S., Frati, L., Miconi, T., et al. (2020). Learning to continually learn rapidly from few and noisy data. arXiv:2006.10220.↩︎

  17. Abraham, W. C., & Robins, A. (2005). Memory retention – the synaptic stability versus plasticity dilemma. Trends in Neurosciences.↩︎

  18. French, R. M. (1999). Catastrophic forgetting in connectionist networks. Trends in Cognitive Sciences.↩︎

  19. Farajtabar, M., Azizan, N., Mott, A., & Li, A. (2020). Orthogonal gradient descent for continual learning. AISTATS.↩︎

  20. Golkar, S., Kagan, M., & Cho, K. (2019). Continual learning via neural pruning. arXiv:1903.04476.↩︎

  21. Wortsman, M., Ramanujan, V., Liu, R., et al. (2020). Supermasks in superposition. NeurIPS.↩︎

  22. Rebuffi, S. A., Kolesnikov, A., Sperl, G., & Lampert, C. H. (2017). iCaRL: Incremental classifier and representation learning. CVPR.↩︎

  23. Hou, S., Pan, X., Loy, C. C., et al. (2019). Learning a unified classifier incrementally via rebalancing. CVPR.↩︎

  24. Farquhar, S., & Gal, Y. (2018). Towards robust evaluations of continual learning. arXiv:1805.09733.↩︎

  • 本文标题:迁移学习(十)—— 持续学习
  • 本文作者:Chen Kai
  • 创建时间:2024-12-27 15:00:00
  • 本文链接:https://www.chenk.top/%E8%BF%81%E7%A7%BB%E5%AD%A6%E4%B9%A0%EF%BC%88%E5%8D%81%EF%BC%89%E2%80%94%E2%80%94-%E6%8C%81%E7%BB%AD%E5%AD%A6%E4%B9%A0/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论