Transfer Learning (4): Few-Shot Learning
Chen Kai BOSS

Few-shot learning represents one of the most challenging problems in machine learning. Humans can rapidly learn new concepts from minimal examples - recognizing new species after seeing just a few images, or understanding new linguistic patterns from a handful of instances. Traditional deep learning models, however, require massive amounts of labeled data to train effectively and perform poorly in data-scarce scenarios.

The goal of few-shot learning is to learn classifiers from only a few examples per class (typically 1-10 samples). This requires models with powerful generalization and transfer capabilities - the ability to learn "how to learn" from known classes and quickly adapt to novel classes. This article derives the mathematical foundations of metric learning and meta-learning from first principles, explains classic methods like Siamese networks, Prototypical networks, and MAML in detail, and provides a complete Prototypical network implementation.

The Few-Shot Learning Challenge

Problem Definition

Few-shot learning typically adopts an N-way K-shot setting: - N-way: Classify among classes - K-shot: Onlylabeled samples per class

For example, 5-way 1-shot means classifying among 5 classes with only 1 training sample per class.

Formally, we define: - Support Set: training samples - Query Set: test samples

The goal is to train a modelthat, after training on support set, achieves high accuracy on query set.

Why Is It Difficult?

  1. Data scarcity:samples are far insufficient to learn a complex classifier
  2. Overfitting risk: Models easily memorize specific support set samples rather than learning generalizable features
  3. Inter-class similarity: Novel classes may be very similar to known classes, making discrimination difficult

Failure of Traditional Methods

Standard empirical risk minimization (ERM):severely overfits whenis small. Even with regularization:this remains insufficient because regularization only prevents large parameters and cannot provide adequate inductive bias.

Core Ideas of Few-Shot Learning

To learn from few samples requires leveraging prior knowledge. Few-shot learning's core approach:

  1. Learn priors from known classes: Train on numerous base classes
  2. Rapidly adapt to novel classes: Use learned priors to quickly adapt on novel classes

This is equivalent to learning a meta-learner: "Learning to Learn".

Metric Learning: Similarity-Based Classification

Metric learning's idea is to learn an embedding space where same-class samples are close and different-class samples are distant. During classification, query samples are compared with support set samples by distance, selecting the nearest class.

Siamese Networks: Twin Networks

Siamese networks are among the earliest metric learning methods, learning embedding spaces through contrastive loss.

Architecture

Siamese networks contain two weight-shared encoders:Then compute the distance between embeddings:

Contrastive Loss

The contrastive loss is defined as:where: -: positive pair (same class), loss is, encouraging small distance -: negative pair (different class), loss is, encouraging distance greater than margin Intuition: - Positive pairs: pull together - Negative pairs: if distance is less than, push apart by at least; if already greater than, no penalty

Few-Shot Classification

Given support setand query sample, prediction is:selecting the class of the nearest support set sample.

Prototypical Networks

Prototypical networks improve metric learning by learning class prototypes for classification.

Class Prototypes

Given class's support set samples, the class prototype is defined as the mean of support sample embeddings:

Intuition: The prototype is the class's "center" in embedding space, representing typical features of that class.

Distance Metric

Prototypical networks use Euclidean distance to measure the distance between query samples and prototypes:Cosine distance can also be used:

Classification and Loss

Classification probability is computed via softmax:The loss function is negative log-likelihood:

Theory of Prototypical Networks

Prototypical networks can be viewed as implementing nearest centroid classification in embedding space. Under linear separability, Prototypical networks are equivalent to linear classifiers.

Theorem: In embedding space, if class prototypes are linearly separable, the decision boundary of Prototypical networks is linear.

Proof: Query samplebelongs to classif and only if:That is:Expanding:Simplifying:This is a linear inequality in, so the decision boundary is a hyperplane.

Matching Networks

Matching networks introduce attention mechanisms and memory augmentation to further improve few-shot learning performance.

Attention Kernel

Matching networks use an attention kernel to compute similarity between query samples and support samples:whereandare encoders for query and support sets respectively (can be different).

Prediction

The predicted class for a query sample is a weighted sum of support labels:

Intuition: Support samples with higher similarity to the query contribute more to prediction.

Full Context Embeddings

Matching networks use bidirectional LSTMs to encode the support set, making each sample's embedding contain contextual information from the entire support set:Extra close brace or missing open braceg(x_i) = \text{BiLSTM}(\{x_1, \ldots, x_{NK}} , i)This allows the model to consider relationships between support set samples.

Relation Networks

Relation networks don't use fixed distance metrics (like Euclidean distance), but instead learn a metric function.

Architecture

Relation networks contain two modules:

  1. Embedding module: maps samples to embedding space
  2. Relation module: learns similarity between embeddings

Given query sampleand support sample, compute:whereis the learned similarity.

Loss Function

Relation networks use MSE loss:𝟙whereis the similarity between query sample and class's prototype.

Why Learn the Metric?

Fixed distances (like Euclidean) assume the embedding space is isotropic, but different dimensions may have different importance. Learning the metric allows adaptive adjustment of distance computation.

Meta-Learning: Learning to Learn

Meta-learning's core idea is: learn across multiple tasks how to rapidly adapt to new tasks.

Formalization of Meta-Learning

Giventraining tasks${_1, , _T} _i^{}_i^{}$.

The goal of meta-learning is to learn meta-parameterssuch that for any new task, after adapting with, performance is good on:whereare parameters adapted on taskusing:

MAML: Model-Agnostic Meta-Learning

Model-Agnostic Meta-Learning (MAML) is the most classic meta-learning algorithm, learning good initialization parameters so models can rapidly adapt to new tasks.

MAML Algorithm

Given task distribution, MAML optimizes:That is: 1. On task's training set, take one (or multiple) gradient descent step(s):2. Compute loss on task's test set:3. Average test loss across all tasks, update meta-parameters:

MAML Gradient Computation

MAML's key is computing second-order gradients:Using the chain rule:where:Therefore:whereis the Hessian matrix.

Computational Complexity: Computing the Hessian requirestime and space, whereis parameter dimension. In practice, first-order approximation (First-Order MAML, FOMAML) can be used:Ignoring the Hessian term reduces complexity to.

MAML Intuition

MAML learnslocated in a "flat" region of the loss surface, such that gradient descent in any direction (any task) can rapidly reduce loss.

Analogy:is a "universal starting point" from which only a few steps are needed to reach the optimal solution for any task.

Reptile: First-Order Meta-Learning

Reptile is a simplified version of MAML that uses only first-order gradients, making computation more efficient.

Reptile Algorithm

  1. Sample task$k$

Intuition: Reptile moves meta-parameters toward task-specific parameters. After multiple iterations,will be at the "center" of all task-specific parameters.

Reptile vs MAML

Method Gradient Order Computational Complexity Performance
MAML Second-order High (requires Hessian) Optimal
FOMAML First-order (approximation) Medium Close to MAML
Reptile First-order Low Slightly below MAML

Reptile performs similarly to FOMAML in practice but with simpler implementation.

Theory of Meta-Learning

Meta-learning can be understood from a Bayesian perspective. Let task parametersfollow prior distribution, then MAML is equivalent to maximizing the posterior:where: is the prior parameter,is the prior distribution. Meta-learning learns a good prior.

Episode Training: Simulating Few-Shot Scenarios

Few-shot learning training adopts episodic training, where each episode simulates a few-shot task.

Episode Sampling

Each episode contains: 1. Randomly sampleclasses from base classes 2. Randomly samplesamples per class as support set 3. Randomly samplesamples per class as query set

Formally, an episode is:where:Extra close brace or missing open brace\begin{aligned} \mathcal{S} &= \{(x_i^{(c,k)}, c) : c \in \{1, \ldots, N} , k \in \{1, \ldots, K} } \\ \mathcal{Q} &= \{(x_j^{(c,q)}, c) : c \in \{1, \ldots, N} , q \in \{1, \ldots, Q} } \end{aligned}

Episode Training Workflow

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for epoch in range(num_epochs):
for episode in range(episodes_per_epoch):
# Sample episode
classes = sample(base_classes, N)
support = sample_from_classes(classes, K)
query = sample_from_classes(classes, Q)

# Forward pass
prototypes = compute_prototypes(support)
logits = compute_distances(query, prototypes)

# Compute loss
loss = cross_entropy(logits, query_labels)

# Backward pass
loss.backward()
optimizer.step()

Intuition Behind Episode Training

Episode training exposes the model to few-shot scenarios during training, forcing it to learn how to generalize from few samples. This is a form of curriculum learning: training difficulty matches testing difficulty.

Complete Implementation: Prototypical Networks

Below is a complete Prototypical network implementation including episode sampling, distance computation, and support/query set partitioning.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score


class ConvBlock(nn.Module):
"""Convolutional block"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pool(x)
return x


class ProtoNetEncoder(nn.Module):
"""Prototypical network encoder"""

def __init__(self, input_channels=3, hidden_dim=64):
super().__init__()
self.conv1 = ConvBlock(input_channels, hidden_dim)
self.conv2 = ConvBlock(hidden_dim, hidden_dim)
self.conv3 = ConvBlock(hidden_dim, hidden_dim)
self.conv4 = ConvBlock(hidden_dim, hidden_dim)

def forward(self, x):
x = self.conv1(x) # 84x84 -> 42x42
x = self.conv2(x) # 42x42 -> 21x21
x = self.conv3(x) # 21x21 -> 10x10
x = self.conv4(x) # 10x10 -> 5x5
x = x.view(x.size(0), -1) # Flatten
return x


class PrototypicalNetwork(nn.Module):
"""Prototypical network"""

def __init__(self, encoder):
super().__init__()
self.encoder = encoder

def compute_prototypes(self, support_embeddings, support_labels, n_way):
"""
Compute class prototypes

Args:
support_embeddings: (n_way * n_support, embedding_dim)
support_labels: (n_way * n_support,)
n_way: number of classes

Returns:
prototypes: (n_way, embedding_dim)
"""
prototypes = []
for c in range(n_way):
# Get all samples of class c
class_mask = (support_labels == c)
class_embeddings = support_embeddings[class_mask]
# Compute mean as prototype
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)

prototypes = torch.stack(prototypes)
return prototypes

def compute_distances(self, query_embeddings, prototypes):
"""
Compute Euclidean distance between query samples and prototypes

Args:
query_embeddings: (n_query, embedding_dim)
prototypes: (n_way, embedding_dim)

Returns:
distances: (n_query, n_way)
"""
# Use broadcasting to compute Euclidean distance
# (n_query, 1, embedding_dim) - (1, n_way, embedding_dim)
# -> (n_query, n_way, embedding_dim)
distances = torch.cdist(query_embeddings, prototypes, p=2)
return distances

def forward(self, support_images, support_labels, query_images, n_way, n_support):
"""
Forward pass

Args:
support_images: (n_way * n_support, C, H, W)
support_labels: (n_way * n_support,)
query_images: (n_query, C, H, W)
n_way: number of classes
n_support: number of support samples per class

Returns:
logits: (n_query, n_way)
"""
# Encode
support_embeddings = self.encoder(support_images)
query_embeddings = self.encoder(query_images)

# Compute prototypes
prototypes = self.compute_prototypes(support_embeddings, support_labels, n_way)

# Compute distances
distances = self.compute_distances(query_embeddings, prototypes)

# Negative distance as logits (smaller distance = higher logit)
logits = -distances

return logits


class FewShotDataset(Dataset):
"""Few-shot dataset"""

def __init__(self, data, labels):
"""
Args:
data: (N, C, H, W) all images
labels: (N,) all labels
"""
self.data = data
self.labels = labels

# Organize data by class
self.classes = np.unique(labels)
self.class_to_indices = {}
for c in self.classes:
self.class_to_indices[c] = np.where(labels == c)[0]

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx], self.labels[idx]


class EpisodeSampler:
"""Episode sampler"""

def __init__(self, dataset, n_way, n_support, n_query, n_episodes):
"""
Args:
dataset: FewShotDataset
n_way: number of classes per episode
n_support: number of support samples per class
n_query: number of query samples per class
n_episodes: total number of episodes
"""
self.dataset = dataset
self.n_way = n_way
self.n_support = n_support
self.n_query = n_query
self.n_episodes = n_episodes

def __iter__(self):
for _ in range(self.n_episodes):
yield self.sample_episode()

def sample_episode(self):
"""Sample one episode"""
# Randomly select n_way classes
selected_classes = np.random.choice(
self.dataset.classes,
size=self.n_way,
replace=False
)

support_images = []
support_labels = []
query_images = []
query_labels = []

for i, c in enumerate(selected_classes):
# Get all sample indices for class c
class_indices = self.dataset.class_to_indices[c]

# Randomly select n_support + n_query samples
selected_indices = np.random.choice(
class_indices,
size=self.n_support + self.n_query,
replace=False
)

# Split into support and query sets
support_indices = selected_indices[:self.n_support]
query_indices = selected_indices[self.n_support:]

# Add to support set
for idx in support_indices:
support_images.append(self.dataset.data[idx])
support_labels.append(i) # Use relative labels 0, 1, ..., n_way-1

# Add to query set
for idx in query_indices:
query_images.append(self.dataset.data[idx])
query_labels.append(i)

# Convert to tensors
support_images = torch.stack([torch.FloatTensor(img) for img in support_images])
support_labels = torch.LongTensor(support_labels)
query_images = torch.stack([torch.FloatTensor(img) for img in query_images])
query_labels = torch.LongTensor(query_labels)

return support_images, support_labels, query_images, query_labels


class ProtoNetTrainer:
"""Prototypical network trainer"""

def __init__(
self,
model,
train_dataset,
val_dataset,
n_way=5,
n_support=5,
n_query=15,
n_episodes=100,
learning_rate=1e-3,
device='cuda'
):
self.model = model.to(device)
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.n_way = n_way
self.n_support = n_support
self.n_query = n_query
self.n_episodes = n_episodes
self.device = device

self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
self.criterion = nn.CrossEntropyLoss()

def train_epoch(self):
"""Train one epoch"""
self.model.train()

sampler = EpisodeSampler(
self.train_dataset,
self.n_way,
self.n_support,
self.n_query,
self.n_episodes
)

total_loss = 0
total_acc = 0

progress_bar = tqdm(sampler, desc='Training')

for support_images, support_labels, query_images, query_labels in progress_bar:
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)

# Forward pass
logits = self.model(
support_images,
support_labels,
query_images,
self.n_way,
self.n_support
)

# Compute loss
loss = self.criterion(logits, query_labels)

# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

# Compute accuracy
preds = torch.argmax(logits, dim=1)
acc = (preds == query_labels).float().mean().item()

total_loss += loss.item()
total_acc += acc

progress_bar.set_postfix({
'loss': loss.item(),
'acc': acc
})

avg_loss = total_loss / self.n_episodes
avg_acc = total_acc / self.n_episodes

return avg_loss, avg_acc

def evaluate(self, n_eval_episodes=100):
"""Evaluate model"""
self.model.eval()

sampler = EpisodeSampler(
self.val_dataset,
self.n_way,
self.n_support,
self.n_query,
n_eval_episodes
)

total_loss = 0
total_acc = 0

with torch.no_grad():
for support_images, support_labels, query_images, query_labels in tqdm(sampler, desc='Evaluating'):
support_images = support_images.to(self.device)
support_labels = support_labels.to(self.device)
query_images = query_images.to(self.device)
query_labels = query_labels.to(self.device)

logits = self.model(
support_images,
support_labels,
query_images,
self.n_way,
self.n_support
)

loss = self.criterion(logits, query_labels)
preds = torch.argmax(logits, dim=1)
acc = (preds == query_labels).float().mean().item()

total_loss += loss.item()
total_acc += acc

avg_loss = total_loss / n_eval_episodes
avg_acc = total_acc / n_eval_episodes

return avg_loss, avg_acc

def train(self, num_epochs=100):
"""Complete training workflow"""
best_val_acc = 0.0

for epoch in range(num_epochs):
print(f"\nEpoch {epoch + 1}/{num_epochs}")

# Train
train_loss, train_acc = self.train_epoch()
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

# Evaluate
val_loss, val_acc = self.evaluate()
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(self.model.state_dict(), 'best_protonet.pt')
print(f"Saved best model with accuracy {best_val_acc:.4f}")


# Usage example
def main():
# Simulate data: Omniglot or miniImageNet
# Using random data for demonstration
num_classes = 64 # Number of base classes
samples_per_class = 600
image_size = 84

# Generate simulated data
all_data = []
all_labels = []

for c in range(num_classes):
class_data = torch.randn(samples_per_class, 3, image_size, image_size)
class_labels = torch.full((samples_per_class,), c)
all_data.append(class_data)
all_labels.append(class_labels)

all_data = torch.cat(all_data, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Split train and validation sets
train_classes = num_classes * 4 // 5
train_mask = all_labels < train_classes
val_mask = all_labels >= train_classes

train_dataset = FewShotDataset(
all_data[train_mask].numpy(),
all_labels[train_mask].numpy()
)

val_dataset = FewShotDataset(
all_data[val_mask].numpy(),
all_labels[val_mask].numpy()
)

# Model
encoder = ProtoNetEncoder(input_channels=3, hidden_dim=64)
model = PrototypicalNetwork(encoder)

# Trainer
trainer = ProtoNetTrainer(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
n_way=5,
n_support=5,
n_query=15,
n_episodes=100,
learning_rate=1e-3
)

# Train
trainer.train(num_epochs=50)


if __name__ == '__main__':
main()

Code Breakdown

Episode Sampling

EpisodeSampler implements the core sampling logic for few-shot learning:

1
2
3
4
5
6
7
8
9
10
11
def sample_episode(self):
# Randomly select n_way classes
selected_classes = np.random.choice(classes, n_way, replace=False)

for c in selected_classes:
# Sample n_support + n_query samples from class c
selected_indices = np.random.choice(class_indices, n_support + n_query, replace=False)

# Split into support and query sets
support_indices = selected_indices[:n_support]
query_indices = selected_indices[n_support:]

Prototype Computation

compute_prototypes computes the prototype (mean) for each class:

1
2
3
4
5
for c in range(n_way):
class_mask = (support_labels == c)
class_embeddings = support_embeddings[class_mask]
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)

Distance Computation

Using torch.cdist for efficient Euclidean distance computation:

1
2
distances = torch.cdist(query_embeddings, prototypes, p=2)
logits = -distances # Negative distance as logits

Advanced Extensions

Transductive Prototypical Networks

Standard Prototypical networks use only support set to compute prototypes. Transductive Prototypical Networks leverage query set information through semi-supervised learning.

Soft k-Means

Iteratively refine prototypes using query predictions:

  1. Initialize prototypes from support set
  2. Compute query predictions:$P(y = c | x_q)$4. Repeat steps 2-3 until convergence

This is equivalent to applying soft k-means clustering in embedding space.

Task-Dependent Adaptive Metric (TADAM)

TADAM conditions the metric on task context, making it adaptive to different task characteristics.

Task Embedding

Compute task representation from support set:wherecan be mean pooling, attention, or set encoder.

Task-Conditioned Feature Extraction

Modulate feature extractor using task embedding via Feature-wise Linear Modulation (FiLM):whereare predicted from task embedding.

This allows the network to adapt its features to task-specific characteristics.

Meta-Learning with Latent Embedding Optimization (LEO)

LEO learns in a lower-dimensional latent space for better generalization.

Architecture

  1. Encoder: Map data to latent code:

  2. Relation Network: Model dependencies between support samples

  3. Decoder: Generate task-specific parameters:

  4. Classifier:

Training

In latent space, perform gradient-based adaptation on support set, then evaluate on query set. This reduces overfitting by constraining adaptation to a low-dimensional space.

Comprehensive Q&A

Q1: How does Few-Shot Learning differ from Transfer Learning?

Connection: Both leverage existing knowledge for new tasks.

Differences:

Dimension Transfer Learning Few-Shot Learning
Data Volume Target task has substantial labeled data Target task has minimal labeled data (1-10 samples)
Adaptation Method Fine-tune pre-trained model Rapid adaptation via metrics or meta-learning
Training Paradigm Standard supervised learning Episodic training

Few-shot learning can be viewed as an extreme case of transfer learning where target task data is extremely scarce.

Q2: Why do Prototypical Networks use mean as prototype? Is there theoretical support?

Theoretical Support: Under Gaussian distribution assumptions, class prototypes are optimal Bayesian classifiers.

Proof: Assume classsamples follow Gaussian distribution, then the posterior probability is:Taking logarithm:When(isotropic), this is equivalent to Euclidean distance:Therefore, using mean as prototype and classifying based on Euclidean distance is Bayes optimal (under Gaussian assumption).

Q3: Why does MAML require second-order gradients? Can it be avoided?

MAML requires second-order gradients because it differentiates adapted parameterswith respect to meta-parameters:This requires computing, which is a second derivative.

Avoidance Methods:

  1. FOMAML: Ignore second-order terms, use only first-order gradients
  2. Reptile: Directly move toward adapted parameters, no second-order gradients needed

Experiments show FOMAML and Reptile perform similarly to MAML but with much higher computational efficiency.

Q4: What's the fundamental difference between episode training and standard training?

Standard Training: Each batch contains samples from multiple classes; model learns discriminative boundaries for all classes.

Episode Training: Each episode contains onlyclasses; model learns "how to learn from few samples ofclasses".

Fundamental Difference: - Standard training learns task-specific knowledge (which features distinguish which classes) - Episode training learns meta-knowledge (how to rapidly learn new tasks)

Analogy: - Standard training is like "learning specific subjects" (learning math, learning physics) - Episode training is like "learning how to learn" (learning methodologies)

Q5: Why does Few-Shot Learning require many base classes?

Although target tasks (novel classes) have few samples, learning "how to learn" requires training on many tasks.

Data Requirements: - Number of base classes: typically dozens to hundreds - Samples per base class: typically hundreds

Intuition: Just as humans can learn new concepts from few examples because of accumulated life experience, few-shot learning models need to learn this capability on many base classes.

Experimental Evidence: - Omniglot: 1200+ base classes - miniImageNet: 64 base classes - tieredImageNet: 351 base classes

More base classes lead to better few-shot learning performance.

Q6: Can Prototypical Networks be used for regression tasks?

Yes, but modifications are needed. In classification, prototypes are discrete (one per class); in regression, a continuous prototype space is needed.

Method 1: Kernel Regression

View prototypes as kernel centers, predict as weighted average:

Method 2: Conditional Neural Processes (CNP)

Learn a function distribution, predict distribution at query points given support set:

Q7: How to choose a Few-Shot Learning method?

Decision tree:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
1. Data characteristics?
├─ Image data → Prototypical Networks, Matching Networks
├─ Sequential data → MAML + LSTM
└─ Graph data → Graph Neural Network + Meta-Learning

2. Computational resources?
├─ Abundant resources → MAML (second-order gradients)
└─ Limited resources → Prototypical Networks, Reptile

3. Task diversity?
├─ Similar tasks → Metric learning (Prototypical)
└─ Diverse tasks → Meta-learning (MAML)

4. Need interpretability?
├─ Yes → Prototypical Networks (visualize prototypes)
└─ No → Relation Networks, MAML

Q8: What are the challenges of Few-Shot Learning in real applications?

  1. Domain Shift: Base classes and novel classes have different distributions
    • Solution: Domain adaptation + Few-Shot Learning (Cross-Domain Few-Shot Learning)
  2. Class Imbalance: Novel classes may have different sample counts
    • Solution: Weighted loss, resampling
  3. Label Noise: Annotation errors in few samples have large impact
    • Solution: Robust loss functions, denoising methods
  4. Computational Efficiency: Episode training is slower than standard training
    • Solution: Pre-training + limited episode fine-tuning
  5. Generalization: Model may overfit base classes
    • Solution: Increase base class diversity, regularization

Q9: How do Prototypical Networks differ from k-NN?

Prototypical networks can be viewed as k-NN with learned embedding space.

Method Distance Metric Embedding Space Prototype
k-NN Fixed (Euclidean, cosine) Original feature space Each sample
Prototypical Learned Learned embedding space Class mean

Key Differences: 1. Embedding Learning: Prototypical networks learn embedding functionmaking embedding space more suitable for few-shot learning 2. Prototype Aggregation: Use class means rather than individual samples, more robust

Experiments: In the same embedding space, Prototypical networks slightly outperform k-NN, but difference is small. Main advantage comes from embedding learning.

Q10: Why is MAML's initialization important?

MAML learns initialization located in a flat region of the loss surface, enabling:

  1. Rapid Adaptation: Gradient descent in any direction can quickly reduce loss
  2. Strong Generalization: Flat regions correspond to better generalization (Sharp Minima vs Flat Minima)

Mathematically, MAML is equivalent to minimizing the second-order Taylor expansion of the loss:MAML wants all eigenvalues of Hessianto be small (flat), so moving in any direction causes slow loss increase.

Q11: Can Few-Shot Learning be used in Reinforcement Learning?

Yes! Few-Shot Reinforcement Learning is an active research area.

Challenges: 1. Even lower sample efficiency (requires interaction) 2. Sparse rewards 3. Exploration-exploitation tradeoff

Methods: 1. MAML for RL: Meta-learn policies across multiple tasks 2. Meta-RL with Context: Learn task representations, condition policies 3. Model-Based Meta-RL: Learn dynamics models, plan

Applications: - Robots rapidly adapting to new tasks - Game AI quickly learning new games - Recommendation systems adapting to new users

Q12: How to evaluate Few-Shot Learning models?

Standard evaluation protocol:

  1. Data Split:

    • Base classes: training
    • Val classes: hyperparameter validation
    • Novel classes: final testing
  2. Evaluation Metrics:

    • Accuracy (primary)
    • 95% confidence interval (report uncertainty)
    • Per-class accuracy (check class imbalance)
  3. Evaluation Steps:

    1
    2
    3
    4
    5
    for episode in test_episodes:
    sample N-way K-shot task from novel classes
    compute accuracy on query set

    report: mean ± 95% confidence interval

  4. Standard Benchmarks:

    • Omniglot: 20-way 1-shot, 20-way 5-shot
    • miniImageNet: 5-way 1-shot, 5-way 5-shot
    • tieredImageNet: 5-way 1-shot, 5-way 5-shot

Note: Must report confidence intervals because few-shot learning has high variance.

Q13: How does gradient-based meta-learning relate to pre-training?

Both learn transferable representations, but with different mechanisms:

Pre-training: - Learns fixed feature extractor on large dataset - Transfer via fine-tuning all or part of parameters - Adaptation: standard gradient descent

MAML: - Learns initialization optimized for rapid adaptation - Transfer via few gradient steps from initialization - Adaptation: few-step gradient descent from learned initialization

Connection: Both can be viewed as learning good priors in Bayesian framework. Pre-training learns features (prior on function space), MAML learns initialization (prior on parameter space).

Q14: What is the relationship between Few-Shot Learning and Zero-Shot Learning?

Few-Shot Learning: Learn fromexamples per class.

Zero-Shot Learning: Learn fromexamples, relying on semantic information (attributes, descriptions).

Unified View - Meta-Learning Spectrum: - Zero-shot: No labeled examples, only semantic information - One-shot: 1 labeled example per class - Few-shot: 2-10 labeled examples per class - Standard learning: Many labeled examples

Zero-shot can be viewed as extreme few-shot where "support set" is semantic descriptions rather than labeled examples.

Q15: How to handle distribution shift between base and novel classes?

Problem: Base classes (e.g., cats, dogs) and novel classes (e.g., birds) may have very different distributions, hurting transfer.

Solutions:

  1. Domain-Adversarial Meta-Learning
    • Add domain discriminator to learn domain-invariant features
    • Minimize domain classification loss while maximizing task performance
  2. Feature-wise Transformation
    • Learn affine transformations to align base and novel class features
    • Use task embedding to predict transformation parameters
  3. Self-Supervised Pre-training
    • Pre-train on large unlabeled dataset covering both base and novel class distributions
    • Helps learn more general features
  4. Data Augmentation
    • Augment base classes to simulate novel class characteristics
    • Mixup, CutMix, domain randomization
  5. Cross-Domain Few-Shot Learning Benchmarks
    • Train on miniImageNet, test on CUB birds
    • Evaluate robustness to domain shift
  1. Siamese Neural Networks for One-shot Image Recognition
    Koch et al., ICML Deep Learning Workshop 2015
    https://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

  2. Prototypical Networks for Few-shot Learning
    Snell et al., NeurIPS 2017
    https://arxiv.org/abs/1703.05175

  3. Matching Networks for One Shot Learning
    Vinyals et al., NeurIPS 2016
    https://arxiv.org/abs/1606.04080

  4. Learning to Compare: Relation Network for Few-Shot Learning
    Sung et al., CVPR 2018
    https://arxiv.org/abs/1711.06025

  5. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (MAML)
    Finn et al., ICML 2017
    https://arxiv.org/abs/1703.03400

  6. On First-Order Meta-Learning Algorithms (Reptile)
    Nichol et al., arXiv 2018
    https://arxiv.org/abs/1803.02999

  7. A Closer Look at Few-shot Classification
    Chen et al., ICLR 2019
    https://arxiv.org/abs/1904.04232

  8. Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
    Triantafillou et al., ICLR 2020
    https://arxiv.org/abs/1903.03096

  9. Learning to Learn with Conditional Class Dependencies
    Bertinetto et al., ICLR 2019
    https://arxiv.org/abs/1806.03961

  10. TADAM: Task dependent adaptive metric for improved few-shot learning
    Oreshkin et al., NeurIPS 2018
    https://arxiv.org/abs/1805.10123

  11. Meta-Learning with Differentiable Convex Optimization
    Lee et al., CVPR 2019
    https://arxiv.org/abs/1904.03758

  12. Generalizing from a Few Examples: A Survey on Few-Shot Learning
    Wang et al., ACM Computing Surveys 2020
    https://arxiv.org/abs/1904.05046

  13. Latent Embedding Optimization for Few-Shot Learning (LEO)
    Rusu et al., ICLR 2019
    https://arxiv.org/abs/1807.05960

  14. Transductive Propagation Network for Few-shot Learning
    Liu et al., arXiv 2019
    https://arxiv.org/abs/1805.10002

Summary

Few-shot learning addresses one of deep learning's biggest bottlenecks: data scarcity. This article derived the mathematical foundations of metric learning (Siamese, Prototypical, Matching, Relation Networks) and meta-learning (MAML, Reptile) from first principles, providing detailed analysis of their architectures, loss functions, and optimization methods.

We saw that few-shot learning's core is leveraging prior knowledge: metric learning makes metrics transferable by learning embedding spaces, while meta-learning makes adaptation rapid by learning initialization or optimizers. Episode training is crucial - it exposes models to few-shot scenarios during training, teaching them "how to learn".

The complete Prototypical network implementation demonstrates core techniques including episode sampling, prototype computation, and distance metrics. Next chapter we'll explore knowledge distillation, studying how to transfer knowledge from large models to small models.

  • Post title:Transfer Learning (4): Few-Shot Learning
  • Post author:Chen Kai
  • Create time:2024-11-21 15:45:00
  • Post link:https://www.chenk.top/transfer-learning-4-few-shot-learning/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments