Transfer Learning (5): Knowledge Distillation
Chen Kai BOSS

Knowledge Distillation (KD) is a model compression and transfer learning technique that enables small models (students) to learn from large models (teachers), maintaining performance close to teacher models while significantly reducing parameters and computation. Hinton et al.'s seminal 2015 paper "Distilling the Knowledge in a Neural Network" sparked a research wave in this field. But knowledge distillation is far more than simple "soft label" training — it involves temperature parameter tuning, extracting knowledge at different levels, matching student-teacher architectures, and numerous technical details.

This article derives the mathematical foundations of knowledge distillation from first principles, explains why soft labels contain more information than hard labels, details implementation of response-based, feature-based, and relation-based distillation, introduces methods like self-distillation, mutual learning, and online distillation that don't require pre-trained teachers, and explores synergistic optimization of quantization, pruning, and distillation. We'll see that distillation is essentially "compression encoding" of knowledge — explicitly transferring dark knowledge implicitly learned by teacher models to student models.

Motivation for Knowledge Distillation

From Model Compression to Knowledge Transfer

Deep neural networks typically require massive parameters to achieve optimal performance. However, large models face numerous deployment challenges:

  • Mobile Deployment: Phones and IoT devices have limited memory and computation, cannot run models with billions of parameters
  • Inference Latency: Real-time systems like autonomous driving and industrial control require millisecond-level response
  • Energy Constraints: Edge devices need long battery life, large models consume too much power
  • Cost Optimization: Cloud services handle billions of requests daily, smaller models reduce costs

Traditional model compression methods (pruning, quantization, low-rank factorization) directly manipulate model structure or parameters, often causing significant performance degradation. Knowledge distillation's core idea is to have small models learn large models' output distributions, not simply fit hard labels.

Dark Knowledge: Information Advantage of Soft Labels

Consider an image classification task where the true label is "cat" (hard label is one-hot vector ). A trained teacher model might output the following probability distribution:Although the teacher predicts highest probability for "cat", other class probabilities also contain valuable information:

  • High probability for "tiger": indicates this cat shares visual similarities with tigers (body shape, patterns, etc.)
  • Low but non-zero probability for "dog": shows cats and dogs have common features (furry, four legs)
  • Extremely low probability for "car": indicates cats and cars have completely different visual features

These non-zero "error" probabilities are what Hinton calls dark knowledge— they reveal similarity structure between classes and embody the teacher model's generalization ability learned during training.

From an information theory perspective, hard labels have entropy of 0 (deterministic one-hot vectors), while soft labels have higher entropy:Soft labels provide richer supervisory signals, helping student models learn relationships between classes.

Mathematical Perspective on Distillation: Distribution Matching

Let teacher parameters be, student parameters be, input be, output logits beand. Standard classification training minimizes cross-entropy:whereis the hard label (one-hot) andis the softmax function.

Knowledge distillation has students match teachers' output distributions:This is the cross-entropy of two distributions, equivalent to minimizing KL divergence (sinceis constant):From an optimization perspective, distillation makes the student's output distributionapproximate the teacher's.

Temperature Parameter: Softening Probability Distributions

The problem with using softmax outputs directly is that probability distributions are often too "peaked"— maximum class probability approaches 1, other class probabilities approach 0, suppressing dark knowledge.

Hinton introduced the temperature parameterto soften distributions:When, the probability distribution becomes smoother:

-: All class probabilities approach uniform distribution -: Standard softmax -: Distribution degenerates to one-hot (argmax)

Intuitive Example: Consider logits -: (class 3 information almost lost) -: (class 3 information preserved)

Distillation loss at temperatureis defined as:where bothandare computed with temperature.

Theoretical Derivation: Why are gradients more stable at high temperatures?

Taking derivative with respect to logit (omitting normalization term):Asincreases, gradient magnitude scales by. But since the loss itself also changes with, the final gradient scaling factor is (see Hinton paper appendix). Therefore, the distillation loss needs to be multiplied byto balance gradient scales during training:whereis the balance coefficient andis standard cross-entropy loss on hard labels.

Response-Based Distillation: Knowledge Transfer at Output Layer

Response-based distillation is the most classic distillation method, using only the model's final layer outputs (logits or probabilities) for knowledge transfer.

Hinton's Original Distillation Algorithm

Algorithm Workflow:

  1. Train Teacher Model: Train high-capacity modelon complete dataset

  2. Generate Soft Labels: For each samplein training set, compute teacher's soft label

  3. Train Student Model: Minimize joint loss

  4. Inference Stage: Student model uses(standard softmax)

Hyperparameter Selection:

  • Temperature: Typically, task-dependent. Classification usually uses
  • Balance coefficient: Typically. Highermeans more reliance on teacher knowledge
  • Student capacity: Generallytoof teacher's parameter count

Experimental Observations (ImageNet experiments):

  • ResNet-34 teacher (73.3% accuracy) distilled to ResNet-18 student
  • Direct training ResNet-18: 69.8%
  • Distilled training ResNet-18: 71.4%
  • 1.6% improvement, but still 1.9% gap

Why Temperature Parameter Works: Information-Theoretic Analysis

From an information theory perspective, temperaturecontrols soft label information content. Define conditional entropy:It can be proven thatincreases monotonically with. Higher temperature means higher entropy — more uncertainty and richer information.

Specifically, whenis large, softmax can be Taylor expanded:where. Then:This shows that at high temperatures, relative differences in softmax outputs directly reflect relative differences in logits, unaffected by the exp function's non-linearity. Student models can more accurately learn relative relationships between classes.

Connection Between Distillation and Label Smoothing

Label smoothing is a regularization technique that replaces hard labelwith:whereis the smoothing coefficient (typically 0.1).

It can be proven that knowledge distillation is data-dependent label smoothing in a sense:

  • Label smoothing uses the same smoothed distribution (uniform) for all samples
  • Knowledge distillation uses different smoothed distributions (teacher's outputs) for each sample

Experiments show distillation typically outperforms label smoothing because teacher output distributions contain sample-specific information (e.g., some cat images look more like tigers).

Layer-wise Distillation: Multi-stage Knowledge Transfer

For very deep networks (like ResNet-152), distillation can be decomposed into multiple stages:

  1. Shallow Layer Distillation: Distill teacher's first few layers to student's first few layers
  2. Middle Layer Distillation: Distill teacher's middle layers to student's middle layers
  3. Deep Layer Distillation: Distill teacher's last few layers to student's last few layers

Loss function becomes multi-term sum:whereis the loss at the-th distillation point,is the weight.

Advantages: Finer-grained knowledge transfer, suitable when teacher and student architectures differ significantly.

Disadvantages: Requires manual design of distillation point locations and weights, expanding hyperparameter space.

Feature-Based Distillation: Knowledge Transfer at Intermediate Layers

Feature-based distillation utilizes not only output layers but also intermediate layer feature maps for knowledge transfer.

FitNets: Hint Learning

FitNets (Hints for Thin Deep Nets) was one of the earliest feature distillation methods, proposed by Romero et al. in 2015.

Core Idea: Have student's intermediate layer features match teacher's intermediate layer features.

Let teacher's features at layerbe, student's features at layerbe. Since dimensions may differ, introduce a learnable projection layer:whereis the Frobenius norm.

Training Strategy (Two-stage):

  1. Stage 1: Freeze teacher, train only student's firstlayers and projection layer, minimizing

  2. Stage 2: Fix firstlayers, train student's remaining layers, minimizing Hint Layer Position Selection:

  • Shallow Hints: Student learns low-level features (edges, textures), suitable when student is very small
  • Deep Hints: Student learns high-level semantic features, suitable when student capacity approaches teacher

Experiments find single hint layers have limited effectiveness, multiple hint layers work better (but increase computational cost).

Attention Transfer: Attention Map Distillation

Zagoruyko and Komodakis proposed Attention Transfer (AT) in 2017, using activation statistics of feature maps as "attention maps" for distillation.

Activation-based Attention:

For feature map, define attention map:whereis typically 1 or 2.represents activation intensity at each spatial location.

Loss function is:Normalization ensures scale invariance.

Gradient-based Attention:

Besides activation, gradients can also serve as attention:Gradient attention reflects which locations contribute most to loss, capturing the model's decision process.

Multi-layer Attention Transfer:whereis the selected layer set,is the weight.

Experimental Results (CIFAR-10):

  • ResNet-110 teacher (93.5%)ResNet-20 student
  • Baseline ResNet-20: 91.3%
  • Response-based distillation: 91.8%
  • Attention transfer: 92.4%

Attention distillation provides 0.6% improvement over response-based distillation, demonstrating intermediate layer knowledge transfer effectiveness.

PKT: Probabilistic Knowledge Transfer

Lopez-Paz et al. proposed Probabilistic Knowledge Transfer (PKT) in 2017, matching statistical properties of feature distributions rather than individual sample features.

Core Idea: Represent knowledge using pairwise sample similarities.

For a batch of samples, compute feature similarity matrix:whereis a kernel function (like Gaussian kernel). Loss function is:This matches relational structure between sample pairs rather than individual sample feature values.

Advantages:

  • Insensitive to feature dimension differences (no projection layer needed)
  • Captures semantic relationships between samples

Disadvantages:

  • Computational complexity, batch size cannot be too large
  • Requires reasonable selection of kernel function and bandwidth

NST: Neural Style Transfer-Inspired Distillation

Huang and Wang proposed in 2017, inspired by Neural Style Transfer, using Gram matrices for feature distillation.

For feature map, reshape to(where), define Gram matrix:Gram matrix elementrepresents correlation between channeland channel. Loss function is:

Intuition: Gram matrices capture second-order statistics (covariance) of features, reflecting relationships between different channels (e.g., co-occurrence patterns of "edge detector" and "texture detector").

Experiments: On CIFAR-100, NST shows further improvement over FitNets and AT (about 0.5%-1%).

Relation-Based Distillation: Transferring Inter-Sample Relations

Relation-based distillation considers not only individual sample outputs or features, but also relationships between samples.

RKD: Relational Knowledge Distillation

Park et al. proposed Relational Knowledge Distillation (RKD) in 2019, defining two types of relations:

Distance-wise Relation:

For a pair of samples, define normalized Euclidean distance:whereis the normalization factor.

Loss function is:whereis the sampled pair set.

Angle-wise Relation:

For triplet, define vector angle:where.

Loss function is:

Intuition:

  • Distance relation ensures relative distances between sample pairs remain consistent (e.g., "cat" and "dog" distance is smaller than "cat" and "car" distance)
  • Angle relation ensures relative positional relationships of samples (e.g., "Persian cat" direction relative to "cat" and "dog")

Total Loss:Experiments show angle relations are more important than distance relations ().

CRD: Contrastive Representation Distillation

Tian et al. proposed Contrastive Representation Distillation (CRD) in 2020, introducing contrastive learning into the distillation framework.

Core Idea: Use contrastive learning to maximize mutual information between student and teacher features.

For a positive pair(same sample's representations in teacher and student) andnegative samples(other samples), define InfoNCE loss:

Key Difference:

  • Traditional distillation: Use MSE or KL divergence to match features
  • CRD: Use contrastive learning to match features, more focused on inter-sample discrimination

Experimental Results (CIFAR-100):

  • ResNet-32x4 teacher (79.4%)ResNet-8x4 student
  • Response-based distillation: 73.3%
  • CRD: 75.5%

CRD is particularly effective on small student models (2%+ improvement).

SP: Similarity-Preserving Distillation

Tung and Mori proposed Similarity-Preserving Distillation (SP) in 2019, requiring student feature similarity matrices to match teacher's.

For a batch of samples, define similarity matrix:Loss function is:Difference from PKT: SP uses cosine similarity, PKT uses kernel similarity.

Self-Distillation: Teacher-Free Knowledge Transfer

Self-distillation is a distillation method that doesn't require pre-trained teachers, where models learn knowledge from their own earlier versions or different branches.

Born-Again Networks: Iterative Self-Distillation

Furlanello et al. proposed Born-Again Networks (BAN) in 2018, improving model performance through iterative distillation.

Algorithm Workflow:

  1. Train 1st Generation Model: Standard training yields

  2. Train 2nd Generation Model: Useas teacher to distill(andhave same architecture)

  3. Train 3rd Generation Model: Useas teacher to distill4. Repeat until performance saturates

Surprising Finding: Even when teacher and student have identical architectures, distillation still improves performance!

Theoretical Explanation:

  • Distillation provides smoother supervision signals (soft labels), reducing overfitting
  • Iterative distillation is an implicit form of ensemble learning
  • Each generation explores different regions of the loss landscape

Experiments (CIFAR-100):

  • 1st generation DenseNet: 74.3%
  • 2nd generation (BAN): 75.2%
  • 3rd generation: 75.4%
  • 4th generation: 75.5% (saturation)

Deep Mutual Learning: Mutual Learning

Zhang et al. proposed Deep Mutual Learning (DML) in 2018, training multiple student models simultaneously with mutual supervision.

Algorithm Workflow:

Forstudent models, each model's loss contains two parts:whereis model's output distribution.

Key Features:

  • No pre-trained teacher needed: All models train from scratch
  • Symmetry: Each model is both student and teacher
  • Online learning: Models learn each other's knowledge in real-time

Theoretical Intuition:

  • Each model makes different errors during training
  • Mutual learning helps models avoid each other's errors, similar to ensemble learning
  • Eventually each model performs better than training alone

Experiments (CIFAR-100):

  • Single ResNet-32 training: 70.2%
  • 2 ResNet-32s mutual learning: 72.1%
  • 4 ResNet-32s mutual learning: 72.8%

Online Distillation

Online distillation aggregates knowledge from multiple student models into a virtual teacher, avoiding pre-trained teacher overhead.

ONE (Online Network Ensemble):

Lan et al. proposed in 2018, using weighted average of multiple branches as teacher:Each branch's loss is:

KDCL (Knowledge Distillation via Collaborative Learning):

Song and Chai proposed in 2018, in addition to branch-wise distillation, also distills at different depths:whereis the selected depth set (e.g., every 4 layers).

Advantages:

  • Single training pass completes, saves time
  • Can use any branch or ensemble multiple branches finally

Synergistic Distillation with Quantization and Pruning

Knowledge distillation is often combined with quantization, pruning and other compression techniques to achieve higher compression ratios.

Quantization-Aware Distillation

Quantization maps floating-point parameters to low-bit integers (like 8-bit or 4-bit), but causes accuracy degradation. Distillation can alleviate this problem.

Algorithm Workflow:

  1. Train Full-Precision Teacher: Standard FP32 training
  2. Quantized Student Initialization: Quantize teacher parameters to INT8 as student initialization
  3. Distillation Fine-Tuning: Fine-tune quantized student using teacher's soft labels

Loss function:whereis quantized student output.

Quantization Details:

For weight, quantization formula is:whereis the scaling factor: is the number of bits.

Experiments (ResNet-18 on ImageNet):

  • FP32 baseline: 69.8%
  • INT8 quantization (no distillation): 68.5% (-1.3%)
  • INT8 quantization (with distillation): 69.2% (-0.6%)

Distillation reduces quantization loss by half.

Pruning-Aware Distillation

Pruning removes unimportant neurons or connections, distillation can help pruned models recover performance.

Algorithm Workflow:

  1. Train Full Model Teacher
  2. Structured Pruning: Remove channels or layers with low importance (e.g., judged by L1-norm)
  3. Distillation Recovery: Fine-tune pruned student using teacher soft labels

Importance Evaluation:

For convolutional layer channel, define importance:Remove lowest importanceof channels (e.g.,).

Loss Function (Multi-layer distillation):whereis intermediate layer feature distillation.

Experiments (VGG-16 on CIFAR-10):

  • Original VGG-16: 93.5% (14.7M parameters)
  • 70% pruning (no distillation): 92.1% (4.4M parameters)
  • 70% pruning (with distillation): 93.0% (4.4M parameters)

Distillation allows pruned models to nearly recover original performance.

NAS + Distillation: Neural Architecture Search and Distillation

Neural Architecture Search (NAS) can automatically find efficient student architectures, combined with distillation to further improve performance.

MetaDistiller:

Liu et al. proposed in 2020, using reinforcement learning to search for optimal distillation strategies:

  • Which layers to distill
  • Loss weightfor each layer
  • Temperature parameterSearch space size is, whereis number of layers,is number of weight candidates.

Use reinforcement learning (like PPO) to optimize search strategy, reward function is student accuracy on validation set.

Experiments: On CIFAR-100, MetaDistiller's found strategy improves 1%-2% over manually designed strategies.

Complete Code Implementation: Multi-Strategy Knowledge Distillation

Below is a complete knowledge distillation implementation including response-based, feature-based, attention transfer and other methods.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from typing import List, Tuple, Dict
import copy

# ============== Distillation Loss Functions ==============

class KLDivergenceLoss(nn.Module):
"""Response-based distillation: KL divergence loss"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.9):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.kl_div = nn.KLDivLoss(reduction='batchmean')

def forward(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
# Soft label distillation loss
student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
kd_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)

# Hard label classification loss
ce_loss = F.cross_entropy(student_logits, labels)

# Weighted combination
total_loss = self.alpha * kd_loss + (1 - self.alpha) * ce_loss
return total_loss


class FeatureDistillationLoss(nn.Module):
"""Feature distillation: intermediate layer feature matching"""
def __init__(self, student_channels: int, teacher_channels: int):
super().__init__()
# Projection layer: map student features to teacher feature space
self.projector = nn.Conv2d(student_channels, teacher_channels,
kernel_size=1, bias=False)

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
# Project student features
student_proj = self.projector(student_feat)

# MSE loss
loss = F.mse_loss(student_proj, teacher_feat)
return loss


class AttentionTransferLoss(nn.Module):
"""Attention transfer: activation attention map distillation"""
def __init__(self, p: float = 2.0):
super().__init__()
self.p = p

def compute_attention_map(self, feature: torch.Tensor) -> torch.Tensor:
"""Compute attention map: L^p norm over channel dimension"""
# feature: [B, C, H, W]
attention = torch.sum(torch.abs(feature) ** self.p, dim=1, keepdim=True)
# Normalize
attention = attention / (torch.sum(attention, dim=[2, 3], keepdim=True) + 1e-8)
return attention

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
student_attn = self.compute_attention_map(student_feat)
teacher_attn = self.compute_attention_map(teacher_feat)

loss = F.mse_loss(student_attn, teacher_attn)
return loss


class RelationalDistillationLoss(nn.Module):
"""Relational distillation: inter-sample distance and angle relations"""
def __init__(self, lambda_distance: float = 1.0, lambda_angle: float = 2.0):
super().__init__()
self.lambda_distance = lambda_distance
self.lambda_angle = lambda_angle

def compute_distance_relation(self, features: torch.Tensor) -> torch.Tensor:
"""Compute normalized Euclidean distance between sample pairs"""
# features: [B, D]
B = features.size(0)
# Compute distance matrix for all sample pairs
feat_norm = features / (torch.norm(features, p=2, dim=1, keepdim=True) + 1e-8)
distance_matrix = torch.cdist(feat_norm, feat_norm, p=2)
return distance_matrix

def compute_angle_relation(self, features: torch.Tensor) -> torch.Tensor:
"""Compute angle relations of sample triplets"""
# features: [B, D]
B = features.size(0)
if B < 3:
return torch.tensor(0.0, device=features.device)

# Normalize features
feat_norm = features / (torch.norm(features, p=2, dim=1, keepdim=True) + 1e-8)

# Compute cosine similarity matrix
cos_sim = torch.mm(feat_norm, feat_norm.t())

# Randomly sample triplets (simplified implementation)
indices = torch.randperm(B)[:min(B, 10)]
sampled_cos = cos_sim[indices][:, indices]

return sampled_cos

def forward(self, student_feat: torch.Tensor, teacher_feat: torch.Tensor) -> torch.Tensor:
# Distance relation loss
student_dist = self.compute_distance_relation(student_feat)
teacher_dist = self.compute_distance_relation(teacher_feat)
dist_loss = F.mse_loss(student_dist, teacher_dist)

# Angle relation loss
student_angle = self.compute_angle_relation(student_feat)
teacher_angle = self.compute_angle_relation(teacher_feat)
angle_loss = F.mse_loss(student_angle, teacher_angle)

total_loss = self.lambda_distance * dist_loss + self.lambda_angle * angle_loss
return total_loss


# ============== Model Definitions ==============

class TeacherResNet(nn.Module):
"""Teacher model: ResNet-34"""
def __init__(self, num_classes: int = 10):
super().__init__()
self.model = torchvision.models.resnet34(pretrained=False)
self.model.fc = nn.Linear(512, num_classes)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
# Extract intermediate features
features = []

x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)

x = self.model.layer1(x)
features.append(x) # Feature 1

x = self.model.layer2(x)
features.append(x) # Feature 2

x = self.model.layer3(x)
features.append(x) # Feature 3

x = self.model.layer4(x)
features.append(x) # Feature 4

x = self.model.avgpool(x)
x = torch.flatten(x, 1)
logits = self.model.fc(x)

return logits, features


class StudentResNet(nn.Module):
"""Student model: ResNet-18"""
def __init__(self, num_classes: int = 10):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=False)
self.model.fc = nn.Linear(512, num_classes)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
features = []

x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)

x = self.model.layer1(x)
features.append(x)

x = self.model.layer2(x)
features.append(x)

x = self.model.layer3(x)
features.append(x)

x = self.model.layer4(x)
features.append(x)

x = self.model.avgpool(x)
x = torch.flatten(x, 1)
logits = self.model.fc(x)

return logits, features


# ============== Distillation Trainer ==============

class DistillationTrainer:
"""Knowledge distillation trainer"""
def __init__(
self,
teacher: nn.Module,
student: nn.Module,
device: str = 'cuda',
distill_type: str = 'response', # 'response', 'feature', 'attention', 'relation', 'combined'
temperature: float = 4.0,
alpha: float = 0.9,
):
self.teacher = teacher.to(device)
self.student = student.to(device)
self.device = device
self.distill_type = distill_type

# Freeze teacher model
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False

# Initialize loss functions
self.kd_loss = KLDivergenceLoss(temperature, alpha)

if distill_type in ['feature', 'combined']:
# Channel counts for ResNet-34 and ResNet-18
teacher_channels = [64, 128, 256, 512]
student_channels = [64, 128, 256, 512]
self.feat_losses = nn.ModuleList([
FeatureDistillationLoss(s_ch, t_ch).to(device)
for s_ch, t_ch in zip(student_channels, teacher_channels)
])

if distill_type in ['attention', 'combined']:
self.attn_loss = AttentionTransferLoss()

if distill_type in ['relation', 'combined']:
self.rel_loss = RelationalDistillationLoss()

def compute_loss(
self,
student_logits: torch.Tensor,
student_features: List[torch.Tensor],
teacher_logits: torch.Tensor,
teacher_features: List[torch.Tensor],
labels: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Compute total loss"""
losses = {}

# Response-based distillation loss
kd_loss = self.kd_loss(student_logits, teacher_logits, labels)
losses['kd'] = kd_loss
total_loss = kd_loss

# Feature distillation loss
if self.distill_type in ['feature', 'combined']:
feat_loss = 0
for i, (s_feat, t_feat, feat_loss_fn) in enumerate(
zip(student_features, teacher_features, self.feat_losses)
):
feat_loss += feat_loss_fn(s_feat, t_feat)
feat_loss /= len(student_features)
losses['feature'] = feat_loss
total_loss += 0.5 * feat_loss

# Attention transfer loss
if self.distill_type in ['attention', 'combined']:
attn_loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
attn_loss += self.attn_loss(s_feat, t_feat)
attn_loss /= len(student_features)
losses['attention'] = attn_loss
total_loss += 0.3 * attn_loss

# Relational distillation loss (on last layer features)
if self.distill_type in ['relation', 'combined']:
s_feat_flat = torch.flatten(student_features[-1], 1)
t_feat_flat = torch.flatten(teacher_features[-1], 1)
rel_loss = self.rel_loss(s_feat_flat, t_feat_flat)
losses['relation'] = rel_loss
total_loss += 0.2 * rel_loss

losses['total'] = total_loss
return losses

def train_epoch(
self,
train_loader: DataLoader,
optimizer: optim.Optimizer,
epoch: int,
) -> Dict[str, float]:
"""Train one epoch"""
self.student.train()

total_losses = {key: 0.0 for key in ['kd', 'feature', 'attention', 'relation', 'total']}
correct = 0
total = 0

for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(self.device), labels.to(self.device)

# Forward pass
with torch.no_grad():
teacher_logits, teacher_features = self.teacher(inputs)

student_logits, student_features = self.student(inputs)

# Compute loss
losses = self.compute_loss(
student_logits, student_features,
teacher_logits, teacher_features,
labels
)

# Backward pass
optimizer.zero_grad()
losses['total'].backward()
optimizer.step()

# Statistics
for key, value in losses.items():
if key in total_losses:
total_losses[key] += value.item()

_, predicted = student_logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

if batch_idx % 50 == 0:
print(f'Epoch {epoch} [{batch_idx}/{len(train_loader)}] '
f'Loss: {losses["total"]:.4f} '
f'Acc: {100. * correct / total:.2f}%')

# Compute average loss
for key in total_losses:
total_losses[key] /= len(train_loader)

accuracy = 100. * correct / total
return {**total_losses, 'accuracy': accuracy}

@torch.no_grad()
def evaluate(self, test_loader: DataLoader) -> Tuple[float, float]:
"""Evaluate student model"""
self.student.eval()

correct = 0
total = 0
test_loss = 0

for inputs, labels in test_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)

logits, _ = self.student(inputs)
loss = F.cross_entropy(logits, labels)

test_loss += loss.item()
_, predicted = logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

accuracy = 100. * correct / total
avg_loss = test_loss / len(test_loader)

return accuracy, avg_loss


# ============== Self-Distillation and Mutual Learning ==============

class SelfDistillationTrainer:
"""Self-distillation trainer: Born-Again Networks"""
def __init__(
self,
model_class: type,
num_classes: int = 10,
device: str = 'cuda',
temperature: float = 4.0,
):
self.model_class = model_class
self.num_classes = num_classes
self.device = device
self.temperature = temperature
self.generations = []

def train_generation(
self,
train_loader: DataLoader,
test_loader: DataLoader,
num_epochs: int = 10,
teacher_model: nn.Module = None,
) -> nn.Module:
"""Train one generation model"""
model = self.model_class(self.num_classes).to(self.device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):
model.train()
for inputs, labels in train_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)

logits, _ = model(inputs)

# Hard label loss
ce_loss = F.cross_entropy(logits, labels)

# If there's a teacher, add distillation loss
if teacher_model is not None:
with torch.no_grad():
teacher_logits, _ = teacher_model(inputs)

student_soft = F.log_softmax(logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
kd_loss *= (self.temperature ** 2)

total_loss = 0.1 * ce_loss + 0.9 * kd_loss
else:
total_loss = ce_loss

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

scheduler.step()

# Evaluate
accuracy, _ = self.evaluate(model, test_loader)
print(f'Generation {len(self.generations)} Epoch {epoch}: Acc = {accuracy:.2f}%')

return model

@torch.no_grad()
def evaluate(self, model: nn.Module, test_loader: DataLoader) -> Tuple[float, float]:
model.eval()
correct = 0
total = 0

for inputs, labels in test_loader:
inputs, labels = inputs.to(self.device), labels.to(self.device)
logits, _ = model(inputs)
_, predicted = logits.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

accuracy = 100. * correct / total
return accuracy, 0.0

def train_multiple_generations(
self,
train_loader: DataLoader,
test_loader: DataLoader,
num_generations: int = 3,
num_epochs_per_gen: int = 10,
) -> List[nn.Module]:
"""Train multiple generation models"""
print("Training Generation 1 (no teacher)...")
gen1 = self.train_generation(train_loader, test_loader, num_epochs_per_gen, teacher_model=None)
self.generations.append(gen1)

for i in range(2, num_generations + 1):
print(f"\nTraining Generation {i} (teacher = Gen {i-1})...")
teacher = self.generations[-1]
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False

gen_i = self.train_generation(train_loader, test_loader, num_epochs_per_gen, teacher_model=teacher)
self.generations.append(gen_i)

return self.generations


# ============== Main Function ==============

def main():
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 20
batch_size = 128
num_classes = 10

# Data loading
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# ========== Experiment 1: Standard Knowledge Distillation ==========
print("\n" + "="*50)
print("Experiment 1: Standard Knowledge Distillation (ResNet-34 -> ResNet-18)")
print("="*50)

# Train teacher model (or load pretrained)
teacher = TeacherResNet(num_classes).to(device)
print("Training teacher model...")
# Teacher training code omitted here, assume pretrained model available
# train_teacher(teacher, trainloader, testloader, num_epochs)

# Distillation training for student model
student = StudentResNet(num_classes).to(device)
trainer = DistillationTrainer(
teacher=teacher,
student=student,
device=device,
distill_type='response', # Response-based distillation
temperature=4.0,
alpha=0.9,
)

optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

best_acc = 0
for epoch in range(num_epochs):
train_metrics = trainer.train_epoch(trainloader, optimizer, epoch)
test_acc, test_loss = trainer.evaluate(testloader)
scheduler.step()

print(f'Epoch {epoch}: Train Acc = {train_metrics["accuracy"]:.2f}%, '
f'Test Acc = {test_acc:.2f}%, Test Loss = {test_loss:.4f}')

if test_acc > best_acc:
best_acc = test_acc

print(f'Best Test Accuracy: {best_acc:.2f}%')

# ========== Experiment 2: Combined Distillation ==========
print("\n" + "="*50)
print("Experiment 2: Combined Distillation (Response + Feature + Attention + Relation)")
print("="*50)

student_combined = StudentResNet(num_classes).to(device)
trainer_combined = DistillationTrainer(
teacher=teacher,
student=student_combined,
device=device,
distill_type='combined',
temperature=4.0,
alpha=0.7, # Lower alpha to balance multiple losses
)

optimizer = optim.SGD(student_combined.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

best_acc_combined = 0
for epoch in range(num_epochs):
train_metrics = trainer_combined.train_epoch(trainloader, optimizer, epoch)
test_acc, test_loss = trainer_combined.evaluate(testloader)
scheduler.step()

print(f'Epoch {epoch}: Test Acc = {test_acc:.2f}%')

if test_acc > best_acc_combined:
best_acc_combined = test_acc

print(f'Best Test Accuracy (Combined): {best_acc_combined:.2f}%')

# ========== Experiment 3: Self-Distillation ==========
print("\n" + "="*50)
print("Experiment 3: Self-Distillation (Born-Again Networks)")
print("="*50)

self_distiller = SelfDistillationTrainer(
model_class=StudentResNet,
num_classes=num_classes,
device=device,
temperature=4.0,
)

generations = self_distiller.train_multiple_generations(
trainloader, testloader,
num_generations=3,
num_epochs_per_gen=10,
)

for i, model in enumerate(generations):
acc, _ = self_distiller.evaluate(model, testloader)
print(f'Generation {i+1} Test Accuracy: {acc:.2f}%')


if __name__ == '__main__':
main()

Code Explanation

  1. Loss Function Modules:
    • KLDivergenceLoss: Response-based distillation, includes temperature parameter andscaling
    • FeatureDistillationLoss: FitNets-style feature matching with projection layer
    • AttentionTransferLoss: Computes and matches activation attention maps
    • RelationalDistillationLoss: RKD-style distance and angle relations
  2. Model Definitions:
    • Both TeacherResNet and StudentResNet return logits and intermediate features
    • Feature extraction points are after each block
  3. Trainer:
    • DistillationTrainer supports multiple distillation strategies (response/feature/attention/relation/combined)
    • Automatically freezes teacher model
    • compute_loss method flexibly combines multiple losses
  4. Self-Distillation:
    • SelfDistillationTrainer implements Born-Again Networks
    • Iteratively trains multiple generation models, each using previous generation as teacher

Comprehensive Q&A

Q1: How to choose temperature parameter?

A: Temperature parameter selection depends on task and data:

  • Classification tasks:, typically start trying from
  • Regression tasks: Temperature has less effect, can skip or use lower
  • Tuning strategy: Grid search on validation set for Principle: Temperature needs to balance two factors: -too small: Soft labels degenerate to hard labels, dark knowledge lost -too large: All class probabilities approach uniform, signal weakens

Empirically, more classes and higher class similarity require higher temperature.

Q2: How to set balance coefficient?

A:controls weight between distillation loss and classification loss:

  • High (like 0.9): More reliance on teacher knowledge, suitable when teacher is strong and data is scarce
  • Low (like 0.5): More reliance on hard labels, suitable when teacher and student capacity are similar

Tuning Advice: - If student capacity is below 1/10 of teacher: - If student capacity is 1/4-1/2 of teacher: - If using multiple distillation losses (combined): lowerto 0.5-0.7

Q3: How small should the student model be?

A: Student capacity depends on deployment constraints and performance requirements:

  • Mobile: Typically compress to 1/10 of teacher parameters, accept 2-5% accuracy loss
  • Edge devices: Compress to 1/20-1/50, may lose 5-10%
  • Server optimization: Compress to 1/2-1/4, loss <1%

Important Finding: Distillation effectiveness is especially significant when student is very small. When student capacity approaches teacher, distillation returns diminish.

Q4: Why is distillation particularly effective for small models?

A: Several theoretical explanations:

  1. Knowledge Compression Under Capacity Constraints: Small models cannot fit all training data, soft labels provide signals about which knowledge is most important
  2. Regularization Effect: Soft labels have higher entropy, preventing small models from overfitting on limited data
  3. Smooth Optimization Landscape: Soft label gradients are smoother, helping small models find better local optima

Experimental Evidence: When student capacity is extremely small (parameter count <1% of teacher), distillation can bring 10-20% relative improvement.

Q5: How to select layers for feature distillation?

A: Layer selection affects distillation effectiveness:

  • Shallow Layers: Low-level features (edges, textures), useful for all tasks but limited information
  • Deep Layers: High-level semantic features, strong task relevance but may overfit

Recommended Strategy: - Choose teacher's middle layers (like ResNet's layer2 and layer3) - Avoid first layer (too basic information) and last layer (already covered by response-based distillation) - For multi-layer distillation, weights should increase progressively (deeper layers have higher weights)

Automated Methods: Use NAS or reinforcement learning to search optimal layer combinations (like MetaDistiller).

Q6: Why is self-distillation effective?

A: Self-distillation seems paradoxical: student and teacher have same architecture, how can distillation improve performance?

Explanations: 1. Regularization: Soft labels provide smoother supervision signals, reducing overfitting 2. Ensemble Effect: Each generation explores different regions of loss landscape, equivalent to implicit ensemble 3. Dark Knowledge Refinement: Even with same architecture, teacher learns dark knowledge like class relationships

Experimental Evidence: - Born-Again Networks improve 1-2% on CIFAR-100 - Improvement more significant when data volume is smaller

Q7: How to combine distillation with pruning/quantization?

A: Distillation can significantly alleviate performance loss from pruning and quantization:

Pruning + Distillation: 1. Train full model teacher 2. Prune teacher to get initial student 3. Fine-tune student with teacher soft labels 4. Effect: Typically recovers 50-80% of pruning loss

Quantization + Distillation: 1. FP32 teacher training 2. Quantized student initialization (INT8 or INT4) 3. Distillation fine-tuning of quantized student 4. Effect: INT8 nearly lossless, INT4 loss <1%

Simultaneous Application: First prune then quantize, with distillation, can achieve 10-20x compression ratio.

Q8: Are there differences in distillation between NLP and CV?

A: Distillation core principles are same, but specific implementations differ:

CV Characteristics: - Feature maps have spatial structure (2D), can use attention maps, Gram matrices, etc. - Typically perform feature distillation at multiple convolutional layers - Data augmentation (like MixUp) can further improve distillation effectiveness

NLP Characteristics: - Features are sequences (1D), use sequence alignment or pooling methods - BERT and other models' intermediate layer distillation (like DistilBERT) very effective - Pre-training + distillation is mainstream paradigm (first pre-train large model, then distill to small model)

Commonality: Response-based distillation effective in both domains, baseline method.

Q9: Can we use multiple teacher models?

A: Yes, called Multi-Teacher Distillation:

Average Ensemble:Student learns average distribution of multiple teachers.

Weighted Ensemble:Weightscan be fixed (like by teacher accuracy) or learnable.

Advantages: Ensemble knowledge from multiple teachers, more robust.

Disadvantages: Need to train multiple teachers, high cost.

Q10: How effective is distillation on small datasets?

A: Distillation especially effective on small datasets:

Reasons: - Small data prone to overfitting, soft labels provide strong regularization - Teacher models pre-trained on larger datasets (like ImageNet) transfer prior knowledge

Experiments (Medical image classification, 1000 training images): - Train small model from scratch: 65% accuracy - Fine-tune pretrained model with hard labels: 72% - Distill small model from large model: 75%

Distillation provides 3% improvement over direct fine-tuning, demonstrating soft label value on small data.

Q11: What is the computational overhead of distillation?

A: Additional overhead from distillation includes:

  1. Teacher Inference: Training requires teacher forward pass, increases computation by about 50%
  2. Feature Storage (feature distillation): Need to store intermediate features, increases memory
  3. Multiple Loss Computations: Additional KL divergence, MSE, etc., very small overhead

Optimization Strategies: - Offline Distillation: Pre-compute teacher's soft labels and save, load directly during training (saves teacher inference) - Online Distillation: Dynamically update teacher, but high computational overhead - Selective Distillation: Only distill on difficult samples

Inference Stage: Student model deployed independently, no additional overhead.

Q12: Relationship between distillation and transfer learning?

A: Distillation is a special type of transfer learning:

Commonalities: - Both transfer knowledge from one model (source) to another (target) - Both leverage prior knowledge to reduce target task data requirements

Differences: - Transfer Learning: Typically changes tasks (like ImageNetmedical imaging) - Distillation: Typically keeps task, changes model capacity

Combination: Cross-task distillation simultaneously changes tasks and capacity, an active research direction.

Q13: How does distillation relate to knowledge transfer in humans?

A: Distillation mimics human knowledge transfer process:

Human Learning: - Experts teach students not just "correct answers" but also "thinking processes" - Students learn from experts' confidence levels, similarities between concepts - Iterative learning: students may become teachers to next generation

Distillation Analogy: - Soft labels are like experts' "confidence" and "similarities between concepts" - Temperature parameter controls knowledge granularity - Self-distillation is like iterative refinement of knowledge

This analogy inspires new distillation methods, like meta-knowledge distillation, curriculum distillation, etc.

Q14: Can distillation improve robustness?

A: Distillation can improve model robustness to some extent:

Adversarial Robustness: - Soft labels provide smoother supervision signals, reducing sensitivity to adversarial perturbations - Experiments show distilled models have slightly higher adversarial accuracy - Combining adversarial training with distillation further improves robustness

Distribution Shift Robustness: - Teacher models trained on broader data distributions can transfer robustness to students - Cross-domain distillation helps student models generalize to different distributions

Noise Robustness: - Soft labels' regularization effect makes models less sensitive to label noise - Particularly effective on small, noisy datasets

Q15: What is the future direction of distillation research?

A: Several promising research directions:

  1. Automated Distillation Strategy Search:
    • Use NAS/RL to automatically find optimal distillation layers, weights, temperatures
    • Reduce manual tuning effort
  2. Cross-Modal Distillation:
    • Transfer knowledge from one modality to another (like visionlanguage)
    • Multimodal large model distillation
  3. Few-Shot Distillation:
    • How to effectively distill with very few samples
    • Combine meta-learning and distillation
  4. Self-Supervised Distillation:
    • Use self-supervised learning objectives for distillation
    • Reduce reliance on labeled data
  5. Lifelong Distillation:
    • Continually distill knowledge as new tasks arrive
    • Avoid catastrophic forgetting

Classic Papers

  1. Hinton et al., "Distilling the Knowledge in a Neural Network", NIPS 2014 Workshop
    • Proposed knowledge distillation, temperature parameter, soft label concepts
    • Laid foundation for distillation research
    • arXiv:1503.02531
  2. Romero et al., "FitNets: Hints for Thin Deep Nets", ICLR 2015
    • First proposed feature-based distillation
    • Introduced hint learning concept
    • Two-stage training strategy
    • arXiv:1412.6550
  3. Zagoruyko & Komodakis, "Paying More Attention to Attention", ICLR 2017
    • Proposed Attention Transfer
    • Activation attention and gradient attention
    • Validated effectiveness on multiple datasets
    • arXiv:1612.03928
  4. Tung & Mori, "Similarity-Preserving Knowledge Distillation", ICCV 2019
    • Similarity-Preserving distillation (SP)
    • Transfer of sample pair relations
    • Theoretical analysis of similarity preservation importance
    • arXiv:1907.09682

Relational Distillation

  1. Park et al., "Relational Knowledge Distillation", CVPR 2019
    • Relational Knowledge Distillation (RKD)
    • Distance relations and angle relations
    • Triplet sampling strategy
    • arXiv:1904.05068
  2. Tian et al., "Contrastive Representation Distillation", ICLR 2020
    • Contrastive Representation Distillation (CRD)
    • Use contrastive learning framework for distillation
    • Maximize student-teacher feature mutual information
    • arXiv:1910.10699

Self-Distillation and Mutual Learning

  1. Furlanello et al., "Born-Again Neural Networks", ICML 2018
    • Self-distillation
    • Iterative distillation improves same-architecture models
    • Theoretical analysis of why self-distillation works
    • arXiv:1805.04770
  2. Zhang et al., "Deep Mutual Learning", CVPR 2018
    • Mutual Learning
    • Multiple students train simultaneously, mutually supervise
    • No pre-trained teacher needed
    • arXiv:1706.00384

Distillation in NLP

  1. Sanh et al., "DistilBERT, a distilled version of BERT", NeurIPS 2019 Workshop
    • Distill BERT-base to smaller model
    • Retains 97% performance, reduces 40% parameters
    • Widely applied in industry
    • arXiv:1910.01108
  2. Jiao et al., "TinyBERT", Findings of EMNLP 2020
    • Two-stage distillation: pre-training distillation + task distillation
    • Comprehensive distillation of embedding layer, attention, hidden layers
    • Achieves 7.5x compression ratio
    • arXiv:1909.10351

Quantization and Distillation

  1. Mishra & Marr, "Apprentice: Using Knowledge Distillation Techniques To Improve Low-Precision Network Accuracy", ICLR 2018
    • Quantization-aware distillation
    • FP32 teacher helps INT8 student
    • Alleviates accuracy loss from quantization
    • arXiv:1711.05852
  2. Liu et al., "MetaDistiller: Network Self-Boosting via Meta-Learned Top-Down Distillation", ECCV 2020
    • Use NAS to search distillation strategies
    • Automatically select distillation layers and weights
    • Achieves SOTA on multiple tasks
    • arXiv:1910.03444

Summary

Knowledge distillation is a simple yet powerful idea: have small models learn large models' "thinking patterns" rather than simply imitating outputs. Through soft labels, temperature parameters, feature matching, relation preservation and other techniques, distillation can significantly reduce model size while maintaining performance close to original models. From Hinton's pioneering work to recent methods like CRD and TinyBERT, distillation techniques continue to evolve, becoming a core tool for model compression and transfer learning. Whether for mobile deployment, edge computing, or democratization of large models, knowledge distillation will play a crucial role.

  • Post title:Transfer Learning (5): Knowledge Distillation
  • Post author:Chen Kai
  • Create time:2024-11-27 09:30:00
  • Post link:https://www.chenk.top/transfer-learning-5-knowledge-distillation/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments