迁移学习(五)—— 知识蒸馏
Chen Kai BOSS

知识蒸馏( Knowledge Distillation, KD)是一种模型压缩与迁移学习技术,通过让小模型(学生)学习大模型(教师)的知识,在显著减少参数量和计算量的同时保持接近教师模型的性能。 2015 年 Hinton 等人提出的经典论文"Distilling the Knowledge in a Neural Network"开启了这一领域的研究热潮。但知识蒸馏不仅仅是简单的"软标签"训练——背后涉及温度参数的调节、不同层次知识的提取、学生教师架构的匹配等诸多技术细节。

本文将从第一性原理出发,推导知识蒸馏的数学基础,解析软标签为什么包含比硬标签更多的信息,详细讲解响应式蒸馏、特征蒸馏、关系蒸馏的实现细节,介绍自蒸馏、相互学习、在线蒸馏等无需预训练教师的方法,并探讨量化、剪枝与蒸馏的协同优化。我们会看到,蒸馏本质上是一种知识的"压缩编码"——将教师模型隐式学到的暗知识( dark knowledge)显式地传递给学生模型。

知识蒸馏的动机:为什么需要蒸馏

从模型压缩到知识迁移

深度神经网络通常需要大量参数才能达到最优性能。但在实际部署中,大模型面临诸多挑战:

  • 移动端部署:手机、 IoT 设备的内存和计算能力有限,无法运行百亿参数模型
  • 推理延迟:自动驾驶、工业控制等实时系统要求毫秒级响应
  • 能耗限制:边缘设备需要长时间电池供电,大模型功耗过高
  • 成本优化:云端服务每天处理数十亿请求,模型越小成本越低

传统的模型压缩方法(剪枝、量化、低秩分解)直接操作模型结构或参数,往往导致明显的性能下降。知识蒸馏的核心思想是用小模型学习大模型的输出分布,而不是简单地拟合硬标签

暗知识:软标签的信息优势

考虑一个图像分类任务,假设真实标签是"猫"(硬标签是 one-hot 向量 )。一个训练好的教师模型可能输出如下概率分布:

$$

p_T = [{}, {}, {}, {}, ] $$

虽然教师预测的最高概率是"猫",但其他类别的概率也包含了有价值的信息:

  • "老虎"概率较高:说明这只猫和老虎有某些视觉相似性(体型、花纹等)
  • "狗"概率低但非零:表明猫狗有一定的共同特征(毛茸茸、四条腿)
  • "汽车"概率极低:说明猫和汽车的视觉特征完全不同

这些非零的"错误"概率就是 Hinton 所说的暗知识( dark knowledge)——它们揭示了类别之间的相似性结构,是教师模型在训练过程中学到的泛化能力的体现。

从信息论角度,硬标签的信息熵为 0(确定性的 one-hot 向量),而软标签的信息熵更高:

$$

H(p_T) = -_{i=1}^C p_T^{(i)} p_T^{(i)} > 0 $$

软标签提供了更丰富的监督信号,帮助学生模型学习类别之间的关系。

蒸馏的数学视角:分布匹配

设教师模型参数为 ,学生模型参数为 ,输入为 ,输出 logits 为 。标准分类训练最小化交叉熵:

$$

L_{} = -_{i=1}^C y_i (z_S)_i $$

其中 是硬标签( one-hot), 是 softmax 函数。

知识蒸馏则让学生模型匹配教师的输出分布:

$$

L_{} = -_{i=1}^C (z_T)_i (z_S)_i $$

这是两个分布的交叉熵,也等价于最小化 KL 散度(因为 是常数):

$$

L_{} = ((z_T) | (z_S)) + $$

从优化角度看,蒸馏是在让学生模型的输出分布 逼近教师的输出分布

温度参数:软化概率分布

直接使用 softmax 输出的问题是:概率分布往往过于"尖锐"( peaked),最大类的概率接近 1,其他类的概率接近 0,暗知识被抑制了。

Hinton 引入温度参数( temperature) 来软化分布:

$$

q_i = $$

时,概率分布变得更平滑:

  • :所有类的概率趋于均匀分布
  • :标准 softmax
  • :分布退化为 one-hot( argmax)

直觉例子:考虑 logits

  • (第 3 类信息几乎丢失)
  • (第 3 类信息被保留)

蒸馏损失在温度 下定义为:

$$

L_{}(T) = -_{i=1}^C q_T^{(i)} q_S^{(i)} $$

其中 都用温度 计算。

理论推导:为什么高温下梯度更稳定?

对 logit 求导(省略 normalization 项):

增大时,梯度的幅度按 缩放。但由于损失本身也随 变化,最终的梯度缩放因子是 (详见 Hinton 论文附录)。因此,实际训练时蒸馏损失需要乘以 来平衡梯度尺度:

$$

L_{} = T^2 L_{} (T) + (1 - ) L_{} $$

其中 是平衡系数, 是对硬标签的标准交叉熵损失。

响应式蒸馏:输出层的知识传递

响应式蒸馏( Response-based Distillation)是最经典的蒸馏方法,只利用模型最后一层的输出( logits 或概率)进行知识传递。

Hinton 的原始蒸馏算法

算法流程

  1. 训练教师模型:在完整数据集上训练一个高容量模型$ f_T x q_T = (f_T(x) / T)$

L = T^2 (q_T | q_S) + (1 - ) (y, q_S) $使T = 1$(标准 softmax)

超参数选择

  • 温度:通常取,任务相关。分类任务一般用
  • 平衡系数:通常取 越大,越依赖教师知识
  • 学生容量:一般是教师的 参数量

实验观察( ImageNet 实验):

  • ResNet-34 教师( 73.3%准确率)蒸馏到 ResNet-18 学生
  • 直接训练 ResNet-18: 69.8%
  • 蒸馏训练 ResNet-18: 71.4%
  • 提升 1.6%,但仍有 1.9%的 gap

为什么温度参数有效:信息论分析

从信息论角度,温度 控制了软标签的信息量。定义条件熵:

$$

H(Q_T) = -_{i=1}^C q_T^{(i)} q_T^{(i)} $$

可以证明 单调递增。更高的温度意味着更高的熵,即更多的不确定性和更丰富的信息。

具体地,当 很大时, softmax 可以泰勒展开:

$$

q_i (1 + {TC}) $$

其中。此时:

$$

q_i - q_j $$

这表明高温下, softmax 输出的相对差异直接反映了 logits 的相对差异,不受 exp 函数的非线性影响。学生模型可以更准确地学习类别之间的相对关系。

蒸馏与标签平滑的联系

标签平滑( Label Smoothing)是一种正则化技术,将硬标签 替换为:

$$

y_{} ^{(i)} = (1 - ) y^{(i)} + $$

其中 是平滑系数(通常取 0.1)。

可以证明,知识蒸馏在某种意义上是数据依赖的标签平滑

  • 标签平滑对所有样本使用相同的平滑分布(均匀分布)
  • 知识蒸馏对每个样本使用不同的平滑分布(教师的输出)

实验表明,蒸馏的效果通常优于标签平滑,因为教师的输出分布包含了样本特定的信息(例如某张猫的图片更像老虎)。

逐层蒸馏:多阶段知识传递

对于非常深的网络(如 ResNet-152),可以将蒸馏分解为多个阶段:

  1. 浅层蒸馏:用教师的前几层蒸馏学生的前几层
  2. 中层蒸馏:用教师的中间层蒸馏学生的中间层
  3. 深层蒸馏:用教师的最后几层蒸馏学生的最后几层

损失函数变为多项和:

$$

L = _{k=1}^K k L{} ^{(k)} + (1 - _k k) L{} $$

其中 是第 个蒸馏点的损失, 是权重。

优点:更细粒度的知识传递,适合教师和学生架构差异较大的情况。

缺点:需要手动设计蒸馏点位置和权重,超参数空间增大。

特征蒸馏:中间层的知识传递

特征蒸馏( Feature-based Distillation)不仅利用输出层,还利用中间层的特征图进行知识传递。

FitNets:提示学习

FitNets( Fitnets: Hints for Thin Deep Nets)是最早的特征蒸馏方法之一,由 Romero 等人于 2015 年提出。

核心思想:让学生的中间层特征图匹配教师的中间层特征图。

设教师在第 层的特征为,学生在第 层的特征为。由于维度可能不同,引入一个可学习的投影层

$$

L_{} = | _T - _r _S |_F^2 $$

其中 是 Frobenius 范数。

训练策略(两阶段):

  1. 阶段 1:冻结教师,只训练学生的前 层和投影层,最小化$L_{} l_SL_{} + L_{} $ 提示层位置选择
  • 浅层提示:学生学习低级特征(边缘、纹理),适合学生很小的情况
  • 深层提示:学生学习高级语义特征,适合学生容量接近教师的情况

实验发现,单个提示层效果有限,多个提示层效果更好(但增加计算成本)。

Attention Transfer:注意力图蒸馏

Zagoruyko 和 Komodakis 于 2017 年提出 Attention Transfer( AT),利用特征图的激活统计量作为"注意力图"进行蒸馏。

激活注意力( Activation-based Attention)

对特征图,定义注意力图:

其中$ p^{} $ 表示每个空间位置的激活强度。

损失函数为:

$$

L_{} ^{} = | - |_2^2 $$

归一化确保了尺度不变性。

梯度注意力( Gradient-based Attention)

除了激活,还可以用梯度作为注意力:

梯度注意力反映了哪些位置对损失贡献最大,捕获了模型的决策过程。

多层注意力传递

$$

L_{} = _{l } l L{} ^{(l)} $$

其中 是选定的层集合, 是权重。

实验结果( CIFAR-10):

  • ResNet-110 教师( 93.5%)ResNet-20 学生
  • 基线 ResNet-20: 91.3%
  • 响应式蒸馏: 91.8%
  • 注意力传递: 92.4%

注意力蒸馏比响应式蒸馏提升 0.6%,说明中间层知识传递的有效性。

PKT:概率知识传递

Lopez-Paz 等人于 2017 年提出 Probabilistic Knowledge Transfer( PKT),不匹配单个样本的特征,而是匹配特征分布的统计量。

核心思想:用样本对之间的相似性来表示知识。

对一批样本,计算特征相似性矩阵:

其中 是核函数(如高斯核)。损失函数为:

$$

L_{} = | _T - _S |_F^2 $$

这相当于匹配样本对之间的关系结构,而不是单个样本的特征值。

优点

  • 对特征维度的差异不敏感(不需要投影层)
  • 捕获了样本之间的语义关系

缺点

  • 计算复杂度, batch size 不能太大
  • 需要合理选择核函数和带宽

NST:神经风格迁移启发的蒸馏

Huang 和 Wang 于 2017 年受神经风格迁移( Neural Style Transfer)启发,提出用 Gram 矩阵进行特征蒸馏。

对特征图,将其 reshape 为(其中),定义 Gram 矩阵:

Gram 矩阵的元素 表示通道 和通道 的相关性。损失函数为:

$$

L_{} = _{l} | _T^{(l)} - _S^{(l)} |_F^2 $$

直觉: Gram 矩阵捕获了特征的二阶统计量(协方差),反映了不同通道之间的关系(例如"边缘检测器"和"纹理检测器"的共现模式)。

实验:在 CIFAR-100 上, NST 比 FitNets 和 AT 都有进一步提升(约 0.5%-1%)。

关系蒸馏:样本间关系的传递

关系蒸馏( Relation-based Distillation)不仅考虑单个样本的输出或特征,还考虑样本之间的关系。

RKD:关系知识蒸馏

Park 等人于 2019 年提出 Relational Knowledge Distillation( RKD),定义了两种关系:

距离关系( Distance-wise Relation)

对一对样本,定义归一化的欧氏距离:

其中 是归一化因子。

损失函数为:

$$

L_{} = _{(i,j) } ( _T(x_i, x_j) - _S(x_i, x_j) )^2 $$

其中 是采样的样本对集合。

角度关系( Angle-wise Relation)

对三元组,定义向量夹角:

其中

损失函数为:

$$

L_{} = _{(i,j,k) } ( _T(x_i, x_j, x_k) - _S(x_i, x_j, x_k) )^2 $$

直觉

  • 距离关系保证了样本对的相对距离保持一致(例如"猫"和"狗"的距离比"猫"和"汽车"的距离小)
  • 角度关系保证了样本的相对位置关系(例如"波斯猫"相对于"猫"和"狗"的方向)

总损失

$$

L_{} = D L{} + A L{} $$

实验表明,角度关系比距离关系更重要()。

CRD:对比表示蒸馏

Tian 等人于 2020 年提出 Contrastive Representation Distillation( CRD),将对比学习引入蒸馏框架。

核心思想:用对比学习的方式最大化学生和教师特征的互信息。

对一个正样本对(同一样本在教师和学生中的表示)和 个负样本(其他样本),定义 InfoNCE 损失:

$$

L_{} = - $$

关键区别

  • 传统蒸馏:用 MSE 或 KL 散度匹配特征
  • CRD:用对比学习匹配特征,更关注样本间的区分性

实验结果( CIFAR-100):

  • ResNet-32x4 教师( 79.4%)ResNet-8x4 学生
  • 响应式蒸馏: 73.3%
  • CRD: 75.5%

CRD 在小学生模型上尤其有效(提升 2%以上)。

SP:相似性保持蒸馏

Tung 和 Mori 于 2019 年提出 Similarity-Preserving Distillation( SP),要求学生的特征相似性矩阵和教师一致。

对一批样本,定义相似性矩阵:

损失函数为:

$$

L_{} = | _T - _S |_F^2 $$

与 PKT 的区别: SP 用余弦相似度, PKT 用核相似度。

自蒸馏:无教师的知识传递

自蒸馏( Self-Distillation)是一种无需预训练教师的蒸馏方法,模型从自己的早期版本或不同分支学习知识。

Born-Again Networks:迭代自蒸馏

Furlanello 等人于 2018 年提出 Born-Again Networks( BAN),通过迭代蒸馏提升模型性能。

算法流程

  1. 训练第 1 代模型:标准训练得到$M_1M_1$ 作为教师蒸馏 架构相同)
  2. 训练第 3 代模型:用 作为教师蒸馏4. 重复直到性能饱和

惊人发现:即使教师和学生架构完全相同,蒸馏仍然能提升性能!

理论解释

  • 蒸馏提供了更平滑的监督信号(软标签),减少了过拟合
  • 迭代蒸馏是一种集成学习的隐式形式
  • 每一代模型探索了损失曲面的不同区域

实验( CIFAR-100):

  • 第 1 代 DenseNet: 74.3%
  • 第 2 代( BAN): 75.2%
  • 第 3 代: 75.4%
  • 第 4 代: 75.5%(饱和)

Deep Mutual Learning:相互学习

Zhang 等人于 2018 年提出 Deep Mutual Learning( DML),让多个学生模型同时训练,互相作为教师。

算法流程

个学生模型,每个模型的损失包含两部分:

$$

L_k = L_{} (y, q_k) + _{j k} (q_j | q_k) $$

其中 是模型 的输出分布。

关键特点

  • 无需预训练教师:所有模型从头开始训练
  • 对称性:每个模型既是学生又是教师
  • 在线学习:模型实时学习彼此的知识

理论直觉

  • 每个模型在训练过程中会犯不同的错误
  • 相互学习让模型避免彼此的错误,类似于集成学习
  • 最终每个模型都比单独训练更好

实验( CIFAR-100):

  • 单独训练 ResNet-32: 70.2%
  • 2 个 ResNet-32 相互学习: 72.1%
  • 4 个 ResNet-32 相互学习: 72.8%

Online Distillation:在线蒸馏

在线蒸馏将多个学生模型的知识聚合为一个虚拟教师,避免了预训练教师的开销。

ONE( Online Network Ensemble)

Lan 等人于 2018 年提出,用多个分支的加权平均作为教师:

$$

q_{} = _{k=1}^K q_k $$

每个分支的损失为:

$$

L_k = L_{} (y, q_k) + (q_{} | q_k) $$

KDCL( Knowledge Distillation via Collaborative Learning)

Song 和 Chai 于 2018 年提出,除了分支间蒸馏,还在不同深度进行蒸馏:

$$

L = {d } {k=1}^K $$

其中 是选定的深度集合(如每隔 4 层)。

优势

  • 单次训练完成,节省时间
  • 最终可以使用任一分支,或集成多个分支

量化与剪枝的协同蒸馏

知识蒸馏常与量化、剪枝等压缩技术结合,实现更高的压缩比。

Quantization-aware Distillation:量化感知蒸馏

量化将浮点参数映射到低比特整数(如 8-bit 或 4-bit),但会导致精度下降。蒸馏可以缓解这一问题。

算法流程

  1. 训练全精度教师:标准 FP32 训练
  2. 量化学生初始化:将教师参数量化为 INT8 作为学生初始化
  3. 蒸馏微调:用教师的软标签微调量化学生

损失函数:

$$

L = L_{} (q_T, q_S^{} ) + (1 - ) L_{} (y, q_S^{} ) $$

其中 是量化后的学生输出。

量化细节

对权重,量化公式为:

$$

W^{} = ( ) s $$

其中 是缩放因子:

$$

s = $$ 是比特数。

实验( ResNet-18 on ImageNet):

  • FP32 基线: 69.8%
  • INT8 量化(无蒸馏): 68.5%(-1.3%)
  • INT8 量化(有蒸馏): 69.2%(-0.6%)

蒸馏将量化损失减少了一半。

Pruning-aware Distillation:剪枝感知蒸馏

剪枝移除不重要的神经元或连接,蒸馏可以帮助剪枝后的模型恢复性能。

算法流程

  1. 训练全模型教师
  2. 结构化剪枝:移除重要性低的通道或层(如用 L1-norm 判断)
  3. 蒸馏恢复:用教师软标签微调剪枝后的学生

重要性评估

对卷积层的通道,定义重要性:

$$

I_c = {k,i,j} |W{c,k,i,j}| $$

移除重要性最低的$ p% p = 50$)。

损失函数(多层蒸馏):

$$

L = L_{} ^{} + {l} l L{} ^{(l)} + L{} $$

其中 是中间层的特征蒸馏。

实验( VGG-16 on CIFAR-10):

  • 原始 VGG-16: 93.5%( 14.7M 参数)
  • 剪枝 70%(无蒸馏): 92.1%( 4.4M 参数)
  • 剪枝 70%(有蒸馏): 93.0%( 4.4M 参数)

蒸馏使剪枝后的模型几乎恢复到原始性能。

NAS + Distillation:神经架构搜索与蒸馏

神经架构搜索( NAS)可以自动找到高效的学生架构,结合蒸馏进一步提升性能。

MetaDistiller

Liu 等人于 2020 年提出,用强化学习搜索最优的蒸馏策略:

  • 哪些层进行蒸馏
  • 每层的损失权重
  • 温度参数 搜索空间大小为,其中 是层数, 是权重候选数。

用强化学习(如 PPO)优化搜索策略,奖励函数为学生在验证集上的准确率。

实验:在 CIFAR-100 上, MetaDistiller 找到的策略比手动设计的策略提升 1%-2%。

完整代码实现:多策略知识蒸馏

下面提供一个完整的知识蒸馏实现,包含响应式蒸馏、特征蒸馏、注意力传递等多种方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from typing import List, Tuple, Dict
import copy

# ============== 蒸馏损失函数 ==============

class KLDivergenceLoss(nn.Module):
"""响应式蒸馏: KL 散度损失"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.9):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.kl_div = nn.KLDivLoss(reduction='batchmean')

def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
# 软标签蒸馏损失
student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
kd_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)

# 硬标签分类损失
ce_loss = F.cross_entropy(student_logits, labels)

# 加权组合
total_loss = self.alpha * kd_loss + (1 - self.alpha) * ce_loss
return total_loss


class FeatureDistillationLoss(nn.Module):
"""特征蒸馏:中间层特征匹配"""
def __init__(self, student_channels: int, teacher_channels: int):
super().__init__()
# 投影层:将学生特征映射到教师特征空间
self.projector = nn.Conv2d(student_channels, teacher_channels,
kernel_size=1, bias=False)

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
# 投影学生特征
student_proj = self.projector(student_feat)

# MSE 损失
loss = F.mse_loss(student_proj, teacher_feat)
return loss


class AttentionTransferLoss(nn.Module):
"""注意力传递:激活注意力图蒸馏"""
def __init__(self, p: float = 2.0):
super().__init__()
self.p = p

def compute_attention_map(self, feature: torch.Tensor) -> torch.Tensor:
"""计算注意力图:对通道维度求 L^p 范数"""
# feature: [B, C, H, W]
attention = torch.sum(torch.abs(feature) ** self.p, dim=1, keepdim=True)
# 归一化
attention = attention / (torch.sum(attention, dim=[2, 3], keepdim=True) + 1e-8)
return attention

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
student_attn = self.compute_attention_map(student_feat)
teacher_attn = self.compute_attention_map(teacher_feat)

loss = F.mse_loss(student_attn, teacher_attn)
return loss


class RelationalDistillationLoss(nn.Module):
"""关系蒸馏:样本间距离和角度关系"""
def __init__(self, lambda_distance: float = 1.0, lambda_angle: float = 2.0):
super().__init__()
self.lambda_distance = lambda_distance
self.lambda_angle = lambda_angle

def compute_distance_relation(self, features: torch.Tensor) -> torch.Tensor:
"""计算样本对之间的归一化欧氏距离"""
# features: [B, D]
B = features.size(0)
# 计算所有样本对的距离矩阵
feat_norm = features / (torch.norm(features, p=2, dim=1, keepdim=True) + 1e-8)
distance_matrix = torch.cdist(feat_norm, feat_norm, p=2)
return distance_matrix

def compute_angle_relation(self, features: torch.Tensor) -> torch.Tensor:
"""计算样本三元组的角度关系"""
# features: [B, D]
B = features.size(0)
if B < 3:
return torch.tensor(0.0, device=features.device)

# 标准化特征
feat_norm = features / (torch.norm(features, p=2, dim=1, keepdim=True) + 1e-8)

# 计算余弦相似度矩阵
cos_sim = torch.mm(feat_norm, feat_norm.t())

# 随机采样三元组(简化实现)
indices = torch.randperm(B)[:min(B, 10)]
sampled_cos = cos_sim[indices][:, indices]

return sampled_cos

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
# 距离关系损失
student_dist = self.compute_distance_relation(student_feat)
teacher_dist = self.compute_distance_relation(teacher_feat)
dist_loss = F.mse_loss(student_dist, teacher_dist)

# 角度关系损失
student_angle = self.compute_angle_relation(student_feat)
teacher_angle = self.compute_angle_relation(teacher_feat)
angle_loss = F.mse_loss(student_angle, teacher_angle)

total_loss = self.lambda_distance * dist_loss + self.lambda_angle * angle_loss
return total_loss


# ============== 模型定义 ==============

class TeacherResNet(nn.Module):
"""教师模型: ResNet-34"""
def __init__(self, num_classes: int = 10):
super().__init__()
self.model = torchvision.models.resnet34(pretrained=False)
self.model.fc = nn.Linear(512, num_classes)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
# 提取中间层特征
features = []

x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)

x = self.model.layer1(x)
features.append(x) # 特征 1

x = self.model.layer2(x)
features.append(x) # 特征 2

x = self.model.layer3(x)
features.append(x) # 特征 3

x = self.model.layer4(x)
features.append(x) # 特征 4

x = self.model.avgpool(x)
x = torch.flatten(x, 1)
logits = self.model.fc(x)

return logits, features


class StudentResNet(nn.Module):
"""学生模型: ResNet-18"""
def __init__(self, num_classes: int = 10):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=False)
self.model.fc = nn.Linear(512, num_classes)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
features = []

x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)

x = self.model.layer1(x)
features.append(x)

x = self.model.layer2(x)
features.append(x)

x = self.model.layer3(x)
features.append(x)

x = self.model.layer4(x)
features.append(x)

x = self.model.avgpool(x)
x = torch.flatten(x, 1)
logits = self.model.fc(x)

return logits, features


# ============== 蒸馏训练器 ==============

class DistillationTrainer:
"""知识蒸馏训练器"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
device: str = 'cuda',
distill_type: str = 'response', # 'response', 'feature', 'attention', 'relation', 'combined'
temperature: float = 4.0,
alpha: float = 0.9,
):
self.teacher = teacher.to(device)
self.student = student.to(device)
self.device = device
self.distill_type = distill_type

# 冻结教师模型
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False

# 初始化损失函数
self.kd_loss = KLDivergenceLoss(temperature, alpha)

if distill_type in ['feature', 'combined']:
# ResNet-34 和 ResNet-18 的通道数
teacher_channels = [64, 128, 256, 512]
student_channels = [64, 128, 256, 512]
self.feat_losses = nn.ModuleList([
FeatureDistillationLoss(s_ch, t_ch).to(device)
for s_ch, t_ch in zip(student_channels, teacher_channels)
])

if distill_type in ['attention', 'combined']:
self.attn_loss = AttentionTransferLoss()

if distill_type in ['relation', 'combined']:
self.rel_loss = RelationalDistillationLoss()

def compute_loss(
self,
student_logits: torch.Tensor,
student_features: List[torch.Tensor],
teacher_logits: torch.Tensor,
teacher_features: List[torch.Tensor],
labels: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""计算总损失"""
losses = {}

# 响应式蒸馏损失
kd_loss = self.kd_loss(student_logits, teacher_logits, labels)
losses['kd'] = kd_loss
total_loss = kd_loss

# 特征蒸馏损失
if self.distill_type in ['feature', 'combined']:
feat_loss = 0
for i, (s_feat, t_feat, feat_loss_fn) in enumerate(
zip(student_features, teacher_features, self.feat_losses)
):
feat_loss += feat_loss_fn(s_feat, t_feat)
feat_loss /= len(student_features)
losses['feature'] = feat_loss
total_loss += 0.5 * feat_loss

# 注意力传递损失
if self.distill_type in ['attention', 'combined']:
attn_loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
attn_loss += self.attn_loss(s_feat, t_feat)
attn_loss /= len(student_features)
losses['attention'] = attn_loss
total_loss += 0.3 * attn_loss

# 关系蒸馏损失(在最后一层特征上)
if self.distill_type in ['relation', 'combined']:
s_feat_flat = torch.flatten(student_features[-1], 1)
t_feat_flat = torch.flatten(teacher_features[-1], 1)
rel_loss = self.rel_loss(s_feat_flat, t_feat_flat)
losses['relation'] = rel_loss
total_loss += 0.2 * rel_loss

losses['total'] = total_loss
return losses

def train_epoch(
self,
train_loader: DataLoader,
optimizer: optim.Optimizer,
epoch: int,
) -> Dict[str, float]:
"""训练一个 epoch"""
self.student.train()

total_losses = {key: 0.0 for key in ['kd', 'feature', 'attention', 'relation', 'total']}
correct = 0
total = 0

for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# 前向传播
with torch.no_grad():
teacher_logits, teacher_features = self.teacher(inputs)

student_logits, student_features = self.student(inputs)

# 计算损失
losses = self.compute_loss(
student_logits, student_features,
teacher_logits, teacher_features,
labels
)

# 反向传播
optimizer.zero_grad()
losses['total'].backward()
optimizer.step()

# 统计
for key, value in losses.items():
if key in total_losses:
total_losses[key] += value.item()

_, predicted = student_logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

if batch_idx % 50 == 0:
print(f'Epoch {epoch} [{batch_idx}/{len(train_loader)}] '
f'Loss: {losses["total"]:.4f} '
f'Acc: {100. * correct / total:.2f}%')

# 计算平均损失
for key in total_losses:
total_losses[key] /= len(train_loader)

accuracy = 100. * correct / total
return {**total_losses, 'accuracy': accuracy}

@torch.no_grad()
def evaluate(self, test_loader: DataLoader) -> Tuple[float, float]:
"""评估学生模型"""
self.student.eval()

correct = 0
total = 0
test_loss = 0

for inputs, labels in test_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)

logits, _ = self.student(inputs)
loss = F.cross_entropy(logits, labels)

test_loss += loss.item()
_, predicted = logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

accuracy = 100. * correct / total
avg_loss = test_loss / len(test_loader)

return accuracy, avg_loss


# ============== 自蒸馏与相互学习 ==============

class SelfDistillationTrainer:
"""自蒸馏训练器: Born-Again Networks"""
def __init__(
self,
model_class: type,
num_classes: int = 10,
device: str = 'cuda',
temperature: float = 4.0,
):
self.model_class = model_class
self.num_classes = num_classes
self.device = device
self.temperature = temperature
self.generations = []

def train_generation(
self,
train_loader: DataLoader,
test_loader: DataLoader,
num_epochs: int = 10,
teacher_model: nn.Module = None,
) -> nn.Module:
"""训练一代模型"""
model = self.model_class(self.num_classes).to(self.device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):
model.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)

logits, _ = model(inputs)

# 硬标签损失
ce_loss = F.cross_entropy(logits, labels)

# 如果有教师,添加蒸馏损失
if teacher_model is not None:
with torch.no_grad():
teacher_logits, _ = teacher_model(inputs)

student_soft = F.log_softmax(logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
kd_loss *= (self.temperature ** 2)

total_loss = 0.1 * ce_loss + 0.9 * kd_loss
else:
total_loss = ce_loss

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

scheduler.step()

# 评估
accuracy, _ = self.evaluate(model, test_loader)
print(f'Generation {len(self.generations)} Epoch {epoch}: Acc = {accuracy:.2f}%')

return model

@torch.no_grad()
def evaluate(self, model: nn.Module, test_loader: DataLoader) -> Tuple[float, float]:
model.eval()
correct = 0
total = 0

for inputs, labels in test_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
logits, _ = model(inputs)
_, predicted = logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

accuracy = 100. * correct / total
return accuracy, 0.0

def train_multiple_generations(
self,
train_loader: DataLoader,
test_loader: DataLoader,
num_generations: int = 3,
num_epochs_per_gen: int = 10,
) -> List[nn.Module]:
"""训练多代模型"""
print("Training Generation 1 (no teacher)...")
gen1 = self.train_generation(train_loader, test_loader, num_epochs_per_gen, teacher_model=None)
self.generations.append(gen1)

for i in range(2, num_generations + 1):
print(f"\nTraining Generation {i} (teacher = Gen {i-1})...")
teacher = self.generations[-1]
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False

gen_i = self.train_generation(train_loader, test_loader, num_epochs_per_gen, teacher_model=teacher)
self.generations.append(gen_i)

return self.generations


# ============== 主函数 ==============

def main():
# 超参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 20
batch_size = 128
num_classes = 10

# 数据加载
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# ========== 实验 1:标准知识蒸馏 ==========
print("\n" + "="*50)
print("实验 1:标准知识蒸馏 (ResNet-34 -> ResNet-18)")
print("="*50)

# 训练教师模型(或加载预训练)
teacher = TeacherResNet(num_classes).to(device)
print("Training teacher model...")
# 这里省略教师训练代码,假设已有预训练模型
# train_teacher(teacher, trainloader, testloader, num_epochs)

# 蒸馏训练学生模型
student = StudentResNet(num_classes).to(device)
trainer = DistillationTrainer(
teacher=teacher,
student=student,
device=device,
distill_type='response', # 响应式蒸馏
temperature=4.0,
alpha=0.9,
)

optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

best_acc = 0
for epoch in range(num_epochs):
train_metrics = trainer.train_epoch(trainloader, optimizer, epoch)
test_acc, test_loss = trainer.evaluate(testloader)
scheduler.step()

print(f'Epoch {epoch}: Train Acc = {train_metrics["accuracy"]:.2f}%, '
f'Test Acc = {test_acc:.2f}%, Test Loss = {test_loss:.4f}')

if test_acc > best_acc:
best_acc = test_acc

print(f'Best Test Accuracy: {best_acc:.2f}%')

# ========== 实验 2:组合蒸馏 ==========
print("\n" + "="*50)
print("实验 2:组合蒸馏 (Response + Feature + Attention + Relation)")
print("="*50)

student_combined = StudentResNet(num_classes).to(device)
trainer_combined = DistillationTrainer(
teacher=teacher,
student=student_combined,
device=device,
distill_type='combined',
temperature=4.0,
alpha=0.7, # 降低 alpha 以平衡多个损失
)

optimizer = optim.SGD(student_combined.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

best_acc_combined = 0
for epoch in range(num_epochs):
train_metrics = trainer_combined.train_epoch(trainloader, optimizer, epoch)
test_acc, test_loss = trainer_combined.evaluate(testloader)
scheduler.step()

print(f'Epoch {epoch}: Test Acc = {test_acc:.2f}%')

if test_acc > best_acc_combined:
best_acc_combined = test_acc

print(f'Best Test Accuracy (Combined): {best_acc_combined:.2f}%')

# ========== 实验 3:自蒸馏 ==========
print("\n" + "="*50)
print("实验 3:自蒸馏 (Born-Again Networks)")
print("="*50)

self_distiller = SelfDistillationTrainer(
model_class=StudentResNet,
num_classes=num_classes,
device=device,
temperature=4.0,
)

generations = self_distiller.train_multiple_generations(
trainloader, testloader,
num_generations=3,
num_epochs_per_gen=10,
)

for i, model in enumerate(generations):
acc, _ = self_distiller.evaluate(model, testloader)
print(f'Generation {i+1} Test Accuracy: {acc:.2f}%')


if __name__ == '__main__':
main()

代码说明

  1. 损失函数模块
    • KLDivergenceLoss:响应式蒸馏,包含温度参数和 缩放
    • FeatureDistillationLoss: FitNets 风格的特征匹配,带投影层
    • AttentionTransferLoss:计算激活注意力图并匹配
    • RelationalDistillationLoss: RKD 风格的距离和角度关系
  2. 模型定义
    • TeacherResNetStudentResNet都返回 logits 和中间特征
    • 特征提取点选在每个 block 之后
  3. 训练器
    • DistillationTrainer支持多种蒸馏策略( response/feature/attention/relation/combined)
    • 自动冻结教师模型
    • compute_loss方法灵活组合多个损失
  4. 自蒸馏
    • SelfDistillationTrainer实现 Born-Again Networks
    • 迭代训练多代模型,每代用前一代作为教师

Q&A:常见问题解答

Q1:温度参数 如何选择?

A:温度参数的选择取决于任务和数据:

  • 分类任务,通常从 开始尝试
  • 回归任务:温度的作用较小,可以不用或用较低的
  • 调优策略:在验证集上网格搜索 原理:温度需要平衡两个因素:
  • 太小:软标签退化为硬标签,暗知识丢失
  • 太大:所有类的概率趋于均匀,信号变弱

经验上,类别数越多、类别相似度越高,需要越高的温度。

Q2:平衡系数 如何设置?

A 控制蒸馏损失和分类损失的权重:

  • (如 0.9):更依赖教师知识,适合教师很强、数据很少的情况
  • (如 0.5):更依赖硬标签,适合教师和学生容量接近的情况

调优建议: - 如果学生容量是教师的 1/10 以下: - 如果学生容量是教师的 1/4-1/2: - 如果使用多个蒸馏损失( combined):降低 到 0.5-0.7

Q3:学生模型应该多小?

A:学生容量取决于部署约束和性能要求:

  • 移动端:通常压缩到教师的 1/10 参数量,接受 2-5%的精度损失
  • 边缘设备:压缩到 1/20-1/50,可能损失 5-10%
  • 服务器优化:压缩到 1/2-1/4,损失<1%

重要发现:蒸馏的效果在学生很小时尤为明显。当学生容量接近教师时,蒸馏的收益递减。

Q4:蒸馏为什么对小模型特别有效?

A:理论上有几种解释:

  1. 容量限制下的知识压缩:小模型无法拟合所有训练数据,软标签提供了哪些知识最重要的信号
  2. 正则化效应:软标签的熵更高,防止小模型在有限数据上过拟合
  3. 优化曲面平滑:软标签的梯度更平滑,帮助小模型找到更好的局部最优

实验证据:当学生容量极小(参数量<1%教师)时,蒸馏可以带来 10-20%的相对提升。

Q5:特征蒸馏的层如何选择?

A:层的选择影响蒸馏效果:

  • 浅层:低级特征(边缘、纹理),对所有任务都有用,但信息量有限
  • 深层:高级语义特征,任务相关性强,但可能过拟合

推荐策略: - 选择教师的中间层(如 ResNet 的 layer2 和 layer3) - 避免选择第一层(信息太基础)和最后一层(已被响应式蒸馏覆盖) - 多层蒸馏时,权重应该递增(深层权重更大)

自动化方法:用 NAS 或强化学习搜索最优层组合(如 MetaDistiller)。

Q6:自蒸馏为什么有效?

A:自蒸馏( Self-Distillation)看似悖论:学生和教师架构相同,凭什么蒸馏能提升性能?

解释: 1. 正则化:软标签提供了更平滑的监督信号,减少过拟合 2. 集成效应:每一代模型探索了损失曲面的不同区域,相当于隐式集成 3. 暗知识提炼:即使同架构,教师也学到了类别关系等暗知识

实验证据: - Born-Again Networks 在 CIFAR-100 上提升 1-2% - 提升在数据量较小时更明显

Q7:蒸馏和剪枝/量化如何结合?

A:蒸馏可以显著缓解剪枝和量化的性能损失:

剪枝 + 蒸馏: 1. 训练全模型教师 2. 剪枝教师得到初始学生 3. 用教师软标签微调学生 4. 效果:通常能恢复 50-80%的剪枝损失

量化 + 蒸馏: 1. FP32 教师训练 2. 量化初始化学生( INT8 或 INT4) 3. 蒸馏微调量化学生 4. 效果: INT8 几乎无损, INT4 损失<1%

同时应用:先剪枝再量化,配合蒸馏,可以达到 10-20 倍压缩比。

Q8:蒸馏在 NLP 和 CV 中有区别吗?

A:蒸馏的核心原理相同,但具体实现有差异:

CV 特点: - 特征图是空间结构( 2D),可以用注意力图、 Gram 矩阵等 - 通常在多个卷积层进行特征蒸馏 - 数据增强(如 MixUp)可以进一步提升蒸馏效果

NLP 特点: - 特征是序列( 1D),用序列对齐或池化方法 - BERT 等模型的中间层蒸馏(如 DistilBERT)很有效 - 预训练+蒸馏是主流范式(先预训练大模型,再蒸馏到小模型)

共性:响应式蒸馏在两个领域都有效,是 baseline 方法。

Q9:多个教师模型可以吗?

A:可以,称为多教师蒸馏( Multi-Teacher Distillation)

平均集成

$$

q_{} = _{k=1}^K q_k $$

学生学习多个教师的平均分布。

加权集成

$$

q_{} = _{k=1}^K w_k q_k, _k w_k = 1 $$

权重 可以是固定的(如按教师准确率),也可以是可学习的。

优点:集成多个教师的知识,鲁棒性更好。

缺点:需要训练多个教师,成本高。

Q10:蒸馏在小数据集上效果如何?

A:蒸馏在小数据集上尤其有效:

原因: - 小数据上容易过拟合,软标签提供了强正则化 - 教师模型在更大的数据集(如 ImageNet)上预训练,迁移了先验知识

实验(医疗影像分类, 1000 张训练图): - 从头训练小模型: 65% accuracy - 用硬标签微调预训练模型: 72% - 用大模型蒸馏小模型: 75%

蒸馏比直接微调提升 3%,说明软标签在小数据上的价值。

Q11:蒸馏的计算开销如何?

A:蒸馏的额外开销包括:

  1. 教师推理:训练时需要教师前向传播,增加约 50%计算量
  2. 特征存储(特征蒸馏):需要存储中间特征,增加内存
  3. 多个损失计算:额外的 KL 散度、 MSE 等,开销很小

优化策略: - 离线蒸馏:预先计算教师的软标签并保存,训练时直接加载(节省教师推理) - 在线蒸馏:动态更新教师,但计算开销大 - 选择性蒸馏:只在困难样本上进行蒸馏

推理阶段:学生模型独立部署,无额外开销。

Q12:蒸馏和迁移学习的关系?

A:蒸馏是一种特殊的迁移学习:

共性: - 都是从一个模型(源)迁移知识到另一个模型(目标) - 都利用了先验知识减少目标任务的数据需求

差异: - 迁移学习:通常改变任务(如 ImageNet 医疗影像) - 蒸馏:通常保持任务,改变模型容量

结合:跨任务蒸馏( Cross-Task Distillation)同时改变任务和容量,是一个活跃的研究方向。

论文推荐

经典论文

  1. Hinton et al., "Distilling the Knowledge in a Neural Network", NIPS 2014 Workshop
    • 提出知识蒸馏、温度参数、软标签的概念
    • 奠定了蒸馏研究的基础
    • arXiv:1503.02531
  2. Romero et al., "FitNets: Hints for Thin Deep Nets", ICLR 2015
    • 首次提出特征蒸馏( feature-based distillation)
    • 引入提示学习( hint learning)概念
    • 两阶段训练策略
  3. Zagoruyko & Komodakis, "Paying More Attention to Attention", ICLR 2017
    • 提出注意力传递( Attention Transfer)
    • 激活注意力和梯度注意力
    • 在多个数据集上验证有效性
  4. Tung & Mori, "Similarity-Preserving Knowledge Distillation", ICCV 2019
    • 相似性保持蒸馏( SP)
    • 样本对关系的传递
    • 理论分析相似性保持的重要性

关系蒸馏

  1. Park et al., "Relational Knowledge Distillation", CVPR 2019
    • 关系知识蒸馏( RKD)
    • 距离关系和角度关系
    • 三元组采样策略
  2. Tian et al., "Contrastive Representation Distillation", ICLR 2020
    • 对比表示蒸馏( CRD)
    • 用对比学习框架进行蒸馏
    • 最大化学生-教师特征的互信息

自蒸馏与相互学习

  1. Furlanello et al., "Born-Again Neural Networks", ICML 2018
    • 自蒸馏( Self-Distillation)
    • 迭代蒸馏提升同架构模型
    • 理论分析为何自蒸馏有效
  2. Zhang et al., "Deep Mutual Learning", CVPR 2018
    • 相互学习( Mutual Learning)
    • 多个学生同时训练、互相监督
    • 无需预训练教师

NLP 中的蒸馏

  1. Sanh et al., "DistilBERT, a distilled version of BERT", NeurIPS 2019 Workshop
    • 将 BERT-base 蒸馏到更小的模型
    • 保留 97%性能,减少 40%参数
    • 广泛应用于工业界
  2. Jiao et al., "TinyBERT", Findings of EMNLP 2020
    • 两阶段蒸馏:预训练蒸馏 + 任务蒸馏
    • 嵌入层、注意力、隐层全方位蒸馏
    • 达到 7.5 倍压缩比

量化与蒸馏

  1. Mishra & Marr, "Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy", ICLR 2018
    • 量化感知蒸馏
    • 用 FP32 教师帮助 INT8 学生
    • 缓解量化带来的精度损失
  2. Liu et al., "MetaDistiller: Network Self-Boosting via Meta-Learned Top-Down Distillation", ECCV 2020
    • 用 NAS 搜索蒸馏策略
    • 自动选择蒸馏层和权重
    • 在多个任务上达到 SOTA

知识蒸馏是一个简单而强大的思想:让小模型学习大模型的"思维方式"而非简单模仿输出。通过软标签、温度参数、特征匹配、关系保持等技术,蒸馏能在显著减少模型大小的同时保持接近原始模型的性能。从 Hinton 的开创性工作到近年来的 CRD 、 TinyBERT 等方法,蒸馏技术不断进化,成为模型压缩和迁移学习的核心工具。无论是移动端部署、边缘计算,还是大模型的民主化,知识蒸馏都将发挥关键作用。

  • 本文标题:迁移学习(五)—— 知识蒸馏
  • 本文作者:Chen Kai
  • 创建时间:2024-11-27 09:30:00
  • 本文链接:https://www.chenk.top/%E8%BF%81%E7%A7%BB%E5%AD%A6%E4%B9%A0%EF%BC%88%E4%BA%94%EF%BC%89%E2%80%94%E2%80%94-%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论