Few-Shot
Learning(小样本学习)是机器学习中最具挑战性的问题之一。人类可以从极少样本中快速学习新概念:看过几张图片就能识别新物种,听过几个例子就能理解新语言。但传统深度学习模型需要大量标注数据才能训练,在数据稀缺场景下表现糟糕。
Few-Shot Learning 的目标是:从每类只有少量样本(通常 1-10
个)的情况下学习分类器 。这需要模型具备强大的泛化能力和迁移能力,从已知类别中学习"如何学习"的能力,然后快速适应新类别。本文将从第一性原理出发,推导度量学习和元学习的数学基础,详解
Siamese 网络、 Prototypical 网络、 MAML 等经典方法,并提供完整的
Prototypical 网络实现。
Few-Shot Learning 的挑战
问题定义
Few-Shot Learning 通常采用N-way K-shot 设定:
N-way :有
个类别需要分类
K-shot :每个类别只有 个标注样本
例如 5-way 1-shot 表示从 5 个类别中识别,每个类别只有 1
个训练样本。
形式化地,设:
支持集( Support Set) :训练样本
查询集( Query Set) :测试样本
目标是训练一个模型 ,使得在支持集 上训练后,能在查询集 上取得高准确率。
为什么困难?
数据稀缺 :
个样本远不足以学习一个复杂分类器
过拟合风险 :模型容易记住支持集的具体样本,而非学到可泛化的特征
类间相似 :新类别可能与已知类别非常相似,难以区分
传统方法的失败
标准的经验风险最小化( ERM):
在
很小时会严重过拟合。即使加上正则化:
仍然不够,因为正则化只能防止参数过大,无法提供足够的归纳偏置(
inductive bias)。
Few-Shot Learning 的核心思想
要在少量样本下学习,需要利用先验知识 。 Few-Shot
Learning 的核心是:
从已知类别学习先验 :在大量已知类别( base
classes)上训练
快速适应新类别 :用学到的先验在新类别( novel
classes)上快速适应
这等价于学习一个元学习器( meta-learner) : 。
度量学习:基于相似度的分类
度量学习( Metric
Learning)的思想是:学习一个嵌入空间,使得同类样本距离近、异类样本距离远 。分类时,将查询样本与支持集样本比较距离,选择最近的类别。
Siamese 网络:孪生网络
Siamese 网络是最早的度量学习方法之一,通过对比损失( contrastive
loss)学习嵌入空间。
架构
Siamese 网络包含两个权重共享的编码器 :
然后计算嵌入之间的距离:
$$
d(x_1, x_2) = |z_1 - z_2|_2 $$
对比损失
对比损失( Contrastive Loss)定义为:
$$
L = y d^2 + (1 - y) (0, m - d)^2 $$
其中:
:正样本对(同类),损失为
,希望距离小
:负样本对(异类),损失为
,希望距离大于
margin
直觉解释 :
正样本对:拉近距离
负样本对:如果距离小于 ,推开至少 的距离;如果已经大于 ,不再惩罚
Few-Shot 分类
给定支持集Extra close brace or missing open brace \mathcal{S} = \{(x_i, y_i)} _{i=1}^{NK} 和查询样本 ,预测为:
即选择支持集中距离最近的样本的类别。
Prototypical 网络:原型网络
Prototypical 网络是度量学习的改进版本,通过学习类别原型(
prototype) 来分类。
类别原型
给定类别 的支持集样本 ,类别原型定义为支持集样本嵌入的均值:
$$
p_c = _{x_i c} f (x_i) $$
直觉 :原型是该类别在嵌入空间中的"中心",代表该类别的典型特征。
距离度量
Prototypical 网络使用欧氏距离度量查询样本与原型的距离:
$$
d(x_q, p_c) = |f_(x_q) - p_c|_2^2 $$
也可以使用余弦距离:
$$
d_{} (x_q, p_c) = 1 - $$
分类与损失
分类概率通过 softmax 计算:
$$
P(y = c | x_q) = $$
损失函数为负对数似然:
$$
L = -P(y = y_q | x_q) $$
Prototypical 网络的理论
Prototypical 网络可以看作是最近质心分类器( Nearest Centroid
Classifier) 在嵌入空间中的实现。在线性可分的情况下,
Prototypical 网络等价于线性分类器 。
定理 :在嵌入空间中,如果类别原型线性可分,则
Prototypical 网络的决策边界是线性的。
证明 :查询样本 属于类别 的充要条件是:
$$
d(x_q, p_c) < d(x_q, p_{c'}), c' c $$
即:
展开:
简化:
这是 的线性不等式,决策边界是超平面。
匹配网络( Matching Networks)
匹配网络引入注意力机制 和记忆增强 ,进一步提升
Few-Shot Learning 性能。
注意力核
匹配网络使用注意力核( attention
kernel)计算查询样本与支持集样本的相似度:
$$
a(x_q, x_i) = $$
其中 和 分别是查询集和支持集的编码器(可以不同)。
预测
查询样本的类别预测为支持集标签的加权和:
直觉 :与查询样本相似度高的支持集样本对预测贡献更大。
Full Context Embeddings
匹配网络使用双向
LSTM 对支持集进行编码,使每个样本的嵌入包含整个支持集的上下文信息:
$$
g(x_i) = ({x_1, , x_{NK}} , i) $$
这让模型能考虑支持集样本之间的关系。
关系网络( Relation Networks)
关系网络不使用固定的距离度量(如欧氏距离),而是学习一个度量函数 。
架构
关系网络包含两个模块:
嵌入模块 Missing superscript or subscript argument f_ :将样本映射到嵌入空间
关系模块 Missing superscript or subscript argument g_ :学习嵌入之间的相似度
给定查询样本 和支持集样本 ,计算:
$$
r_{q,i} = g_((f_(x_q), f_(x_i))) $$
其中 是学到的相似度。
损失函数
关系网络使用 MSE 损失:
$$
L = {(x_q, y_q) } {c=1}^N (r_{q,c} - _{y_q = c})^2 $$
其中 是查询样本与类别 的原型的相似度。
为什么学习度量?
固定距离(如欧氏距离)假设嵌入空间是各向同性的,但实际上不同维度可能有不同重要性。学习度量可以自适应地调整距离计算。
元学习:学会学习
元学习(
Meta-Learning)的核心思想是:在多个任务上学习如何快速适应新任务 。
元学习的形式化
设有 个训练任务${_1, , _T}
, 每 个 任 务 包 含 训 练 集 _i^{} 和 测 试 集 _i^{} $。
元学习的目标是学习一个元参数 ,使得对任意新任务 ,用$^{} 适 配 后 在 ^{} $ 上表现好:
其中 是在任务 上用 适配后的参数:
MAML:模型无关元学习
Model-Agnostic Meta-Learning (MAML)
是最经典的元学习算法,通过学习一个好的初始化参数,使得模型能快速适应新任务。
MAML 算法
给定任务分布 , MAML 优化:
即: 1. 在任务
的训练集上做一步(或多步)梯度下降: 2. 在任务
的测试集上计算损失: $$
L_{} ^{} ({} ') 对 所 有 任 务 的 测 试 损 失 求 平 均 , 更 新 元 参 数 : - _{} [L_{} ^{} (_{} ')] $$
MAML 的梯度计算
MAML 的关键是计算二阶梯度:
使用链式法则:
其中:
因此:
其中 是 Hessian 矩阵。
计算复杂度 :计算 Hessian 需要 时间和空间,
是参数维度。实践中可以用一阶近似( First-Order MAML,
FOMAML) :
忽略 Hessian 项,计算复杂度降为 。
MAML 的直觉
MAML 学习的
位于损失曲面的"平坦"区域,使得沿任意方向(任意任务)的梯度下降都能快速降低损失。
类比 :
是一个"万能起点",从这个起点出发,只需几步就能到达任意任务的最优解。
Reptile:一阶元学习
Reptile 是 MAML 的简化版本,只使用一阶梯度,计算更高效。
Reptile 算法
采样任务$在 任 务 上 做 k$ 步
SGD: 3. 更新元参数:
直觉 : Reptile
将元参数朝任务特定参数移动。多次迭代后,
会位于所有任务特定参数的"中心"。
Reptile vs MAML
方法
梯度阶数
计算复杂度
性能
MAML
二阶
高(需要 Hessian)
最优
FOMAML
一阶(近似)
中等
接近 MAML
Reptile
一阶
低
略逊于 MAML
Reptile 在实践中与 FOMAML 性能相近,但实现更简单。
元学习的理论
元学习可以从贝叶斯视角理解。设任务参数 服从先验分布 ,则
MAML 等价于最大化后验:
其中:
$$
p({} ^{} | {} ^{} , ) = p({} ^{} | ) p(| {} ^{} , )
d $$ 是先验参数,
是先验分布。元学习学习一个好的先验。
Episode 训练:模拟 Few-Shot
场景
Few-Shot Learning 的训练采用episode 训练( episodic
training) ,每个 episode 模拟一个 Few-Shot 任务。
Episode 采样
每个 episode 包含: 1. 从 base classes 中随机采样 个类别 2. 从每个类别中随机采样 个样本作为支持集 3.
从每个类别中随机采样
个样本作为查询集
形式化地,一个 episode 为:
其中:
Extra close brace or missing open brace \begin{aligned} \mathcal{S} &= \{(x_i^{(c,k)}, c) : c \in \{1, \ldots, N} , k \in \{1, \ldots, K} } \\ \mathcal{Q} &= \{(x_j^{(c,q)}, c) : c \in \{1, \ldots, N} , q \in \{1, \ldots, Q} } \end{aligned}
Episode 训练流程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 for epoch in range(num_epochs): for episode in range(episodes_per_epoch): # 采样 episode classes = sample(base_classes, N) support = sample_from_classes(classes, K) query = sample_from_classes(classes, Q) # 前向传播 prototypes = compute_prototypes(support) logits = compute_distances(query, prototypes) # 计算损失 loss = cross_entropy(logits, query_labels) # 反向传播 loss.backward() optimizer.step()
Episode 训练的直觉
Episode 训练让模型在训练时就面临 Few-Shot
场景,强迫模型学习如何从少量样本中泛化。这是一种课程学习(
curriculum learning) :训练时的困难度与测试时相同。
完整实现: Prototypical 网络
下面提供一个完整的 Prototypical 网络实现,包含 episode
采样、距离计算、支持集与查询集划分等。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderimport numpy as npfrom tqdm import tqdmfrom sklearn.metrics import accuracy_scoreclass ConvBlock (nn.Module): """卷积块""" def __init__ (self, in_channels, out_channels ): super ().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3 , padding=1 ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True ) self.pool = nn.MaxPool2d(2 ) def forward (self, x ): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool(x) return x class ProtoNetEncoder (nn.Module): """Prototypical 网络编码器""" def __init__ (self, input_channels=3 , hidden_dim=64 ): super ().__init__() self.conv1 = ConvBlock(input_channels, hidden_dim) self.conv2 = ConvBlock(hidden_dim, hidden_dim) self.conv3 = ConvBlock(hidden_dim, hidden_dim) self.conv4 = ConvBlock(hidden_dim, hidden_dim) def forward (self, x ): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.view(x.size(0 ), -1 ) return x class PrototypicalNetwork (nn.Module): """Prototypical 网络""" def __init__ (self, encoder ): super ().__init__() self.encoder = encoder def compute_prototypes (self, support_embeddings, support_labels, n_way ): """ 计算类别原型 Args: support_embeddings: (n_way * n_support, embedding_dim) support_labels: (n_way * n_support,) n_way: 类别数 Returns: prototypes: (n_way, embedding_dim) """ prototypes = [] for c in range (n_way): class_mask = (support_labels == c) class_embeddings = support_embeddings[class_mask] prototype = class_embeddings.mean(dim=0 ) prototypes.append(prototype) prototypes = torch.stack(prototypes) return prototypes def compute_distances (self, query_embeddings, prototypes ): """ 计算查询样本与原型的欧氏距离 Args: query_embeddings: (n_query, embedding_dim) prototypes: (n_way, embedding_dim) Returns: distances: (n_query, n_way) """ distances = torch.cdist(query_embeddings, prototypes, p=2 ) return distances def forward (self, support_images, support_labels, query_images, n_way, n_support ): """ 前向传播 Args: support_images: (n_way * n_support, C, H, W) support_labels: (n_way * n_support,) query_images: (n_query, C, H, W) n_way: 类别数 n_support: 每类支持样本数 Returns: logits: (n_query, n_way) """ support_embeddings = self.encoder(support_images) query_embeddings = self.encoder(query_images) prototypes = self.compute_prototypes(support_embeddings, support_labels, n_way) distances = self.compute_distances(query_embeddings, prototypes) logits = -distances return logits class FewShotDataset (Dataset ): """Few-Shot 数据集""" def __init__ (self, data, labels ): """ Args: data: (N, C, H, W) 所有图像 labels: (N,) 所有标签 """ self.data = data self.labels = labels self.classes = np.unique(labels) self.class_to_indices = {} for c in self.classes: self.class_to_indices[c] = np.where(labels == c)[0 ] def __len__ (self ): return len (self.data) def __getitem__ (self, idx ): return self.data[idx], self.labels[idx] class EpisodeSampler : """Episode 采样器""" def __init__ (self, dataset, n_way, n_support, n_query, n_episodes ): """ Args: dataset: FewShotDataset n_way: 每个 episode 的类别数 n_support: 每类支持样本数 n_query: 每类查询样本数 n_episodes: episode 总数 """ self.dataset = dataset self.n_way = n_way self.n_support = n_support self.n_query = n_query self.n_episodes = n_episodes def __iter__ (self ): for _ in range (self.n_episodes): yield self.sample_episode() def sample_episode (self ): """采样一个 episode""" selected_classes = np.random.choice( self.dataset.classes, size=self.n_way, replace=False ) support_images = [] support_labels = [] query_images = [] query_labels = [] for i, c in enumerate (selected_classes): class_indices = self.dataset.class_to_indices[c] selected_indices = np.random.choice( class_indices, size=self.n_support + self.n_query, replace=False ) support_indices = selected_indices[:self.n_support] query_indices = selected_indices[self.n_support:] for idx in support_indices: support_images.append(self.dataset.data[idx]) support_labels.append(i) for idx in query_indices: query_images.append(self.dataset.data[idx]) query_labels.append(i) support_images = torch.stack([torch.FloatTensor(img) for img in support_images]) support_labels = torch.LongTensor(support_labels) query_images = torch.stack([torch.FloatTensor(img) for img in query_images]) query_labels = torch.LongTensor(query_labels) return support_images, support_labels, query_images, query_labels class ProtoNetTrainer : """Prototypical 网络训练器""" def __init__ ( self, model, train_dataset, val_dataset, n_way=5 , n_support=5 , n_query=15 , n_episodes=100 , learning_rate=1e-3 , device='cuda' ): self.model = model.to(device) self.train_dataset = train_dataset self.val_dataset = val_dataset self.n_way = n_way self.n_support = n_support self.n_query = n_query self.n_episodes = n_episodes self.device = device self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) self.criterion = nn.CrossEntropyLoss() def train_epoch (self ): """训练一个 epoch""" self.model.train() sampler = EpisodeSampler( self.train_dataset, self.n_way, self.n_support, self.n_query, self.n_episodes ) total_loss = 0 total_acc = 0 progress_bar = tqdm(sampler, desc='Training' ) for support_images, support_labels, query_images, query_labels in progress_bar: support_images = support_images.to(self.device) support_labels = support_labels.to(self.device) query_images = query_images.to(self.device) query_labels = query_labels.to(self.device) logits = self.model( support_images, support_labels, query_images, self.n_way, self.n_support ) loss = self.criterion(logits, query_labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() preds = torch.argmax(logits, dim=1 ) acc = (preds == query_labels).float ().mean().item() total_loss += loss.item() total_acc += acc progress_bar.set_postfix({ 'loss' : loss.item(), 'acc' : acc }) avg_loss = total_loss / self.n_episodes avg_acc = total_acc / self.n_episodes return avg_loss, avg_acc def evaluate (self, n_eval_episodes=100 ): """评估模型""" self.model.eval () sampler = EpisodeSampler( self.val_dataset, self.n_way, self.n_support, self.n_query, n_eval_episodes ) total_loss = 0 total_acc = 0 with torch.no_grad(): for support_images, support_labels, query_images, query_labels in tqdm(sampler, desc='Evaluating' ): support_images = support_images.to(self.device) support_labels = support_labels.to(self.device) query_images = query_images.to(self.device) query_labels = query_labels.to(self.device) logits = self.model( support_images, support_labels, query_images, self.n_way, self.n_support ) loss = self.criterion(logits, query_labels) preds = torch.argmax(logits, dim=1 ) acc = (preds == query_labels).float ().mean().item() total_loss += loss.item() total_acc += acc avg_loss = total_loss / n_eval_episodes avg_acc = total_acc / n_eval_episodes return avg_loss, avg_acc def train (self, num_epochs=100 ): """完整训练流程""" best_val_acc = 0.0 for epoch in range (num_epochs): print (f"\nEpoch {epoch + 1 } /{num_epochs} " ) train_loss, train_acc = self.train_epoch() print (f"Train Loss: {train_loss:.4 f} , Train Acc: {train_acc:.4 f} " ) val_loss, val_acc = self.evaluate() print (f"Val Loss: {val_loss:.4 f} , Val Acc: {val_acc:.4 f} " ) if val_acc > best_val_acc: best_val_acc = val_acc torch.save(self.model.state_dict(), 'best_protonet.pt' ) print (f"Saved best model with accuracy {best_val_acc:.4 f} " ) def main (): num_classes = 64 samples_per_class = 600 image_size = 84 all_data = [] all_labels = [] for c in range (num_classes): class_data = torch.randn(samples_per_class, 3 , image_size, image_size) class_labels = torch.full((samples_per_class,), c) all_data.append(class_data) all_labels.append(class_labels) all_data = torch.cat(all_data, dim=0 ) all_labels = torch.cat(all_labels, dim=0 ) train_classes = num_classes * 4 // 5 train_mask = all_labels < train_classes val_mask = all_labels >= train_classes train_dataset = FewShotDataset( all_data[train_mask].numpy(), all_labels[train_mask].numpy() ) val_dataset = FewShotDataset( all_data[val_mask].numpy(), all_labels[val_mask].numpy() ) encoder = ProtoNetEncoder(input_channels=3 , hidden_dim=64 ) model = PrototypicalNetwork(encoder) trainer = ProtoNetTrainer( model=model, train_dataset=train_dataset, val_dataset=val_dataset, n_way=5 , n_support=5 , n_query=15 , n_episodes=100 , learning_rate=1e-3 ) trainer.train(num_epochs=50 ) if __name__ == '__main__' : main()
代码详解
Episode 采样
EpisodeSampler实现了 Few-Shot Learning
的核心采样逻辑:
1 2 3 4 5 6 7 8 9 10 11 def sample_episode (self ): selected_classes = np.random.choice(classes, n_way, replace=False ) for c in selected_classes: selected_indices = np.random.choice(class_indices, n_support + n_query, replace=False ) support_indices = selected_indices[:n_support] query_indices = selected_indices[n_support:]
原型计算
compute_prototypes计算每个类别的原型(均值):
1 2 3 4 5 for c in range (n_way): class_mask = (support_labels == c) class_embeddings = support_embeddings[class_mask] prototype = class_embeddings.mean(dim=0 ) prototypes.append(prototype)
距离计算
使用torch.cdist高效计算欧氏距离:
1 2 distances = torch.cdist(query_embeddings, prototypes, p=2 ) logits = -distances
深度 Q&A
Q1:
Few-Shot Learning 与 Transfer Learning 有什么区别?
联系 :都是利用已知知识学习新任务
区别 :
维度
Transfer Learning
Few-Shot Learning
数据量
目标任务有较多标注数据
目标任务只有极少标注数据( 1-10 个)
适配方式
微调预训练模型
基于度量或元学习快速适配
训练范式
标准监督学习
Episode 训练
Few-Shot Learning 可以看作是 Transfer Learning
的极端情况 :目标任务数据极度稀缺。
Q2:
为什么 Prototypical 网络使用均值作为原型?有理论支持吗?
理论支持 :在高斯分布假设下,类别原型是最优贝叶斯分类器。
证明 :假设类别 的样本服从高斯分布 ,则后验概率为:
$$
P(y = c | x) (-(x - _c) {-1} (x - _c)) $$
取对数:
当 (各向同性)时,这等价于欧氏距离:
因此,使用均值作为原型并基于欧氏距离分类是贝叶斯最优的(在高斯假设下)。
Q3: MAML
为什么需要二阶梯度?能否避免?
MAML 需要二阶梯度是因为要对适配后的参数 关于元参数 求导:
需要计算 ,这是二阶导数。
避免方法 :
FOMAML :忽略二阶项,只用一阶梯度
Reptile :直接朝适配后参数移动,无需二阶梯度
实验表明 FOMAML 和 Reptile 性能与 MAML 接近,但计算效率高得多。
Q4: Episode
训练和普通训练有什么本质区别?
普通训练 :每个 batch
包含多个类别的样本,模型学习所有类别的判别边界
Episode 训练 :每个 episode 只包含 N
个类别,模型学习"如何从 N 个类别的少量样本中学习"
本质区别 : -
普通训练学习任务特定 知识(哪些特征区分哪些类别) -
Episode 训练学习元知识 (如何快速学习新任务)
类比 : - 普通训练像"学习特定科目"(学数学、学物理)
- Episode 训练像"学习如何学习"(学习方法论)
Q5: 为什么
Few-Shot Learning 需要大量 base classes?
虽然目标任务( novel
classes)只有少量样本,但要学会"如何学习"需要在多个任务 上训练。
数据需求 : - Base classes
数量:通常需要几十到上百个类别 - 每个 base class 样本数:通常几百个
直觉 :就像人类虽然能从少量样本学习新概念,但这种能力是通过一生的经验积累的。
Few-Shot Learning 模型需要在大量 base classes 上学习这种能力。
实验证据 : - Omniglot: 1200+ base classes -
miniImageNet: 64 base classes - tieredImageNet: 351 base classes
Base classes 越多, Few-Shot Learning 性能越好。
Q6: Prototypical
网络能否用于回归任务?
可以,但需要修改。分类任务中,原型是离散的(每个类别一个),回归任务中,需要连续的原型空间 。
方法 1:核回归
将原型看作核中心,预测为加权平均:
方法 2:条件神经过程( Conditional Neural Process,
CNP)
学习一个函数分布 ,给定支持集预测查询点的分布:
$$
p(y_q | x_q, ) = ((x_q, ), ^2(x_q, )) $$
Q7: 如何选择 Few-Shot
Learning 方法?
决策树:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 1. 数据特点? ├─ 图像数据 → Prototypical 网络、 Matching 网络 ├─ 时序数据 → MAML + LSTM └─ 图数据 → Graph Neural Network + Meta-Learning 2. 计算资源? ├─ 资源充足 → MAML(二阶梯度) └─ 资源有限 → Prototypical 网络、 Reptile 3. 任务多样性? ├─ 任务相似 → 度量学习( Prototypical) └─ 任务多样 → 元学习( MAML) 4. 是否需要可解释性? ├─ 需要 → Prototypical 网络(原型可视化) └─ 不需要 → Relation 网络、 MAML
Q8: Few-Shot
Learning 在实际应用中的挑战是什么?
域偏移 : Base classes 和 novel classes 分布不同
解决:域适应 + Few-Shot Learning( Cross-Domain Few-Shot
Learning)
类不平衡 : Novel classes 样本数可能不同
标注噪声 :少量样本中的标注错误影响大
计算效率 : Episode 训练比普通训练慢
泛化能力 :模型可能过拟合 base classes
解决:增大 base classes 多样性、正则化
Q9: Prototypical 网络和
k-NN 有什么区别?
Prototypical 网络可以看作是学习嵌入空间的 k-NN 。
方法
距离度量
嵌入空间
原型
k-NN
固定(欧氏、余弦)
原始特征空间
每个样本
Prototypical
学习的
学习的嵌入空间
类别均值
关键区别 : 1. 嵌入学习 :
Prototypical 网络学习一个嵌入函数Missing superscript or subscript argument f_ ,使得嵌入空间更适合 Few-Shot
Learning 2.
原型聚合 :使用类别均值而非每个样本,更鲁棒
实验 :在相同嵌入空间下, Prototypical 网络略优于
k-NN,但差异不大。主要优势来自嵌入学习。
Q10: MAML 的初始化为什么重要?
MAML 学习的初始化位于损失曲面的平坦区域 ,使得:
快速适配 :沿任意方向梯度下降都能快速降低损失
泛化能力强 :平坦区域对应更好的泛化( Sharp Minima
vs Flat Minima)
数学上 , MAML 等价于最小化损失的二阶泰勒展开:
$$
L(') L() + L()^(' - ) + (' - )^H (' - ) $$
MAML 希望 Hessian
的特征值都较小(平坦),这样沿任意方向移动损失增长都慢。
Q11: Few-Shot Learning
能否用于强化学习?
可以! Few-Shot Reinforcement Learning 是一个活跃的研究方向。
挑战 : 1. 样本效率更低(需要交互) 2. 奖励稀疏 3.
探索-利用权衡
方法 : 1. MAML for
RL :在多个任务上元学习策略 2. Meta-RL with
Context :学习任务表示,条件化策略 3. Model-Based
Meta-RL :学习动力学模型,规划
应用 : - 机器人快速适应新任务 - 游戏 AI
快速学习新游戏 - 推荐系统快速适应新用户
Q12: 如何评估 Few-Shot
Learning 模型?
标准评估协议:
数据划分 :
Base classes:训练
Val classes:验证超参数
Novel classes:最终测试
评估指标 :
准确率(主要)
95%置信区间(报告不确定性)
每类准确率(检查类不平衡)
评估步骤 : 1 2 3 4 5 for episode in test_episodes: sample N-way K-shot task from novel classes compute accuracy on query set report: mean ± 95% confidence interval
标准基准 :
Omniglot: 20-way 1-shot, 20-way 5-shot
miniImageNet: 5-way 1-shot, 5-way 5-shot
tieredImageNet: 5-way 1-shot, 5-way 5-shot
注意 :必须报告置信区间,因为 Few-Shot Learning
方差较大。
相关论文
Siamese Neural Networks for One-shot Image
Recognition
Koch et al., ICML Deep Learning Workshop 2015
https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf
Prototypical Networks for Few-shot
Learning
Snell et al., NeurIPS 2017
https://arxiv.org/abs/1703.05175
Matching Networks for One Shot Learning
Vinyals et al., NeurIPS 2016
https://arxiv.org/abs/1606.04080
Learning to Compare: Relation Network for Few-Shot
Learning
Sung et al., CVPR 2018
https://arxiv.org/abs/1711.06025
Model-Agnostic Meta-Learning for Fast Adaptation of Deep
Networks (MAML)
Finn et al., ICML 2017
https://arxiv.org/abs/1703.03400
On First-Order Meta-Learning Algorithms
(Reptile)
Nichol et al., arXiv 2018
https://arxiv.org/abs/1803.02999
A Closer Look at Few-shot Classification
Chen et al., ICLR 2019
https://arxiv.org/abs/1904.04232
Meta-Dataset: A Dataset of Datasets for Learning to Learn
from Few Examples
Triantafillou et al., ICLR 2020
https://arxiv.org/abs/1903.03096
Learning to Learn with Conditional Class
Dependencies
Bertinetto et al., ICLR 2019
https://arxiv.org/abs/1806.03961
TADAM: Task dependent adaptive metric for improved
few-shot learning
Oreshkin et al., NeurIPS 2018
https://arxiv.org/abs/1805.10123
Meta-Learning with Differentiable Convex
Optimization
Lee et al., CVPR 2019
https://arxiv.org/abs/1904.03758
Generalizing from a Few Examples: A Survey on Few-Shot
Learning
Wang et al., ACM Computing Surveys 2020
https://arxiv.org/abs/1904.05046
总结
Few-Shot Learning
解决了深度学习最大的瓶颈之一:数据稀缺。本文从第一性原理出发,推导了度量学习(
Siamese 、 Prototypical 、 Matching 、 Relation Networks)和元学习(
MAML 、
Reptile)的数学基础,详细解析了它们的架构、损失函数、优化方法。
我们看到, Few-Shot Learning
的核心是利用先验知识:度量学习通过学习嵌入空间使得度量可迁移,元学习通过学习初始化或优化器使得适配快速。
Episode 训练是关键,它让模型在训练时就面临 Few-Shot
场景,学会"如何学习"。
完整的 Prototypical 网络实现展示了 episode
采样、原型计算、距离度量等核心技术。下一章我们将探讨知识蒸馏 ,研究如何将大模型的知识迁移到小模型。