Transfer Learning (1): Fundamentals and Core Concepts
Chen Kai BOSS

Why can a model trained on ImageNet quickly achieve usable performance on medical imaging? Why can BERT learn text classification from just hundreds of samples after pretraining? The essence of these phenomena is transfer learning — enabling models to transfer existing knowledge to new problems rather than starting from scratch every time.

In the deep learning era, transfer learning has become standard practice rather than an option. This article systematically covers the mathematical formalization, core concepts, taxonomy, feasibility analysis, and negative transfer issues, along with a complete 200+ line implementation of feature transfer with MMD domain adaptation.

Why We Need Transfer Learning

The Dilemma of Training From Scratch

Suppose we want to train a medical image diagnosis model. The traditional supervised learning paradigm requires:

  1. Massive labeled data: Deep neural networks typically need tens of thousands to millions of labeled samples to achieve good generalization
  2. Enormous computational resources: Training large models from random initialization requires hundreds to thousands of GPU hours
  3. Difficulty reusing domain knowledge: Even similar tasks (e.g., X-ray classification vs. CT image classification) require independent model training

However, real-world scenarios often face:

  • Data scarcity: Some rare diseases have only a few hundred cases
  • Expensive annotation: Medical image annotation requires professional doctors at extremely high costs
  • Time urgency: Rapid model deployment needed during new disease outbreaks

These contradictions gave birth to transfer learning: Can we leverage models trained on large-scale data to quickly adapt to new tasks with scarce data?

The Intuition Behind Transfer Learning

Human learning naturally possesses transfer capability:

  • People who can ride bicycles learn motorcycles faster
  • Programmers who know Python don't start from zero when learning Java
  • People who have seen cats can recognize "this is a feline" when first encountering a lion

This ability stems from shared underlying knowledge structures. Similarly, low-level features in deep neural networks (edges, textures) are highly reusable across different visual tasks, and high-level features (semantic concepts) also exhibit certain similarities. Transfer learning exploits this similarity.

Core Idea of Transfer Learning

Given a source domain and a target domain, where the source domain has abundant labeled dataand the target domain has scarce data, the goal of transfer learning is:

Leverage knowledge fromto improve model performance on, especially when target domain data is limited.

Key assumption: There exists some correlation between source and target domains (though not required to be identical), making knowledge transfer possible.

Formal Definitions of Core Concepts

Domain

A domain is defined as a tuple:where: -is the feature space, e.g.,represents-dimensional real vector space -is the marginal probability distribution over the feature space

Example: - Source domain: Natural images (ImageNet), feature space isRGB pixels, distributioncaptures natural scene statistics - Target domain: Medical CT images, feature space isgrayscale, distributionsignificantly differs from natural images

Task

A task is defined as:where: -is the label space -$f: is the prediction function (to be learned)

For supervised learning, the task also includes the conditional probability distribution.

Example: - Task 1: ImageNet 1000-class classification, - Task 2: Pneumonia binary classification,

Source Domain and Target Domain

Transfer learning setup:

  • Source domain:, with source task
  • Target domain:, with target taskKey differences: - (different feature spaces) - (different marginal distributions) - (different label spaces) - (different conditional distributions)

Transfer learning does not require source and target domains to be identical— this is precisely its value.

Mathematical Definition of Transfer Learning

According to the seminal survey by Pan and Yang (2010)1, transfer learning is formally defined as:

Given source domainand learning task, target domainand learning task, transfer learning aims to leverage knowledge fromandto improve learning of the target prediction function, where:Quantitative definition of performance improvement: Letbe the target domain error rate without transfer learning, andbe the error rate with transfer learning. We require:Or equivalently, using fewer target domain labeled samplesto achieve the same performance:

Taxonomy of Transfer Learning

Based on the availability of labeled data in source and target domains, transfer learning is categorized into three types2:

Inductive Transfer Learning

Definition: Source and target tasks are different (), with a small amount of labeled data in the target domain.

Mathematical description: - Source domain:with abundant labels - Target domain:with scarce labels () - Different conditional distributions: Typical methods: 1. Multi-task learning: Optimize source and target tasks simultaneously 2. Self-training: Use pseudo-labels from unlabeled target domain data

Application scenarios: - ImageNet pretrained model → Medical image classification (different tasks) - General language model → Sentiment analysis (different tasks)

Transductive Transfer Learning

Definition: Source and target tasks are the same (), but domains differ (), with no labeled data in the target domain.

Mathematical description: - Source domain:with abundant labels - Target domain: Only, no labels - Different marginal distributions:or Typical methods: 1. Domain adaptation: Align feature distributions between source and target domains 2. Sample reweighting: Adjust source sample weights to approximate target distribution

Application scenarios: - Synthetic data → Real data (e.g., GTAV game scenes → Real street views) - Sentiment analysis for product reviews: Book reviews → Electronics reviews

Unsupervised Transfer Learning

Definition: Both source and target domains lack labeled data; transfer focuses on intrinsic data structure.

Mathematical description: - Both source and target domains have only, no labels - Goal: Learn feature representations or clustering structures

Typical methods: 1. Self-supervised learning: Learn general features through proxy tasks (rotation prediction, contrastive learning) 2. Deep clustering: Transfer clustering structures

Application scenarios: - Word vector transfer in NLP (Word2Vec trained on general corpora, applied to specific domains) - Self-supervised pretraining for images (MoCo, SimCLR)

Core Assumptions of Transfer Learning

Relatedness Assumption

Transfer learning is predicated on some correlation existing between source and target domains. Formalized as:whereis a similarity measure andis a threshold.

Common similarity measures: 1. Feature space similarity: Degree of overlap betweenand

  1. Distribution divergence: KL divergence, Maximum Mean Discrepancy (MMD), Wasserstein distance
  2. Task relatedness: Semantic similarity of label spaces

Shared Representation Assumption

There exists a shared feature representation$: such that in the latent space:Or at least:That is, the latent representation reduces distributional differences between domains. This is the theoretical foundation for most deep transfer learning methods.

The Problem of Negative Transfer

Definition of Negative Transfer

When source domain knowledge not only fails to improve but actually harms target domain performance, it's called negative transfer3:The error rate with transfer learningis higher than training from scratch.

Causes of Negative Transfer

1. Excessive Domain Divergence

If the distribution difference between source and target domains exceeds a threshold:then source domain prior knowledge may be misleading.

Example: Transferring a natural image-trained model to hand-drawn sketches — due to completely different textures and colors, pretrained features may be entirely ineffective.

2. Task Conflict

Source and target tasks have inherent conflicts. Let the optimal solution for the source task beand for the target task be. If:whereis a tolerable radius in parameter space, initializing frommay lead to worse local optima.

Example: Transferring an English sentiment analysis model to Chinese, where Chinese negation expressions and irony differ vastly from English.

3. Overfitting to Source Domain

The model overfits to the source domain, learning noise patterns specific to the source rather than common knowledge. Formalized as:Ifis transferred, it introduces bias.

Avoiding Negative Transfer

  1. Measure domain similarity: Use MMD, A-distance, or other metrics to estimate transfer feasibility
  2. Selective transfer: Only transfer low-level general features; retrain high-level task-specific layers
  3. Regularization constraints: Add regularization during target domain fine-tuning to limit parameter deviation from source domain
  4. Ensemble methods: Combine predictions from scratch training and transfer learning to reduce single-strategy risks

Quantitative Analysis of Transfer Feasibility

Ben-David Bound

Ben-David et al.4 theoretically analyzed domain adaptation generalization bounds. Letbe the hypothesis space andthe divergence induced by the hypothesis space. Then the target domain error rate satisfies:where: -is the hypothesis's error rate on the source domain -is the-divergence between source and target domains -is the ideal joint error rate

Interpretation: Target domain error is controlled by three terms: 1. Source domain error (reducible through training) 2. Domain divergence (reducible through domain adaptation) 3. Task relatedness (determined by problem nature, unchangeable)

Iforis too large, transfer learning may fail.

Maximum Mean Discrepancy (MMD)

MMD is a common metric for measuring distributional differences5:whereis a feature map andis a Reproducing Kernel Hilbert Space (RKHS).

In practice, MMD can serve as a domain adaptation loss:By minimizing MMD, feature distributions of source and target domains are aligned.

Transfer Learning vs. Multi-Task Learning

Dimension Transfer Learning Multi-Task Learning
Goal Optimize target domain performance Optimize all tasks simultaneously
Training Sequential (source then target) Parallel (simultaneous)
Data distribution Source and target can differ Typically assumes related tasks
Typical use Pretrain-finetune Multi-head network with shared encoder

Transfer Learning vs. Meta-Learning

Dimension Transfer Learning Meta-Learning
Learning goal Transfer specific knowledge Learn how to learn
Training data Single or few source domains Many diverse tasks
Adaptation speed Requires some fine-tuning Fast adaptation (few-shot)
Theoretical framework Statistical learning Optimization theory

Transfer Learning vs. Domain Generalization

Dimension Transfer Learning Domain Generalization
Test-time info Access to target domain data Target domain unknown
Methodology Domain adaptation (use target data) Learn domain-invariant features
Challenge Domain alignment Generalize to unseen domains

Complete Implementation: Feature Transfer Example

Below is a complete example demonstrating the basic transfer learning workflow: domain adaptation on the Office-31 dataset, transferring an Amazon domain model to the Webcam domain.

Experimental Setup

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
"""
Complete Transfer Learning Experiment: Feature Transfer and Domain Adaptation
Dataset: Office-31 (simulated scenario using MNIST → USPS as substitute)
Method: Feature extraction + MMD alignment + Fine-tuning
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.svm import SVC
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# ============================================================================
# Data Generation: Simulate source and target domains
# ============================================================================

def generate_source_domain(n_samples=1000):
"""
Generate source domain data: 2D features, 2 classes
Source distribution: Two Gaussian clusters, well-separated
"""
# Class 0: centered at (-2, -2)
X0 = np.random.randn(n_samples // 2, 2) * 0.5 + np.array([-2, -2])
y0 = np.zeros(n_samples // 2)

# Class 1: centered at (2, 2)
X1 = np.random.randn(n_samples // 2, 2) * 0.5 + np.array([2, 2])
y1 = np.ones(n_samples // 2)

X = np.vstack([X0, X1])
y = np.hstack([y0, y1])

return X, y

def generate_target_domain(n_samples=200):
"""
Generate target domain data: same feature space but different distribution
Target distribution: rotated and shifted
"""
# Rotation matrix: 45 degrees
theta = np.pi / 4
rotation = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])

# Class 0: centered at (-1, -1), rotated
X0 = np.random.randn(n_samples // 2, 2) * 0.6 + np.array([-1, -1])
X0 = X0 @ rotation.T
y0 = np.zeros(n_samples // 2)

# Class 1: centered at (1, 1), rotated
X1 = np.random.randn(n_samples // 2, 2) * 0.6 + np.array([1, 1])
X1 = X1 @ rotation.T
y1 = np.ones(n_samples // 2)

X = np.vstack([X0, X1])
y = np.hstack([y0, y1])

return X, y

# Generate data
X_source, y_source = generate_source_domain(1000)
X_target_train, y_target_train = generate_target_domain(50) # Few labeled
X_target_test, y_target_test = generate_target_domain(200) # Test set

print(f"Source domain data: {X_source.shape}, labels: {y_source.shape}")
print(f"Target domain train: {X_target_train.shape}")
print(f"Target domain test: {X_target_test.shape}")

# ============================================================================
# Visualize data distributions
# ============================================================================

def visualize_domains(X_source, y_source, X_target, y_target, title="Domain Comparison"):
"""Visualize source and target domain data distributions"""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Source domain
axes[0].scatter(X_source[y_source==0, 0], X_source[y_source==0, 1],
c='blue', alpha=0.6, label='Class 0')
axes[0].scatter(X_source[y_source==1, 0], X_source[y_source==1, 1],
c='red', alpha=0.6, label='Class 1')
axes[0].set_title('Source Domain')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Target domain
axes[1].scatter(X_target[y_target==0, 0], X_target[y_target==0, 1],
c='blue', alpha=0.6, label='Class 0')
axes[1].scatter(X_target[y_target==1, 0], X_target[y_target==1, 1],
c='red', alpha=0.6, label='Class 1')
axes[1].set_title('Target Domain')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle(title)
plt.tight_layout()
plt.savefig('domain_comparison.png', dpi=150, bbox_inches='tight')
plt.close()

visualize_domains(X_source, y_source, X_target_test, y_target_test)

# ============================================================================
# Method 1: No transfer (train from scratch)
# ============================================================================

def train_from_scratch(X_train, y_train, X_test, y_test):
"""Train classifier from scratch on target domain"""
clf = SVC(kernel='rbf', gamma='auto')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
return acc, clf

acc_scratch, clf_scratch = train_from_scratch(
X_target_train, y_target_train, X_target_test, y_target_test
)
print(f"\n[No Transfer] Target domain test accuracy: {acc_scratch:.4f}")

# ============================================================================
# Method 2: Direct transfer (train on source, test on target)
# ============================================================================

def direct_transfer(X_source, y_source, X_test, y_test):
"""Train on source domain, test directly on target (no fine-tuning)"""
clf = SVC(kernel='rbf', gamma='auto')
clf.fit(X_source, y_source)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
return acc, clf

acc_direct, clf_direct = direct_transfer(
X_source, y_source, X_target_test, y_target_test
)
print(f"[Direct Transfer] Target domain test accuracy: {acc_direct:.4f}")

# ============================================================================
# Method 3: Feature transfer + Fine-tuning
# ============================================================================

class FeatureExtractor(nn.Module):
"""Feature extractor: 2-layer MLP"""
def __init__(self, input_dim=2, hidden_dim=32, output_dim=16):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.ReLU()
)

def forward(self, x):
return self.encoder(x)

class Classifier(nn.Module):
"""Classifier"""
def __init__(self, input_dim=16, num_classes=2):
super().__init__()
self.fc = nn.Linear(input_dim, num_classes)

def forward(self, x):
return self.fc(x)

def compute_mmd(x_source, x_target, kernel='rbf', gamma=1.0):
"""
Compute Maximum Mean Discrepancy (MMD)
"""
n_source = x_source.size(0)
n_target = x_target.size(0)

# Compute kernel matrices
xx = torch.sum(x_source ** 2, dim=1, keepdim=True)
yy = torch.sum(x_target ** 2, dim=1, keepdim=True)

# Source domain kernel
K_ss = torch.exp(-gamma * (xx + xx.t() - 2 * x_source @ x_source.t()))
# Target domain kernel
K_tt = torch.exp(-gamma * (yy + yy.t() - 2 * x_target @ x_target.t()))
# Cross-domain kernel
K_st = torch.exp(-gamma * (xx + yy.t() - 2 * x_source @ x_target.t()))

mmd = K_ss.sum() / (n_source ** 2) + K_tt.sum() / (n_target ** 2) - 2 * K_st.sum() / (n_source * n_target)
return mmd

def train_with_mmd(X_source, y_source, X_target_unlabeled,
X_target_labeled, y_target_labeled,
epochs=100, lambda_mmd=0.1):
"""
Domain adaptation using MMD
"""
# Convert to PyTorch tensors
X_s = torch.FloatTensor(X_source)
y_s = torch.LongTensor(y_source.astype(int))
X_t_unlabeled = torch.FloatTensor(X_target_unlabeled)
X_t_labeled = torch.FloatTensor(X_target_labeled)
y_t_labeled = torch.LongTensor(y_target_labeled.astype(int))

# Create DataLoader
source_dataset = TensorDataset(X_s, y_s)
source_loader = DataLoader(source_dataset, batch_size=32, shuffle=True)

# Initialize models
feature_extractor = FeatureExtractor()
classifier = Classifier()

# Optimizer
optimizer = optim.Adam(
list(feature_extractor.parameters()) + list(classifier.parameters()),
lr=0.001
)
criterion = nn.CrossEntropyLoss()

# Training
losses = []
for epoch in range(epochs):
epoch_loss = 0
for X_batch, y_batch in source_loader:
optimizer.zero_grad()

# Source domain classification loss
features_s = feature_extractor(X_batch)
logits_s = classifier(features_s)
loss_cls = criterion(logits_s, y_batch)

# MMD alignment loss (using unlabeled target data)
features_t = feature_extractor(X_t_unlabeled)
loss_mmd = compute_mmd(features_s, features_t)

# Total loss
loss = loss_cls + lambda_mmd * loss_mmd

loss.backward()
optimizer.step()
epoch_loss += loss.item()

losses.append(epoch_loss / len(source_loader))

if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {losses[-1]:.4f}")

# Fine-tuning: use few labeled target samples
print("\nStarting fine-tuning...")
for epoch in range(50):
optimizer.zero_grad()
features = feature_extractor(X_t_labeled)
logits = classifier(features)
loss = criterion(logits, y_t_labeled)
loss.backward()
optimizer.step()

if (epoch + 1) % 10 == 0:
print(f"Fine-tune Epoch {epoch+1}/50, Loss: {loss.item():.4f}")

return feature_extractor, classifier, losses

# Training
feature_extractor, classifier, losses = train_with_mmd(
X_source, y_source,
X_target_test, # Target unlabeled (for MMD alignment)
X_target_train, y_target_train, # Target labeled (for fine-tuning)
epochs=100,
lambda_mmd=0.5
)

# Testing
feature_extractor.eval()
classifier.eval()
with torch.no_grad():
X_test_tensor = torch.FloatTensor(X_target_test)
features_test = feature_extractor(X_test_tensor)
logits_test = classifier(features_test)
y_pred = torch.argmax(logits_test, dim=1).numpy()

acc_transfer = accuracy_score(y_target_test, y_pred)
print(f"\n[Feature Transfer + MMD] Target domain test accuracy: {acc_transfer:.4f}")

# ============================================================================
# Performance comparison visualization
# ============================================================================

def plot_performance_comparison():
"""Plot performance comparison of three methods"""
methods = ['From Scratch\n(50 samples)', 'Direct Transfer\n(No adaptation)', 'Feature Transfer+MMD\n(Domain adaptation)']
accuracies = [acc_scratch, acc_direct, acc_transfer]

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(methods, accuracies, color=['#ff6b6b', '#4ecdc4', '#45b7d1'], alpha=0.8)

# Add value labels
for bar, acc in zip(bars, accuracies):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{acc:.3f}',
ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Transfer Learning Methods Performance Comparison', fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('performance_comparison.png', dpi=150, bbox_inches='tight')
plt.close()

plot_performance_comparison()

# ============================================================================
# Feature space visualization
# ============================================================================

def visualize_feature_space():
"""Visualize feature space before and after transfer"""
feature_extractor.eval()

with torch.no_grad():
# Extract source features
X_s_tensor = torch.FloatTensor(X_source)
features_s = feature_extractor(X_s_tensor).numpy()

# Extract target features
X_t_tensor = torch.FloatTensor(X_target_test)
features_t = feature_extractor(X_t_tensor).numpy()

# Use t-SNE for 2D visualization
features_all = np.vstack([features_s, features_t])
tsne = TSNE(n_components=2, random_state=42)
features_2d = tsne.fit_transform(features_all)

features_s_2d = features_2d[:len(features_s)]
features_t_2d = features_2d[len(features_s):]

# Plot
fig, ax = plt.subplots(figsize=(10, 8))

# Source domain
ax.scatter(features_s_2d[y_source==0, 0], features_s_2d[y_source==0, 1],
c='blue', marker='o', alpha=0.5, s=30, label='Source Class 0')
ax.scatter(features_s_2d[y_source==1, 0], features_s_2d[y_source==1, 1],
c='red', marker='o', alpha=0.5, s=30, label='Source Class 1')

# Target domain
ax.scatter(features_t_2d[y_target_test==0, 0], features_t_2d[y_target_test==0, 1],
c='blue', marker='^', alpha=0.8, s=50, label='Target Class 0')
ax.scatter(features_t_2d[y_target_test==1, 0], features_t_2d[y_target_test==1, 1],
c='red', marker='^', alpha=0.8, s=50, label='Target Class 1')

ax.set_title('Feature Space Visualization (t-SNE)', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('feature_space_tsne.png', dpi=150, bbox_inches='tight')
plt.close()

visualize_feature_space()

# ============================================================================
# Confusion matrices
# ============================================================================

def plot_confusion_matrices():
"""Plot confusion matrices for three methods"""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Method 1: From scratch
y_pred_scratch = clf_scratch.predict(X_target_test)
cm1 = confusion_matrix(y_target_test, y_pred_scratch)
sns.heatmap(cm1, annot=True, fmt='d', cmap='Blues', ax=axes[0])
axes[0].set_title(f'From Scratch\nAcc: {acc_scratch:.3f}')
axes[0].set_xlabel('Predicted Label')
axes[0].set_ylabel('True Label')

# Method 2: Direct transfer
y_pred_direct = clf_direct.predict(X_target_test)
cm2 = confusion_matrix(y_target_test, y_pred_direct)
sns.heatmap(cm2, annot=True, fmt='d', cmap='Greens', ax=axes[1])
axes[1].set_title(f'Direct Transfer\nAcc: {acc_direct:.3f}')
axes[1].set_xlabel('Predicted Label')
axes[1].set_ylabel('True Label')

# Method 3: Feature transfer
cm3 = confusion_matrix(y_target_test, y_pred)
sns.heatmap(cm3, annot=True, fmt='d', cmap='Oranges', ax=axes[2])
axes[2].set_title(f'Feature Transfer+MMD\nAcc: {acc_transfer:.3f}')
axes[2].set_xlabel('Predicted Label')
axes[2].set_ylabel('True Label')

plt.tight_layout()
plt.savefig('confusion_matrices.png', dpi=150, bbox_inches='tight')
plt.close()

plot_confusion_matrices()

print("\n" + "="*60)
print("Experiment Summary:")
print("="*60)
print(f"1. Train from scratch (50 target samples): {acc_scratch:.4f}")
print(f"2. Direct transfer (no adaptation): {acc_direct:.4f}")
print(f"3. Feature transfer + MMD (adaptation): {acc_transfer:.4f}")
print(f"\nImprovement: {(acc_transfer - acc_scratch) / acc_scratch * 100:.1f}%")
print("="*60)

Code Explanation

Core components: 1. Data generation: Simulate distribution shift between source and target (rotation + translation) 2. MMD computation: Calculate domain distance using RBF kernel 3. Two-stage training: - Stage 1: Source classification + MMD alignment - Stage 2: Target domain fine-tuning 4. Visualization: t-SNE feature space, performance comparison, confusion matrices

Key parameters: - lambda_mmd=0.5: Controls domain adaptation strength - epochs=100: Pretraining epochs - Fine-tuning epochs: 50 (using few labeled target samples)

Frequently Asked Questions

Q1: Is transfer learning always better than training from scratch?

Not necessarily. Transfer learning effectiveness depends on: 1. Domain relatedness: The more similar source and target domains, the better the results 2. Data quantity: With extremely few target samples (<100), transfer learning has clear advantages; with abundant data (>100K), training from scratch may be better 3. Task relatedness: Excessive task differences lead to negative transfer

Rule of thumb: Consider transfer learning when target domain data < 10% of source domain data.

Q2: How to select a source domain?

Selection criteria: 1. Domain similarity: Choose ImageNet for vision, BERT/GPT for NLP 2. Data scale: Larger source data is better (million-scale+) 3. Task relatedness: Classification transfers to classification, detection to detection

Visualization tools: Use t-SNE to compare source and target feature distributions; MMD < 0.1 typically indicates transferability.

Q3: Which layers of a pretrained model should be frozen?

General guidelines: - CV models: Freeze first 3-4 layers (edge, texture features), fine-tune later layers - NLP models: Usually fine-tune all layers but with reduced learning rate (1/10 of source domain lr) - Small data scenarios: Only fine-tune last 1-2 layers to avoid overfitting

Experimental strategy: Progressive Unfreezing — gradually unfreeze more layers starting from the top.

Q4: How to detect negative transfer?

Detection methods: 1. Baseline comparison: Transfer learning accuracy < training from scratch accuracy 2. Loss curves: Loss increases instead of decreasing during fine-tuning 3. Domain distance: MMD or A-distance too large (> 0.5)

Remedies: - Only transfer shallow features - Increase target domain data weight - Use adversarial domain adaptation

Q5: How does transfer learning handle different label spaces?

Three strategies: 1. Zero-shot transfer: Use semantic embeddings (e.g., Word2Vec) to map labels to shared space 2. Partial transfer: Only transfer shared classes, ignore source-specific classes 3. Open-set transfer: Introduce "unknown class" to identify new target classes

Example: ImageNet (1000 classes) → Medical imaging (5 classes)— keep first 4096 dimensions of classifier features, replace final softmax layer.

Q6: How to evaluate transfer learning effectiveness?

Evaluation metrics: 1. Accuracy improvement:

  1. Sample efficiency: Ratio of target samples needed to achieve same accuracy
  2. Training efficiency: Convergence speed (transfer typically 10x faster)
  3. A-distance:, whereis domain classifier error rate

Q7: What's the difference between transfer learning and data augmentation?

Dimension Transfer Learning Data Augmentation
Knowledge source External source domain Current dataset
Method Model initialization/feature alignment Sample transformation
Applicable scenarios Data scarcity Improve generalization
Computational cost Requires pretraining Real-time generation

Both can be combined: First use transfer learning for good initialization, then use data augmentation for improved robustness.

Q8: How to implement cross-modal transfer (e.g., image → text)?

Key techniques: 1. Shared embedding space: Map images and text to same vector space (CLIP) 2. Contrastive learning: Maximize matched pair similarity, minimize unmatched pair similarity 3. Generative models: Use VAE/GAN to learn cross-modal mappings

Loss function:whereis image feature,is text feature,is temperature parameter.

Q9: How to transfer learning to small devices (e.g., phones)?

Model compression + transfer learning: 1. Knowledge distillation: Large model (teacher) → Small model (student) 2. Pruning: Remove redundant parameters 3. Quantization: FP32 → INT8 4. Lightweight architectures: MobileNet, EfficientNet

Workflow: Large model pretraining → Distill to small model → Target domain fine-tuning

Q10: What theoretical guarantees exist for transfer learning?

The Ben-David bound provides generalization guarantees:Interpretation: - First term: Source domain error (optimizable) - Second term: Domain divergence (reducible via domain adaptation) - Third term: Ideal joint error (determined by problem essence)

Theoretical insight: Successful transfer requires good source training + small domain gap + high task relatedness.

Q11: How to address catastrophic forgetting?

When a model forgets source knowledge during target fine-tuning, it's called catastrophic forgetting. Solutions: 1. Elastic Weight Consolidation (EWC): Add regularization for important parameters 2. Progressive fine-tuning: Gradually unfreeze layers, preserve low-level features 3. Memory replay: Mix small amount of source data during training 4. Knowledge distillation: Keep pretrained model as teacher

Q12: How to do semi-supervised transfer learning?

Leveraging unlabeled data: 1. Pseudo-labeling: Use source model to label unlabeled target data 2. Consistency regularization: Augmented samples should have consistent predictions 3. Self-training: Iterative self-improvement

Loss function:

Summary

This article systematically introduced the fundamentals and core concepts of transfer learning:

  1. Motivation: Addressing data scarcity and expensive training
  2. Core concepts: Formal definitions of domain, task, source/target domains
  3. Taxonomy: Inductive, transductive, and unsupervised transfer
  4. Negative transfer: Causes, detection, and avoidance
  5. Theoretical analysis: Ben-David bound, MMD, and other feasibility criteria
  6. Practical code: Complete feature transfer + MMD domain adaptation implementation

Transfer learning is not a silver bullet, but in scenarios with data scarcity, computational constraints, and rapid deployment needs, it's one of the most effective technical approaches. The next chapter will delve into pretraining and fine-tuning techniques, covering classic paradigms from ImageNet to BERT.

References


  1. Pan, S. J., & Yang, Q. (2010). A survey on transfer learning. IEEE Transactions on Knowledge and Data Engineering, 22(10), 1345-1359.↩︎

  2. Weiss, K., Khoshgoftaar, T. M., & Wang, D. (2016). A survey of transfer learning. Journal of Big Data, 3(1), 1-40.↩︎

  3. Rosenstein, M. T., Marx, Z., Kaelbling, L. P., & Dietterich, T. G. (2005). To transfer or not to transfer. NIPS Workshop on Transfer Learning.↩︎

  4. Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., & Vaughan, J. W. (2010). A theory of learning from different domains. Machine Learning, 79(1), 151-175.↩︎

  5. Gretton, A., Borgwardt, K. M., Rasch, M. J., Sch ö lkopf, B., & Smola, A. (2012). A kernel two-sample test. Journal of Machine Learning Research, 13, 723-773.↩︎

  • Post title:Transfer Learning (1): Fundamentals and Core Concepts
  • Post author:Chen Kai
  • Create time:2024-11-03 09:00:00
  • Post link:https://www.chenk.top/transfer-learning-1-fundamentals-and-core-concepts/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments