人类可以不断学习新技能而不忘记旧知识,但神经网络在学习新任务时却常常"健忘"——这就是灾难性遗忘( Catastrophic Forgetting)。如何让模型像人一样终身学习,在掌握 100 个任务后依然记得第 1 个任务?持续学习( Continual Learning)给出了答案。
本文从灾难性遗忘的数学机理出发,系统讲解正则化、动态架构、记忆重放、元学习四大类方法的原理与实现,深入分析参数重要性估计、任务间知识迁移与遗忘-稳定性权衡,并提供从零实现 EWC 的完整代码( 250+行)。
持续学习的问题定义
任务序列
持续学习处理按时间顺序到达的任务序列:
每个任务
其中:
是任务数据 是任务损失函数 是任务 的模型(参数为 )
关键约束:学习任务
灾难性遗忘
现象:在任务
形式化为:
其中:
:学习任务 后立即在任务 上的准确率 :学习任务 后在任务 上的准确率
目标:最小化遗忘,同时保持对新任务的学习能力。
持续学习的三种场景
- 任务增量学习( Task-Incremental Learning):
- 推理时已知任务 ID
- 每个任务有独立的输出头
- 示例: MNIST → FashionMNIST → CIFAR10
- 领域增量学习( Domain-Incremental Learning):
- 同一任务,不同数据分布
- 示例:晴天图像 → 雨天图像 → 夜间图像
- 类别增量学习( Class-Incremental Learning):
- 推理时未知任务 ID
- 需要从所有学过的类别中预测
- 最困难的场景
评估指标
平均准确率( Average Accuracy):
平均遗忘度( Average Forgetting):
后向迁移( Backward Transfer):
:遗忘 :正向迁移
前向迁移( Forward Transfer):
:零样本性能(学习任务 前在任务 上的准确率)
灾难性遗忘的数学机理
损失曲面视角
神经网络的损失函数定义在高维参数空间
训练过程是优化
问题:
- 任务 1 的最优点
与 任务 2 的最优点 通常相距甚远 - 从
优化到 会离开任务 1 的低损失区域 - 梯度下降对参数的更新具有全局性,难以局部调整
梯度干扰
任务 1 和任务 2 的梯度可能相互冲突:
直觉:改善任务 2 的参数更新会恶化任务 1 的性能。
定义梯度冲突度:
其中
权重重要性
并非所有参数对旧任务都同等重要。定义参数
洞察:保护重要参数(
Fisher 信息矩阵
Fisher 信息矩阵衡量参数对损失的敏感度:
$$
F = _{(x,y) } $$
对角元素
$$
F_{ii} = $$
性质: -
正则化方法
Elastic Weight Consolidation (EWC)
EWC 的核心思想
EWC1在学习新任务时,对重要参数施加正则化约束,防止它们偏离旧任务的最优值。
目标函数:
其中: -
直觉:限制重要参数的变化,保护旧任务的知识。
Fisher 信息的计算
对于分类任务, Fisher 信息矩阵的对角元素为:
$$
F_i = _{n=1}^{N} ( )^2 $$
实践中,在任务 A 的数据上计算:
- 前向传播得到预测
2. 计算对数似然 3. 反向传播得到梯度 $ g_i = F_i = [g_i^2]$信 息 为 梯 度 平 方 的 期 望 :
多任务扩展
学习任务序列
问题: Fisher 信息的累积导致参数越来越"僵化"。
改进: Online EWC2,只保留当前 Fisher 信息和参数,避免累积:
其中:
Memory Aware Synapses (MAS)
MAS 的改进
MAS3指出 EWC 的局限: Fisher 信息只考虑最后一层的梯度,忽略了中间层的重要性。
MAS 的参数重要性定义为:
注意这里是输出对参数的梯度,而非损失对参数的梯度。
目标函数:
MAS vs EWC
| 维度 | EWC | MAS |
|---|---|---|
| 重要性度量 | Fisher 信息(梯度平方) | 输出敏感度(梯度绝对值) |
| 计算依赖 | 需要标签 | 无需标签 |
| 适用场景 | 监督学习 | 无监督/自监督 |
| 计算复杂度 | 中等 | 低 |
Synaptic Intelligence (SI)
SI 的在线更新
SI4在训练过程中在线计算参数重要性,而非在任务结束后。
参数重要性:
其中: -
直觉:参数在训练过程中移动的"路径长度"越长且对应的损失降低越多,该参数越重要。
目标函数:
Learning without Forgetting (LwF)
知识蒸馏的应用
LwF5利用知识蒸馏保持旧任务的输出分布。
损失函数包含两部分:
- 新任务损失:
2. 蒸馏损失(保持旧任务的输出):
总损失:
优势:无需保存旧任务数据,只需旧模型的预测。
劣势:需要保存旧模型的副本(存储开销)。
动态架构方法
Progressive Neural Networks
渐进式扩展
Progressive Networks6为每个新任务添加新的网络列:
$$
h_t^{(l)} = f( W_t^{(l)} h_t^{(l-1)} + {i<t} U{i t}^{(l)} h_i^{(l-1)} ) $$
其中: -
优势: - 完全避免灾难性遗忘(旧参数不变) - 支持前向迁移(横向连接利用旧知识)
劣势: - 模型大小线性增长:
Dynamically Expandable Networks (DEN)
动态扩展策略
DEN7根据需要动态扩展网络容量:
- 选择性重训练( Selective Retraining):
- 冻结重要参数
- 只微调不重要参数
- 动态扩展( Dynamic Expansion):
- 如果现有容量不足,添加新神经元
- 决策标准:验证损失不再下降
- 网络分裂( Network Split/Duplication):
- 复制神经元并加入噪声
- 增加模型容量而不破坏旧知识
扩展算法:
对于第
$$
W^{(l)}_{} =$$
其中
稀疏正则化:
为了避免过度扩展, DEN 使用
第一项鼓励稀疏,第二项保护旧知识。
PackNet
二值掩码的巧妙设计
PackNet8通过二值掩码为每个任务分配不同的参数子集:
其中
关键约束:不同任务的掩码不重叠:
$$
M_i M_j = 0, i j $$
训练流程:
- 任务
到达时,冻结已被之前任务使用的参数: $$
M_{} = M_1 M_2 M_{t-1}
M_{} = 1 - M_{} $$3. 训练后通过剪枝确定任务
优势: - 参数复用率高 - 模型大小固定 - 完全避免遗忘
劣势: - 可用参数逐渐减少 - 后期任务性能受限
记忆重放方法
Gradient Episodic Memory (GEM)
约束优化视角
GEM9将持续学习建模为约束优化问题:
其中 $ g_t = _t()
直觉:新任务的梯度不能与旧任务的梯度冲突(负内积)。
梯度投影
如果梯度
$$
g_t' = _{g'} |g' - g_t|^2 g', g_i , i < t $$
这是一个二次规划问题,可以用现成求解器求解。
简化版:如果只有一个旧任务(
$$
g_t' = g_t - g_1 g_t, g_1 < 0 $$
记忆缓冲
GEM 为每个任务保存少量样本(如每任务 100 个):
旧任务的梯度
Averaged GEM (A-GEM)
计算效率的改进
A-GEM10简化 GEM 的约束:不要求与所有旧任务梯度非负内积,只要求与平均梯度非负内积。
平均梯度:
约束:
投影公式:
$$
g_t' = g_t - {g} g_t, {g} < 0 $$
优势: - 计算复杂度从
劣势: - 约束更宽松,遗忘可能略高
Experience Replay (ER)
最简单的重放
Experience Replay11在训练新任务时,混合旧任务的记忆样本:
其中
采样策略:
- 均匀采样:每个任务的样本数相等
- 按性能采样:对遗忘严重的任务多采样
- 按时间衰减:近期任务的样本权重更高
记忆缓冲的管理:
- Reservoir Sampling:等概率保留所有见过的样本
- Ring Buffer:固定大小,新样本替换旧样本
- Herding:选择最接近类别中心的样本
Dark Experience Replay (DER)
知识蒸馏与重放的结合
DER12在记忆缓冲中不仅保存样本
损失函数:
第二项是分类损失(记忆样本的真实标签),第三项是蒸馏损失(保持旧模型的输出)。
优势:蒸馏损失缓解了记忆样本的过拟合。
元学习方法
Model-Agnostic Meta-Learning for Continual Learning
MAML 的应用
MAML13通过元学习找到一个良好的初始化
在持续学习中, MAML 可以这样使用:
- 内循环:在当前任务上快速适配:
2. 外循环:更新元参数 ,使其对所有任务都表现良好:
问题:需要保留所有旧任务的数据(与持续学习的无数据假设冲突)。
改进: Meta-Experience Replay14,只在记忆缓冲上执行外循环更新。
Online Meta-Learning (OML)
在线元学习
OML15在持续学习中在线更新元参数:
表示学习器分为两部分: - 表示网络(
Representation):
更新策略:
- 任务到达时:快速适配预测头:
2. 任务结束后:慢更新表示:
优势:表示网络
Learning to Learn without Forgetting (Meta-LwF)
元学习的正则化
Meta-LwF16结合元学习和 LwF:
损失函数:
第一项是新任务损失,第二项是蒸馏损失(
LwF),第三项是元正则化(拉向元参数
直觉:
完整代码实现:从零实现 EWC
下面实现一个完整的 EWC 框架,包括 Fisher 信息计算、多任务训练、遗忘度评估与可视化。
1 | """ |
代码说明
核心组件:
- EWC 类:
compute_fisher():计算 Fisher 信息矩阵penalty():计算 EWC 正则化项
- 训练流程:
- 任务序列: Permuted MNIST(每个任务是 MNIST 的随机像素排列)
- 对比 Baseline(无正则化)和 EWC
- 评估指标:
- 准确率热图:展示每个任务在不同时间点的性能
- 平均准确率:所有任务的平均性能
- 遗忘度:第一个任务的准确率下降
关键细节:
- Fisher 信息在每个任务结束后计算
- EWC 惩罚项累积所有旧任务的约束
- 学习率、 EWC 强度
需要调优
持续学习的前沿进展
理论分析
稳定性-可塑性困境
持续学习面临根本困境17:
- 稳定性( Stability):保持旧知识 → 减少遗忘
- 可塑性( Plasticity):学习新知识 → 适应新任务
两者存在权衡:
最优
记忆容量分析
网络的记忆容量是指不发生灾难性遗忘的最大任务数18。
对于
$$
C $$
直觉:每个参数平均需要"记住"
最新方法
Orthogonal Gradient Descent (OGD)
OGD19将新任务的梯度投影到与旧任务梯度正交的子空间:
$$
g_t' = g_t - _{i<t} g_i $$
优势:完全消除梯度冲突。
劣势:需要存储所有旧任务的梯度。
Continual Backprop
Continual Backprop20修改反向传播算法,只更新对当前任务重要的参数。
更新规则:
其中
Supermasks in Superposition (SupSup)
SupSup21为每个任务学习一个二值掩码,所有任务共享同一组参数:
训练时只优化掩码
惊人发现:随机初始化的网络通过不同掩码可以达到多任务学习的性能!
基准与评估
标准基准
- Permuted MNIST: MNIST 的像素随机排列
- Split CIFAR: CIFAR-10 按类别分成多个任务
- CORe50: 50 个物体在不同场景下的图像
- Continual Reinforcement Learning: Atari 游戏序列
评估协议
标准评估包括:
- 任务内性能:每个任务单独训练的性能上界
- 平均准确率:所有任务的平均性能
- 后向迁移:学习新任务后旧任务性能的变化
- 前向迁移:旧任务对新任务的帮助
- 参数效率:模型大小随任务数的增长率
常见问题解答
Q1: EWC 的 如何选择?
经验规则:
- 小任务(如 Permuted MNIST):
- 中等任务(如 Split CIFAR):
- 大任务(如 ImageNet 子集):
调优策略:
- 先用较小的
(如 100)测试 - 观察遗忘度:如果遗忘严重,增大
4. 在验证集上网格搜索最优观 察 新 任 务 性 能 : 如 果 新 任 务 学 不 好 , 减 小
Q2: Fisher 信息何时计算?
推荐:在每个任务训练结束后立即计算。
原因:此时模型在该任务上达到最优, Fisher 信息最准确。
注意:如果任务数据量大,可以只在子集上计算 Fisher 信息(如每类 100 个样本)。
Q3: EWC 、 MAS 、 SI 如何选择?
| 场景 | 推荐方法 |
|---|---|
| 监督学习 | EWC |
| 无监督学习 | MAS |
| 在线学习(边训练边更新) | SI |
| 需要无标签计算重要性 | MAS |
性能对比: EWC ≈ MAS > SI(在大多数基准上)
Q4: 记忆缓冲应该多大?
经验值:
- 每任务: 50-200 个样本
- 总缓冲: 500-2000 个样本
权衡: - 更大的缓冲 → 更少遗忘,但存储和计算开销大 - 更小的缓冲 → 更高效,但可能遗忘严重
最优策略:根据可用内存和任务数动态调整。
Q5: GEM 与 A-GEM 哪个更好?
| 维度 | GEM | A-GEM |
|---|---|---|
| 遗忘控制 | 更严格 | 更宽松 |
| 计算复杂度 | 高(二次规划) | 低(线性投影) |
| 可扩展性 | 差(任务数多时慢) | 好 |
| 实现难度 | 难 | 简单 |
推荐:除非对遗忘极度敏感,优先使用 A-GEM(性能相近但快很多)。
Q6: 动态架构方法的劣势是什么?
Progressive Networks: - 模型大小线性增长:
DEN: - 训练复杂度高(需要动态扩展决策) - 超参数敏感(扩展阈值、稀疏系数) - 可能过度扩展
PackNet: - 后期任务性能受限(可用参数减少) - 剪枝策略影响大 - 任务数有上限(参数用完)
权衡:动态架构完全避免遗忘,但牺牲效率和可扩展性。
Q7: 元学习在持续学习中的作用?
优势: - 学习通用表示,减少任务特定参数 - 支持快速适配新任务 - 理论上优雅(最小化所有任务的元损失)
劣势: - 需要元训练阶段(任务分布已知) - 计算复杂度高(二阶导数) - 实践中性能提升有限
适用场景:任务相似度高、需要快速适配的场景(如少样本学习)。
Q8: 持续学习与多任务学习的区别?
| 维度 | 持续学习 | 多任务学习 |
|---|---|---|
| 任务可见性 | 序列到达 | 同时可见 |
| 数据访问 | 无法访问旧数据 | 所有数据可用 |
| 主要挑战 | 灾难性遗忘 | 任务平衡 |
| 目标 | 保持旧任务性能 | 所有任务平均性能 |
联系:持续学习可以看作是数据受限的多任务学习。
Q9: 如何处理类别增量学习( Class-IL)?
Class-IL 是最难的场景,需要特殊处理:
- 输出层扩展:每个新任务添加新类的输出神经元
- 偏置校正:新类的输出通常偏小(因为未经充分训练),需要校正
- 知识蒸馏:保持旧类的输出分布
- 记忆重放:混合旧类的样本
Q10: 持续学习能用在生产环境吗?
挑战:
- 遗忘不可接受:生产环境要求严格的性能保证
- 推理延迟:动态架构方法推理慢
- 模型更新频率:新任务到达速度快
实用策略:
- 混合方法: EWC + 少量记忆重放
- 周期性全量微调:每
个任务后,用记忆缓冲全量微调 - A/B 测试:持续学习模型与旧模型并行,比较性能
- 降级机制:如果新任务学习失败,回滚到旧模型
成功案例:推荐系统、语音识别、图像分类的增量更新。
Q11: 如何调试持续学习模型?
诊断步骤:
- 检查单任务性能:每个任务单独训练,确认基线性能
- 检查梯度冲突:计算不同任务梯度的内积,看是否存在负值
- 可视化 Fisher 信息:查看哪些参数被标记为重要
- 监控遗忘曲线:画出每个任务的准确率随时间的变化
- 消融实验:移除正则化/记忆重放,看性能下降多少
Q12: 持续学习的理论极限是什么?
信息论极限24:
对于
$$
T $$
其中
直觉:网络容量有限,任务数太多必然遗忘。
突破方向: - 利用任务相似性(共享表示) - 压缩旧任务知识(知识蒸馏) - 动态扩展容量(架构搜索)
小结
本文全面介绍了持续学习技术:
- 问题定义:灾难性遗忘的数学机理与评估指标
- 正则化方法: EWC 、 MAS 、 SI 、 LwF 的原理与对比
- 动态架构: Progressive Networks 、 DEN 、 PackNet 的设计
- 记忆重放: GEM 、 A-GEM 、 ER 、 DER 的策略
- 元学习: MAML 、 OML 在持续学习中的应用
- 完整代码:从零实现 EWC 的 250+行工程级代码
- 前沿进展:稳定性-可塑性困境、记忆容量理论、最新方法
持续学习让模型具备终身学习能力,是通用人工智能的重要基石。下一章我们将探讨跨语言迁移,看如何让模型在不同语言间无缝迁移知识。
参考文献
Kirkpatrick, J., Pascanu, R., Rabinowitz, N., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS.↩︎
Schwarz, J., Czarnecki, W., Luketina, J., et al. (2018). Progress & compress: A scalable framework for continual learning. ICML.↩︎
Aljundi, R., Babiloni, F., Elhoseiny, M., et al. (2018). Memory aware synapses: Learning what (not) to forget. ECCV.↩︎
Zenke, F., Poole, B., & Ganguli, S. (2017). Continual learning through synaptic intelligence. ICML.↩︎
Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI.↩︎
Rusu, A. A., Rabinowitz, N. C., Desjardins, G., et al. (2016). Progressive neural networks. arXiv:1606.04671.↩︎
Yoon, J., Yang, E., Lee, J., & Hwang, S. J. (2018). Lifelong learning with dynamically expandable networks. ICLR.↩︎
Mallya, A., & Lazebnik, S. (2018). PackNet: Adding multiple tasks to a single network by iterative pruning. CVPR.↩︎
Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.↩︎
Chaudhry, A., Ranzato, M., Rohrbach, M., & Elhoseiny, M. (2019). Efficient lifelong learning with A-GEM. ICLR.↩︎
Robins, A. (1995). Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science.↩︎
Buzzega, P., Boschini, M., Porrello, A., et al. (2020). Dark experience for general continual learning: A strong, simple baseline. NeurIPS.↩︎
Finn, C., Abbeel, P., & Levine, S. (2017). Model-agnostic meta-learning for fast adaptation of deep networks. ICML.↩︎
Riemer, M., Cases, I., Ajemian, R., et al. (2019). Learning to learn without forgetting by maximizing transfer and minimizing interference. ICLR.↩︎
Javed, K., & White, M. (2019). Meta-learning representations for continual learning. NeurIPS.↩︎
Beaulieu, S., Frati, L., Miconi, T., et al. (2020). Learning to continually learn rapidly from few and noisy data. arXiv:2006.10220.↩︎
Abraham, W. C., & Robins, A. (2005). Memory retention – the synaptic stability versus plasticity dilemma. Trends in Neurosciences.↩︎
French, R. M. (1999). Catastrophic forgetting in connectionist networks. Trends in Cognitive Sciences.↩︎
Farajtabar, M., Azizan, N., Mott, A., & Li, A. (2020). Orthogonal gradient descent for continual learning. AISTATS.↩︎
Golkar, S., Kagan, M., & Cho, K. (2019). Continual learning via neural pruning. arXiv:1903.04476.↩︎
Wortsman, M., Ramanujan, V., Liu, R., et al. (2020). Supermasks in superposition. NeurIPS.↩︎
Rebuffi, S. A., Kolesnikov, A., Sperl, G., & Lampert, C. H. (2017). iCaRL: Incremental classifier and representation learning. CVPR.↩︎
Hou, S., Pan, X., Loy, C. C., et al. (2019). Learning a unified classifier incrementally via rebalancing. CVPR.↩︎
Farquhar, S., & Gal, Y. (2018). Towards robust evaluations of continual learning. arXiv:1805.09733.↩︎
- 本文标题:迁移学习(十)—— 持续学习
- 本文作者:Chen Kai
- 创建时间:2024-12-27 15:00:00
- 本文链接:https://www.chenk.top/%E8%BF%81%E7%A7%BB%E5%AD%A6%E4%B9%A0%EF%BC%88%E5%8D%81%EF%BC%89%E2%80%94%E2%80%94-%E6%8C%81%E7%BB%AD%E5%AD%A6%E4%B9%A0/
- 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!