迁移学习(四)—— Few-Shot Learning
Chen Kai BOSS

Few-Shot Learning(小样本学习)是机器学习中最具挑战性的问题之一。人类可以从极少样本中快速学习新概念:看过几张图片就能识别新物种,听过几个例子就能理解新语言。但传统深度学习模型需要大量标注数据才能训练,在数据稀缺场景下表现糟糕。

Few-Shot Learning 的目标是:从每类只有少量样本(通常 1-10 个)的情况下学习分类器。这需要模型具备强大的泛化能力和迁移能力,从已知类别中学习"如何学习"的能力,然后快速适应新类别。本文将从第一性原理出发,推导度量学习和元学习的数学基础,详解 Siamese 网络、 Prototypical 网络、 MAML 等经典方法,并提供完整的 Prototypical 网络实现。

Few-Shot Learning 的挑战

问题定义

Few-Shot Learning 通常采用N-way K-shot设定:

  • N-way:有 个类别需要分类
  • K-shot:每个类别只有 个标注样本

例如 5-way 1-shot 表示从 5 个类别中识别,每个类别只有 1 个训练样本。

形式化地,设:

  • 支持集( Support Set) :训练样本
  • 查询集( Query Set) :测试样本

目标是训练一个模型 ,使得在支持集 上训练后,能在查询集 上取得高准确率。

为什么困难?

  1. 数据稀缺 个样本远不足以学习一个复杂分类器
  2. 过拟合风险:模型容易记住支持集的具体样本,而非学到可泛化的特征
  3. 类间相似:新类别可能与已知类别非常相似,难以区分

传统方法的失败

标准的经验风险最小化( ERM):

很小时会严重过拟合。即使加上正则化:

仍然不够,因为正则化只能防止参数过大,无法提供足够的归纳偏置( inductive bias)。

Few-Shot Learning 的核心思想

要在少量样本下学习,需要利用先验知识。 Few-Shot Learning 的核心是:

  1. 从已知类别学习先验:在大量已知类别( base classes)上训练
  2. 快速适应新类别:用学到的先验在新类别( novel classes)上快速适应

这等价于学习一个元学习器( meta-learner)

度量学习:基于相似度的分类

度量学习( Metric Learning)的思想是:学习一个嵌入空间,使得同类样本距离近、异类样本距离远。分类时,将查询样本与支持集样本比较距离,选择最近的类别。

Siamese 网络:孪生网络

Siamese 网络是最早的度量学习方法之一,通过对比损失( contrastive loss)学习嵌入空间。

架构

Siamese 网络包含两个权重共享的编码器

然后计算嵌入之间的距离:

$$

d(x_1, x_2) = |z_1 - z_2|_2 $$

对比损失

对比损失( Contrastive Loss)定义为:

$$

L = y d^2 + (1 - y) (0, m - d)^2 $$

其中:

  • :正样本对(同类),损失为 ,希望距离小
  • :负样本对(异类),损失为 ,希望距离大于 margin

直觉解释

  • 正样本对:拉近距离
  • 负样本对:如果距离小于 ,推开至少 的距离;如果已经大于 ,不再惩罚

Few-Shot 分类

给定支持集Extra close brace or missing open brace\mathcal{S} = \{(x_i, y_i)} _{i=1}^{NK} 和查询样本,预测为:

即选择支持集中距离最近的样本的类别。

Prototypical 网络:原型网络

Prototypical 网络是度量学习的改进版本,通过学习类别原型( prototype)来分类。

类别原型

给定类别 的支持集样本,类别原型定义为支持集样本嵌入的均值:

$$

p_c = _{x_i c} f(x_i) $$

直觉:原型是该类别在嵌入空间中的"中心",代表该类别的典型特征。

距离度量

Prototypical 网络使用欧氏距离度量查询样本与原型的距离:

$$

d(x_q, p_c) = |f_(x_q) - p_c|_2^2 $$

也可以使用余弦距离:

$$

d_{} (x_q, p_c) = 1 - $$

分类与损失

分类概率通过 softmax 计算:

$$

P(y = c | x_q) = $$

损失函数为负对数似然:

$$

L = -P(y = y_q | x_q) $$

Prototypical 网络的理论

Prototypical 网络可以看作是最近质心分类器( Nearest Centroid Classifier)在嵌入空间中的实现。在线性可分的情况下, Prototypical 网络等价于线性分类器

定理:在嵌入空间中,如果类别原型线性可分,则 Prototypical 网络的决策边界是线性的。

证明:查询样本 属于类别 的充要条件是:

$$

d(x_q, p_c) < d(x_q, p_{c'}), c' c $$

即:

展开:

简化:

这是 的线性不等式,决策边界是超平面。

匹配网络( Matching Networks)

匹配网络引入注意力机制记忆增强,进一步提升 Few-Shot Learning 性能。

注意力核

匹配网络使用注意力核( attention kernel)计算查询样本与支持集样本的相似度:

$$

a(x_q, x_i) = $$

其中 分别是查询集和支持集的编码器(可以不同)。

预测

查询样本的类别预测为支持集标签的加权和:

直觉:与查询样本相似度高的支持集样本对预测贡献更大。

Full Context Embeddings

匹配网络使用双向 LSTM对支持集进行编码,使每个样本的嵌入包含整个支持集的上下文信息:

$$

g(x_i) = ({x_1, , x_{NK}} , i) $$

这让模型能考虑支持集样本之间的关系。

关系网络( Relation Networks)

关系网络不使用固定的距离度量(如欧氏距离),而是学习一个度量函数

架构

关系网络包含两个模块:

  1. 嵌入模块 Missing superscript or subscript argument f_:将样本映射到嵌入空间
  2. 关系模块 Missing superscript or subscript argument g_:学习嵌入之间的相似度

给定查询样本 和支持集样本,计算:

$$

r_{q,i} = g_((f_(x_q), f_(x_i))) $$

其中 是学到的相似度。

损失函数

关系网络使用 MSE 损失:

$$

L = {(x_q, y_q) } {c=1}^N (r_{q,c} - _{y_q = c})^2 $$

其中 是查询样本与类别 的原型的相似度。

为什么学习度量?

固定距离(如欧氏距离)假设嵌入空间是各向同性的,但实际上不同维度可能有不同重要性。学习度量可以自适应地调整距离计算。

元学习:学会学习

元学习( Meta-Learning)的核心思想是:在多个任务上学习如何快速适应新任务

元学习的形式化

设有 个训练任务${_1, , _T} _i^{} _i^{} $。

元学习的目标是学习一个元参数,使得对任意新任务,用$^{} ^{} $ 上表现好:

其中 是在任务 上用适配后的参数:

MAML:模型无关元学习

Model-Agnostic Meta-Learning (MAML) 是最经典的元学习算法,通过学习一个好的初始化参数,使得模型能快速适应新任务。

MAML 算法

给定任务分布, MAML 优化:

即: 1. 在任务 的训练集上做一步(或多步)梯度下降: 2. 在任务 的测试集上计算损失: $$

L_{} ^{} ({} ') - _{} [L_{} ^{} (_{} ')] $$

MAML 的梯度计算

MAML 的关键是计算二阶梯度:

使用链式法则:

其中:

因此:

其中 是 Hessian 矩阵。

计算复杂度:计算 Hessian 需要 时间和空间, 是参数维度。实践中可以用一阶近似( First-Order MAML, FOMAML)

忽略 Hessian 项,计算复杂度降为

MAML 的直觉

MAML 学习的 位于损失曲面的"平坦"区域,使得沿任意方向(任意任务)的梯度下降都能快速降低损失。

类比 是一个"万能起点",从这个起点出发,只需几步就能到达任意任务的最优解。

Reptile:一阶元学习

Reptile 是 MAML 的简化版本,只使用一阶梯度,计算更高效。

Reptile 算法

  1. 采样任务$ k$ 步 SGD: 3. 更新元参数:

直觉: Reptile 将元参数朝任务特定参数移动。多次迭代后, 会位于所有任务特定参数的"中心"。

Reptile vs MAML

方法 梯度阶数 计算复杂度 性能
MAML 二阶 高(需要 Hessian) 最优
FOMAML 一阶(近似) 中等 接近 MAML
Reptile 一阶 略逊于 MAML

Reptile 在实践中与 FOMAML 性能相近,但实现更简单。

元学习的理论

元学习可以从贝叶斯视角理解。设任务参数 服从先验分布,则 MAML 等价于最大化后验:

其中:

$$

p({} ^{} | {} ^{} , ) = p({} ^{} | ) p(| {} ^{} , ) d $$ 是先验参数, 是先验分布。元学习学习一个好的先验。

Episode 训练:模拟 Few-Shot 场景

Few-Shot Learning 的训练采用episode 训练( episodic training),每个 episode 模拟一个 Few-Shot 任务。

Episode 采样

每个 episode 包含: 1. 从 base classes 中随机采样 个类别 2. 从每个类别中随机采样 个样本作为支持集 3. 从每个类别中随机采样 个样本作为查询集

形式化地,一个 episode 为:

其中:

Extra close brace or missing open brace\begin{aligned} \mathcal{S} &= \{(x_i^{(c,k)}, c) : c \in \{1, \ldots, N} , k \in \{1, \ldots, K} } \\ \mathcal{Q} &= \{(x_j^{(c,q)}, c) : c \in \{1, \ldots, N} , q \in \{1, \ldots, Q} } \end{aligned}

Episode 训练流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for epoch in range(num_epochs):
for episode in range(episodes_per_epoch):
# 采样 episode
classes = sample(base_classes, N)
support = sample_from_classes(classes, K)
query = sample_from_classes(classes, Q)

# 前向传播
prototypes = compute_prototypes(support)
logits = compute_distances(query, prototypes)

# 计算损失
loss = cross_entropy(logits, query_labels)

# 反向传播
loss.backward()
optimizer.step()

Episode 训练的直觉

Episode 训练让模型在训练时就面临 Few-Shot 场景,强迫模型学习如何从少量样本中泛化。这是一种课程学习( curriculum learning):训练时的困难度与测试时相同。

完整实现: Prototypical 网络

下面提供一个完整的 Prototypical 网络实现,包含 episode 采样、距离计算、支持集与查询集划分等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score


class ConvBlock(nn.Module):
"""卷积块"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pool(x)
return x


class ProtoNetEncoder(nn.Module):
"""Prototypical 网络编码器"""

def __init__(self, input_channels=3, hidden_dim=64):
super().__init__()
self.conv1 = ConvBlock(input_channels, hidden_dim)
self.conv2 = ConvBlock(hidden_dim, hidden_dim)
self.conv3 = ConvBlock(hidden_dim, hidden_dim)
self.conv4 = ConvBlock(hidden_dim, hidden_dim)

def forward(self, x):
x = self.conv1(x) # 84x84 -> 42x42
x = self.conv2(x) # 42x42 -> 21x21
x = self.conv3(x) # 21x21 -> 10x10
x = self.conv4(x) # 10x10 -> 5x5
x = x.view(x.size(0), -1) # Flatten
return x


class PrototypicalNetwork(nn.Module):
"""Prototypical 网络"""

def __init__(self, encoder):
super().__init__()
self.encoder = encoder

def compute_prototypes(self, support_embeddings, support_labels, n_way):
"""
计算类别原型

Args:
support_embeddings: (n_way * n_support, embedding_dim)
support_labels: (n_way * n_support,)
n_way: 类别数

Returns:
prototypes: (n_way, embedding_dim)
"""
prototypes = []
for c in range(n_way):
# 获取类别 c 的所有样本
class_mask = (support_labels == c)
class_embeddings = support_embeddings[class_mask]
# 计算均值作为原型
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)

prototypes = torch.stack(prototypes)
return prototypes

def compute_distances(self, query_embeddings, prototypes):
"""
计算查询样本与原型的欧氏距离

Args:
query_embeddings: (n_query, embedding_dim)
prototypes: (n_way, embedding_dim)

Returns:
distances: (n_query, n_way)
"""
# 使用广播计算欧氏距离
# (n_query, 1, embedding_dim) - (1, n_way, embedding_dim)
# -> (n_query, n_way, embedding_dim)
distances = torch.cdist(query_embeddings, prototypes, p=2)
return distances

def forward(self, support_images, support_labels, query_images, n_way, n_support):
"""
前向传播

Args:
support_images: (n_way * n_support, C, H, W)
support_labels: (n_way * n_support,)
query_images: (n_query, C, H, W)
n_way: 类别数
n_support: 每类支持样本数

Returns:
logits: (n_query, n_way)
"""
# 编码
support_embeddings = self.encoder(support_images)
query_embeddings = self.encoder(query_images)

# 计算原型
prototypes = self.compute_prototypes(support_embeddings, support_labels, n_way)

# 计算距离
distances = self.compute_distances(query_embeddings, prototypes)

# 负距离作为 logits(距离越小, logits 越大)
logits = -distances

return logits


class FewShotDataset(Dataset):
"""Few-Shot 数据集"""

def __init__(self, data, labels):
"""
Args:
data: (N, C, H, W) 所有图像
labels: (N,) 所有标签
"""
self.data = data
self.labels = labels

# 按类别组织数据
self.classes = np.unique(labels)
self.class_to_indices = {}
for c in self.classes:
self.class_to_indices[c] = np.where(labels == c)[0]

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]


class EpisodeSampler:
"""Episode 采样器"""

def __init__(self, dataset, n_way, n_support, n_query, n_episodes):
"""
Args:
dataset: FewShotDataset
n_way: 每个 episode 的类别数
n_support: 每类支持样本数
n_query: 每类查询样本数
n_episodes: episode 总数
"""
self.dataset = dataset
self.n_way = n_way
self.n_support = n_support
self.n_query = n_query
self.n_episodes = n_episodes

def __iter__(self):
for _ in range(self.n_episodes):
yield self.sample_episode()

def sample_episode(self):
"""采样一个 episode"""
# 随机选择 n_way 个类别
selected_classes = np.random.choice(
self.dataset.classes,
size=self.n_way,
replace=False
)

support_images = []
support_labels = []
query_images = []
query_labels = []

for i, c in enumerate(selected_classes):
# 获取类别 c 的所有样本索引
class_indices = self.dataset.class_to_indices[c]

# 随机选择 n_support + n_query 个样本
selected_indices = np.random.choice(
class_indices,
size=self.n_support + self.n_query,
replace=False
)

# 划分支持集和查询集
support_indices = selected_indices[:self.n_support]
query_indices = selected_indices[self.n_support:]

# 添加到支持集
for idx in support_indices:
support_images.append(self.dataset.data[idx])
support_labels.append(i) # 使用相对标签 0, 1, ..., n_way-1

# 添加到查询集
for idx in query_indices:
query_images.append(self.dataset.data[idx])
query_labels.append(i)

# 转换为 tensor
support_images = torch.stack([torch.FloatTensor(img) for img in support_images])
support_labels = torch.LongTensor(support_labels)
query_images = torch.stack([torch.FloatTensor(img) for img in query_images])
query_labels = torch.LongTensor(query_labels)

return support_images, support_labels, query_images, query_labels


class ProtoNetTrainer:
"""Prototypical 网络训练器"""

def __init__(
self,
model,
train_dataset,
val_dataset,
n_way=5,
n_support=5,
n_query=15,
n_episodes=100,
learning_rate=1e-3,
device='cuda'
):
self.model = model.to(device)
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.n_way = n_way
self.n_support = n_support
self.n_query = n_query
self.n_episodes = n_episodes
self.device = device

self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss()

def train_epoch(self):
"""训练一个 epoch"""
self.model.train()

sampler = EpisodeSampler(
self.train_dataset,
self.n_way,
self.n_support,
self.n_query,
self.n_episodes
)

total_loss = 0
total_acc = 0

progress_bar = tqdm(sampler, desc='Training')

for support_images, support_labels, query_images, query_labels in progress_bar:
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)

# 前向传播
logits = self.model(
support_images,
support_labels,
query_images,
self.n_way,
self.n_support
)

# 计算损失
loss = self.criterion(logits, query_labels)

# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

# 计算准确率
preds = torch.argmax(logits, dim=1)
acc = (preds == query_labels).float().mean().item()

total_loss += loss.item()
total_acc += acc

progress_bar.set_postfix({
'loss': loss.item(),
'acc': acc
})

avg_loss = total_loss / self.n_episodes
avg_acc = total_acc / self.n_episodes

return avg_loss, avg_acc

def evaluate(self, n_eval_episodes=100):
"""评估模型"""
self.model.eval()

sampler = EpisodeSampler(
self.val_dataset,
self.n_way,
self.n_support,
self.n_query,
n_eval_episodes
)

total_loss = 0
total_acc = 0

with torch.no_grad():
for support_images, support_labels, query_images, query_labels in tqdm(sampler, desc='Evaluating'):
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)

logits = self.model(
support_images,
support_labels,
query_images,
self.n_way,
self.n_support
)

loss = self.criterion(logits, query_labels)
preds = torch.argmax(logits, dim=1)
acc = (preds == query_labels).float().mean().item()

total_loss += loss.item()
total_acc += acc

avg_loss = total_loss / n_eval_episodes
avg_acc = total_acc / n_eval_episodes

return avg_loss, avg_acc

def train(self, num_epochs=100):
"""完整训练流程"""
best_val_acc = 0.0

for epoch in range(num_epochs):
print(f"\nEpoch {epoch + 1}/{num_epochs}")

# 训练
train_loss, train_acc = self.train_epoch()
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

# 评估
val_loss, val_acc = self.evaluate()
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(self.model.state_dict(), 'best_protonet.pt')
print(f"Saved best model with accuracy {best_val_acc:.4f}")


# 使用示例
def main():
# 模拟数据: Omniglot 或 miniImageNet
# 这里用随机数据演示
num_classes = 64 # base classes 数量
samples_per_class = 600
image_size = 84

# 生成模拟数据
all_data = []
all_labels = []

for c in range(num_classes):
class_data = torch.randn(samples_per_class, 3, image_size, image_size)
class_labels = torch.full((samples_per_class,), c)
all_data.append(class_data)
all_labels.append(class_labels)

all_data = torch.cat(all_data, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# 划分训练集和验证集
train_classes = num_classes * 4 // 5
train_mask = all_labels < train_classes
val_mask = all_labels >= train_classes

train_dataset = FewShotDataset(
all_data[train_mask].numpy(),
all_labels[train_mask].numpy()
)

val_dataset = FewShotDataset(
all_data[val_mask].numpy(),
all_labels[val_mask].numpy()
)

# 模型
encoder = ProtoNetEncoder(input_channels=3, hidden_dim=64)
model = PrototypicalNetwork(encoder)

# 训练器
trainer = ProtoNetTrainer(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
n_way=5,
n_support=5,
n_query=15,
n_episodes=100,
learning_rate=1e-3
)

# 训练
trainer.train(num_epochs=50)


if __name__ == '__main__':
main()

代码详解

Episode 采样

EpisodeSampler实现了 Few-Shot Learning 的核心采样逻辑:

1
2
3
4
5
6
7
8
9
10
11
def sample_episode(self):
# 随机选择 n_way 个类别
selected_classes = np.random.choice(classes, n_way, replace=False)

for c in selected_classes:
# 从类别 c 中采样 n_support + n_query 个样本
selected_indices = np.random.choice(class_indices, n_support + n_query, replace=False)

# 划分支持集和查询集
support_indices = selected_indices[:n_support]
query_indices = selected_indices[n_support:]

原型计算

compute_prototypes计算每个类别的原型(均值):

1
2
3
4
5
for c in range(n_way):
class_mask = (support_labels == c)
class_embeddings = support_embeddings[class_mask]
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)

距离计算

使用torch.cdist高效计算欧氏距离:

1
2
distances = torch.cdist(query_embeddings, prototypes, p=2)
logits = -distances # 负距离作为 logits

深度 Q&A

Q1: Few-Shot Learning 与 Transfer Learning 有什么区别?

联系:都是利用已知知识学习新任务

区别

维度 Transfer Learning Few-Shot Learning
数据量 目标任务有较多标注数据 目标任务只有极少标注数据( 1-10 个)
适配方式 微调预训练模型 基于度量或元学习快速适配
训练范式 标准监督学习 Episode 训练

Few-Shot Learning 可以看作是 Transfer Learning 的极端情况:目标任务数据极度稀缺。

Q2: 为什么 Prototypical 网络使用均值作为原型?有理论支持吗?

理论支持:在高斯分布假设下,类别原型是最优贝叶斯分类器。

证明:假设类别 的样本服从高斯分布,则后验概率为:

$$

P(y = c | x) (-(x - _c){-1} (x - _c)) $$

取对数:

(各向同性)时,这等价于欧氏距离:

因此,使用均值作为原型并基于欧氏距离分类是贝叶斯最优的(在高斯假设下)。

Q3: MAML 为什么需要二阶梯度?能否避免?

MAML 需要二阶梯度是因为要对适配后的参数 关于元参数 求导:

需要计算,这是二阶导数。

避免方法

  1. FOMAML:忽略二阶项,只用一阶梯度
  2. Reptile:直接朝适配后参数移动,无需二阶梯度

实验表明 FOMAML 和 Reptile 性能与 MAML 接近,但计算效率高得多。

Q4: Episode 训练和普通训练有什么本质区别?

普通训练:每个 batch 包含多个类别的样本,模型学习所有类别的判别边界

Episode 训练:每个 episode 只包含 N 个类别,模型学习"如何从 N 个类别的少量样本中学习"

本质区别: - 普通训练学习任务特定知识(哪些特征区分哪些类别) - Episode 训练学习元知识(如何快速学习新任务)

类比: - 普通训练像"学习特定科目"(学数学、学物理) - Episode 训练像"学习如何学习"(学习方法论)

Q5: 为什么 Few-Shot Learning 需要大量 base classes?

虽然目标任务( novel classes)只有少量样本,但要学会"如何学习"需要在多个任务上训练。

数据需求: - Base classes 数量:通常需要几十到上百个类别 - 每个 base class 样本数:通常几百个

直觉:就像人类虽然能从少量样本学习新概念,但这种能力是通过一生的经验积累的。 Few-Shot Learning 模型需要在大量 base classes 上学习这种能力。

实验证据: - Omniglot: 1200+ base classes - miniImageNet: 64 base classes - tieredImageNet: 351 base classes

Base classes 越多, Few-Shot Learning 性能越好。

Q6: Prototypical 网络能否用于回归任务?

可以,但需要修改。分类任务中,原型是离散的(每个类别一个),回归任务中,需要连续的原型空间

方法 1:核回归

将原型看作核中心,预测为加权平均:

方法 2:条件神经过程( Conditional Neural Process, CNP)

学习一个函数分布,给定支持集预测查询点的分布:

$$

p(y_q | x_q, ) = ((x_q, ), ^2(x_q, )) $$

Q7: 如何选择 Few-Shot Learning 方法?

决策树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
1. 数据特点?
├─ 图像数据 → Prototypical 网络、 Matching 网络
├─ 时序数据 → MAML + LSTM
└─ 图数据 → Graph Neural Network + Meta-Learning

2. 计算资源?
├─ 资源充足 → MAML(二阶梯度)
└─ 资源有限 → Prototypical 网络、 Reptile

3. 任务多样性?
├─ 任务相似 → 度量学习( Prototypical)
└─ 任务多样 → 元学习( MAML)

4. 是否需要可解释性?
├─ 需要 → Prototypical 网络(原型可视化)
└─ 不需要 → Relation 网络、 MAML

Q8: Few-Shot Learning 在实际应用中的挑战是什么?

  1. 域偏移: Base classes 和 novel classes 分布不同
    • 解决:域适应 + Few-Shot Learning( Cross-Domain Few-Shot Learning)
  2. 类不平衡: Novel classes 样本数可能不同
    • 解决:加权损失、重采样
  3. 标注噪声:少量样本中的标注错误影响大
    • 解决:鲁棒损失函数、去噪方法
  4. 计算效率: Episode 训练比普通训练慢
    • 解决:预训练 + 少量 episode 微调
  5. 泛化能力:模型可能过拟合 base classes
    • 解决:增大 base classes 多样性、正则化

Q9: Prototypical 网络和 k-NN 有什么区别?

Prototypical 网络可以看作是学习嵌入空间的 k-NN

方法 距离度量 嵌入空间 原型
k-NN 固定(欧氏、余弦) 原始特征空间 每个样本
Prototypical 学习的 学习的嵌入空间 类别均值

关键区别: 1. 嵌入学习: Prototypical 网络学习一个嵌入函数Missing superscript or subscript argument f_,使得嵌入空间更适合 Few-Shot Learning 2. 原型聚合:使用类别均值而非每个样本,更鲁棒

实验:在相同嵌入空间下, Prototypical 网络略优于 k-NN,但差异不大。主要优势来自嵌入学习。

Q10: MAML 的初始化为什么重要?

MAML 学习的初始化位于损失曲面的平坦区域,使得:

  1. 快速适配:沿任意方向梯度下降都能快速降低损失
  2. 泛化能力强:平坦区域对应更好的泛化( Sharp Minima vs Flat Minima)

数学上, MAML 等价于最小化损失的二阶泰勒展开:

$$

L(') L() + L()^(' - ) + (' - )^H (' - ) $$

MAML 希望 Hessian 的特征值都较小(平坦),这样沿任意方向移动损失增长都慢。

Q11: Few-Shot Learning 能否用于强化学习?

可以! Few-Shot Reinforcement Learning 是一个活跃的研究方向。

挑战: 1. 样本效率更低(需要交互) 2. 奖励稀疏 3. 探索-利用权衡

方法: 1. MAML for RL:在多个任务上元学习策略 2. Meta-RL with Context:学习任务表示,条件化策略 3. Model-Based Meta-RL:学习动力学模型,规划

应用: - 机器人快速适应新任务 - 游戏 AI 快速学习新游戏 - 推荐系统快速适应新用户

Q12: 如何评估 Few-Shot Learning 模型?

标准评估协议:

  1. 数据划分

    • Base classes:训练
    • Val classes:验证超参数
    • Novel classes:最终测试
  2. 评估指标

    • 准确率(主要)
    • 95%置信区间(报告不确定性)
    • 每类准确率(检查类不平衡)
  3. 评估步骤

    1
    2
    3
    4
    5
    for episode in test_episodes:
    sample N-way K-shot task from novel classes
    compute accuracy on query set

    report: mean ± 95% confidence interval

  4. 标准基准

    • Omniglot: 20-way 1-shot, 20-way 5-shot
    • miniImageNet: 5-way 1-shot, 5-way 5-shot
    • tieredImageNet: 5-way 1-shot, 5-way 5-shot

注意:必须报告置信区间,因为 Few-Shot Learning 方差较大。

相关论文

  1. Siamese Neural Networks for One-shot Image Recognition
    Koch et al., ICML Deep Learning Workshop 2015
    https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

  2. Prototypical Networks for Few-shot Learning
    Snell et al., NeurIPS 2017
    https://arxiv.org/abs/1703.05175

  3. Matching Networks for One Shot Learning
    Vinyals et al., NeurIPS 2016
    https://arxiv.org/abs/1606.04080

  4. Learning to Compare: Relation Network for Few-Shot Learning
    Sung et al., CVPR 2018
    https://arxiv.org/abs/1711.06025

  5. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (MAML)
    Finn et al., ICML 2017
    https://arxiv.org/abs/1703.03400

  6. On First-Order Meta-Learning Algorithms (Reptile)
    Nichol et al., arXiv 2018
    https://arxiv.org/abs/1803.02999

  7. A Closer Look at Few-shot Classification
    Chen et al., ICLR 2019
    https://arxiv.org/abs/1904.04232

  8. Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
    Triantafillou et al., ICLR 2020
    https://arxiv.org/abs/1903.03096

  9. Learning to Learn with Conditional Class Dependencies
    Bertinetto et al., ICLR 2019
    https://arxiv.org/abs/1806.03961

  10. TADAM: Task dependent adaptive metric for improved few-shot learning
    Oreshkin et al., NeurIPS 2018
    https://arxiv.org/abs/1805.10123

  11. Meta-Learning with Differentiable Convex Optimization
    Lee et al., CVPR 2019
    https://arxiv.org/abs/1904.03758

  12. Generalizing from a Few Examples: A Survey on Few-Shot Learning
    Wang et al., ACM Computing Surveys 2020
    https://arxiv.org/abs/1904.05046

总结

Few-Shot Learning 解决了深度学习最大的瓶颈之一:数据稀缺。本文从第一性原理出发,推导了度量学习( Siamese 、 Prototypical 、 Matching 、 Relation Networks)和元学习( MAML 、 Reptile)的数学基础,详细解析了它们的架构、损失函数、优化方法。

我们看到, Few-Shot Learning 的核心是利用先验知识:度量学习通过学习嵌入空间使得度量可迁移,元学习通过学习初始化或优化器使得适配快速。 Episode 训练是关键,它让模型在训练时就面临 Few-Shot 场景,学会"如何学习"。

完整的 Prototypical 网络实现展示了 episode 采样、原型计算、距离度量等核心技术。下一章我们将探讨知识蒸馏,研究如何将大模型的知识迁移到小模型。

  • 本文标题:迁移学习(四)—— Few-Shot Learning
  • 本文作者:Chen Kai
  • 创建时间:2024-11-21 15:45:00
  • 本文链接:https://www.chenk.top/%E8%BF%81%E7%A7%BB%E5%AD%A6%E4%B9%A0%EF%BC%88%E5%9B%9B%EF%BC%89%E2%80%94%E2%80%94-Few-Shot-Learning/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论