Transfer Learning (3): Domain Adaptation Methods
Chen Kai BOSS

Domain Adaptation is one of the most challenging problems in transfer learning. In practical applications, training data (source domain) and test data (target domain) often come from different distributions: medical images transferred from one hospital to another, recommendation systems transferred from one country to another, autonomous driving transferred from sunny to rainy conditions. This distribution shift can lead to significant performance degradation.

The core goal of domain adaptation is: to learn a model that performs well on the target domain when the source domain has labeled data but the target domain has no labels (or few labels). This requires aligning source and target domain feature distributions while maintaining discriminative power. This article derives the mathematical characterization of distribution shift from first principles, derives the theoretical foundation of unsupervised domain adaptation, explains classic methods like DANN and MMD in detail, and provides a complete DANN implementation.

Domain Shift Problems: Covariate Shift and Label Shift

Formal Definition of Domains

A Domain consists of two components: - Feature Space : The value space of input variables - Marginal Distribution: The probability distribution of input variables

A Task also consists of two components: - Label Space: The value space of output variables - Conditional Distribution: The probability distribution of outputs given inputs

Domain adaptation setting: - Source Domain, with labeled data - Target Domain, with unlabeled dataGoal: Learn a model$f: that performs well on the target domain.

Covariate Shift

Definition: Feature distributions differ but conditional distributions are the same:

Intuitive Examples: - Spam Email Classification: Training data from 2020, test data from 2026. Email topic distribution changes (changes), but rules for determining spam given content remain unchanged (unchanged) - Medical Imaging: Training data from Siemens CT, test data from GE CT. Different device imaging characteristics (changes), but diagnostic criteria for lesions remain constant (unchanged)

Importance Weighting

Covariate shift can be corrected through importance weighting. Empirical Risk Minimization (ERM) on source domain optimizes:But we really care about target domain risk:Using importance sampling, target domain risk can be rewritten as:whereis the importance weight. In practice, optimize weighted loss:

Density Ratio Estimation

Importance weightrequires estimating the ratio of two densities. Direct density estimation is difficult (curse of dimensionality), but we can directly estimate the density ratio.

Kullback-Leibler Importance Estimation Procedure (KLIEP):

Minimize KL divergence:Subject to constraint. Practical optimization:Typically parameterize, whereare basis functions (like Gaussian kernels).

Label Shift

Definition: Label distributions differ but class-conditional distributions are the same:

Intuitive Examples: - Medical Diagnosis: Training data from hospitalized patients (high disease proportion), test data from outpatients (high healthy proportion). Disease prevalence differs (changes), but given disease, symptom distribution is same (unchanged) - Recommendation System: Training data from active users (more young users), test data from all users (more elderly users). User age distribution differs (changes), but given user type, behavior patterns are same (unchanged)

Label Shift Correction

Using Bayes' theorem, target domain posterior:Further rewrite:Therefore, just reweight source model outputs with:

Label Distribution Estimation:can be directly estimated from source domain labeled data.needs to be estimated from target domain unlabeled data. One method is Expectation Maximization (EM):

E-step: Use current model to predict target domain labelsM-step: Update label distributionIterate until convergence.

Concept Shift

Definition: Conditional distributions differ:This is the most difficult case, because even with same features, decision boundaries may differ.

Examples: - Sentiment Classification: Training data from movie reviews, test data from product reviews. Same words may have different sentiment tendencies in different domains - Autonomous Driving: Training data from US (right-hand traffic), test data from UK (left-hand traffic). Driving rules completely different

Concept shift usually requires a small amount of target domain labeled data for adaptation.

Unsupervised Domain Adaptation: Aligning Feature Distributions

Unsupervised Domain Adaptation (UDA) assumes source domain has labeled data, target domain is completely unlabeled. Core idea: learn a feature extractor that aligns source and target domain feature distributions.

Theoretical Foundation: Ben-David Theory

Ben-David et al. gave a theoretical upper bound for domain adaptation. Letbe the hypothesis space,$h be a hypothesis, target domain risk can be decomposed as:where: -: Source domain risk -:$ -distance between source and target domains -: Ideal joint hypothesis risk

Intuitive Explanation: 1.: Classify well on source domain 2.: Source and target domains are close 3.: Source and target tasks are similar (small ideal joint risk)

Balancing these three terms is the core of domain adaptation.

####$ -Distance$ -distance is defined as:Intuitive understanding: If source and target domains are close, any two hypothesesshould have similar disagreement on both domains.

In practice, we can approximate this distance using a domain discriminator: train a classifier to distinguish source and target domain samples. If it cannot distinguish (classifier accuracy close to 50%), the two domains are close.

DANN: Domain-Adversarial Neural Network

Domain-Adversarial Neural Network (DANN) is the most classic domain adaptation method, aligning source and target domain feature distributions through adversarial training.

DANN Architecture

DANN contains three components: 1. Feature Extractor: Maps input to feature space 2. Label Predictor: Predicts labels based on features (classifier) 3. Domain Discriminator: Determines which domain features come from

Loss function has three terms:

  1. Source Domain Classification Loss:

  2. Domain Discrimination Loss:

  3. Total Loss (adversarial):

Adversarial Training: - Domain discriminatormaximizes(distinguishes source and target domains) - Feature extractorminimizeswhile maximizingconfusion (makesunable to distinguish)

This is equivalent to:

Gradient Reversal Layer (GRL)

DANN's clever part is the Gradient Reversal Layer (GRL), which is identity mapping in forward pass and flips gradients in backward pass:This way, domain discriminator gradients are multiplied bybefore passing back to feature extractor, implementing adversarial training.

AdaptiveDANN uses adaptive adversarial weightthat increases with training progress:whereis training progress (current step / total steps),is hyperparameter.

Intuition: Early in training, first learn source domain classification well, then gradually enhance domain alignment.

Theoretical Explanation of DANN

DANN minimizes a weighted sum of source domain riskand domain discrepancy. Domain discriminator loss is a proxy for domain discrepancy: if domain discriminator cannot distinguish source and target domains (), feature distributions are aligned.

More rigorously, DANN is equivalent to minimizing Jensen-Shannon divergence:whereis the mixture distribution.

MMD: Maximum Mean Discrepancy

Maximum Mean Discrepancy (MMD) is another method to measure distribution difference, measuring distance by comparing means of two distributions in Reproducing Kernel Hilbert Space (RKHS).

MMD Definition

Letbe RKHS with corresponding kernel function, MMD is defined as:where$: is the kernel mapping.

Expanding the square:

Empirical Estimation:

MMD in Deep Domain Adaptation

Deep Domain Confusion (DDC) and Deep Adaptation Network (DAN) embed MMD into deep networks:whereis the feature extractor,are source and target domain samples.

Multi-Kernel MMD:

A single kernel may not sufficiently capture distribution differences. Use linear combination of multiple kernels:where,.

Kernel Selection for MMD

Common kernel functions:

  1. Gaussian Kernel (RBF kernel):

  2. Polynomial Kernel:

  3. Laplacian Kernel:In practice, typically use multiple Gaussian kernels with different bandwidths.

CORAL: Correlation Alignment

Correlation Alignment (CORAL) reduces distribution difference by aligning second-order statistics (covariance) between source and target domains.

CORAL Loss

Let source domain features be, target domain features be, CORAL loss is:whereare covariance matrices of source and target domain features: is the mean vector.

Deep CORAL:

Deep CORAL adds CORAL loss to deep network training:

Intuition of CORAL

Covariance matrix captures correlations between features. Aligning covariance matrices eliminates linear transformation differences between domains. CORAL can be seen as whitening + recoloring:

  1. Whiten source domain features:$_S = F_S C_S^{-1/2}_S = _S C_T^{1/2}_SC_T$.

Adversarial Domain Adaptation: GAN-based Methods

Generative Adversarial Network (GAN) ideas can also be used for domain adaptation: using generators to transform source domain samples to target domain style while preserving semantics.

CycleGAN: Cycle-Consistent Adversarial Networks

CycleGAN implements unpaired domain transformation through cycle consistency loss.

CycleGAN Architecture

Contains two generators and two discriminators: -(source to target domain) -(target to source domain) -: Discriminates target domain real/fake -: Discriminates source domain real/fake

Loss Functions

  1. Adversarial Loss:

  2. Cycle Consistency Loss:

  3. Total Loss:

Intuition: - Adversarial loss makes generated samples look like target domain - Cycle consistency loss ensures semantics unchanged ()

Domain Adaptation Application

CycleGAN can transform source domain images to target domain style, then train classifier with transformed images:

  1. Use CycleGAN to learn$G: _S T{(G(x_i^s), y_i^s)}{i=1}^{n_s}$3. Train classifier on pseudo target domain data

Pixel-level Domain Adaptation

For vision tasks, domain adaptation can be performed at pixel level.

ADDA: Adversarial Discriminative Domain Adaptation

Adversarial Discriminative Domain Adaptation (ADDA) proceeds in three steps:

  1. Pre-training: Train classifier on source domain

  2. Adversarial Adaptation: Fix classifier, learn target domain feature extractor, adversarial to domain discriminator

  3. Testing: On target domain usefor prediction

ADDA's advantage is source and target domain feature extractors can be different, more flexible.

Adaptive Batch Normalization

Batch Normalization (BN) behaves differently during training and testing: training uses batch statistics, testing uses global statistics. This causes problems in domain adaptation.

BN Domain Shift Problem

BN layer computation:During training,are mean and standard deviation of current batch; during testing, use global mean and standard deviation (running mean/std) from training set.

Problem: If test data (target domain) distribution differs from training data (source domain), global statistics are inaccurate.

Adaptive BN (AdaBN)

AdaBN idea is simple: recompute BN statistics on target domain.

AdaBN Algorithm

  1. Train model normally on source domain (including BN layers)
  2. Run model on target domain data, compute mean and variance for each BN layer:3. During testing useto replace original Why Effective?

BN statistics capture low-order statistical properties of data. Aligning these statistics can partially eliminate domain shift. Experiments show AdaBN is especially effective for covariate shift.

TransNorm: Transferable Normalization

TransNorm further decomposes BN into task-related part and domain-related part: -: Domain-related statistics, recomputed on target domain -: Task-related parameters, kept unchanged

This both adapts to target domain distribution and preserves task knowledge learned from source domain.

Complete Implementation: DANN Domain Adaptation

Below is a complete DANN implementation with gradient reversal layers, domain discriminators, and adversarial training key components.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Function
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score


class GradientReversalFunction(Function):
"""Gradient reversal function"""

@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.view_as(x)

@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.lambda_, None


class GradientReversalLayer(nn.Module):
"""Gradient reversal layer"""

def __init__(self):
super().__init__()
self.lambda_ = 1.0

def set_lambda(self, lambda_):
self.lambda_ = lambda_

def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_)


class FeatureExtractor(nn.Module):
"""Feature extractor (CNN)"""

def __init__(self, input_dim=28*28, hidden_dim=256):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
return x


class LabelPredictor(nn.Module):
"""Label predictor (classifier)"""

def __init__(self, feature_dim=256, num_classes=10):
super().__init__()
self.fc = nn.Linear(feature_dim, num_classes)

def forward(self, x):
return self.fc(x)


class DomainDiscriminator(nn.Module):
"""Domain discriminator"""

def __init__(self, feature_dim=256, hidden_dim=256):
super().__init__()
self.fc1 = nn.Linear(feature_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = torch.sigmoid(self.fc3(x))
return x


class DANN(nn.Module):
"""Domain-adversarial neural network"""

def __init__(self, input_dim=28*28, hidden_dim=256, num_classes=10):
super().__init__()
self.feature_extractor = FeatureExtractor(input_dim, hidden_dim)
self.label_predictor = LabelPredictor(hidden_dim, num_classes)
self.domain_discriminator = DomainDiscriminator(hidden_dim, hidden_dim)
self.grl = GradientReversalLayer()

def forward(self, x, alpha=1.0):
# Extract features
features = self.feature_extractor(x)

# Label prediction
class_logits = self.label_predictor(features)

# Domain discrimination (with gradient reversal)
self.grl.set_lambda(alpha)
reversed_features = self.grl(features)
domain_logits = self.domain_discriminator(reversed_features)

return class_logits, domain_logits


class DANNTrainer:
"""DANN trainer"""

def __init__(
self,
model,
source_loader,
target_loader,
test_loader,
num_epochs=100,
learning_rate=1e-3,
device='cuda',
gamma=10.0
):
self.model = model.to(device)
self.source_loader = source_loader
self.target_loader = target_loader
self.test_loader = test_loader
self.num_epochs = num_epochs
self.device = device
self.gamma = gamma

# Optimizer
self.optimizer = torch.optim.Adam(
model.parameters(),
lr=learning_rate
)

# Loss functions
self.class_criterion = nn.CrossEntropyLoss()
self.domain_criterion = nn.BCELoss()

# History
self.train_losses = []
self.test_accuracies = []

def compute_lambda(self, epoch, total_epochs):
"""Compute adaptive adversarial weight"""
p = epoch / total_epochs
lambda_p = 2.0 / (1.0 + np.exp(-self.gamma * p)) - 1.0
return lambda_p

def train_epoch(self, epoch):
"""Train one epoch"""
self.model.train()

# Create iterators
source_iter = iter(self.source_loader)
target_iter = iter(self.target_loader)

num_batches = min(len(self.source_loader), len(self.target_loader))

total_loss = 0
total_class_loss = 0
total_domain_loss = 0

progress_bar = tqdm(range(num_batches), desc=f'Epoch {epoch+1}/{self.num_epochs}')

for _ in progress_bar:
# Get source and target batches
try:
source_data, source_labels = next(source_iter)
except StopIteration:
source_iter = iter(self.source_loader)
source_data, source_labels = next(source_iter)

try:
target_data, _ = next(target_iter)
except StopIteration:
target_iter = iter(self.target_loader)
target_data, _ = next(target_iter)

source_data = source_data.to(self.device)
source_labels = source_labels.to(self.device)
target_data = target_data.to(self.device)

batch_size = source_data.size(0)

# Compute adversarial weight
lambda_p = self.compute_lambda(epoch, self.num_epochs)

# Forward: source domain
source_class_logits, source_domain_logits = self.model(source_data, lambda_p)

# Forward: target domain
_, target_domain_logits = self.model(target_data, lambda_p)

# Source domain classification loss
class_loss = self.class_criterion(source_class_logits, source_labels)

# Domain discrimination loss
source_domain_labels = torch.ones(batch_size, 1).to(self.device)
target_domain_labels = torch.zeros(target_data.size(0), 1).to(self.device)

source_domain_loss = self.domain_criterion(source_domain_logits, source_domain_labels)
target_domain_loss = self.domain_criterion(target_domain_logits, target_domain_labels)
domain_loss = source_domain_loss + target_domain_loss

# Total loss
loss = class_loss + domain_loss

# Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

# Record
total_loss += loss.item()
total_class_loss += class_loss.item()
total_domain_loss += domain_loss.item()

progress_bar.set_postfix({
'loss': loss.item(),
'class': class_loss.item(),
'domain': domain_loss.item(),
'lambda': lambda_p
})

avg_loss = total_loss / num_batches
avg_class_loss = total_class_loss / num_batches
avg_domain_loss = total_domain_loss / num_batches

return avg_loss, avg_class_loss, avg_domain_loss

def evaluate(self):
"""Evaluate model"""
self.model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
for data, labels in self.test_loader:
data = data.to(self.device)
class_logits, _ = self.model(data, alpha=0.0)
preds = torch.argmax(class_logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.numpy())

accuracy = accuracy_score(all_labels, all_preds)
return accuracy

def train(self):
"""Complete training workflow"""
print(f"Starting DANN training for {self.num_epochs} epochs")

best_acc = 0.0

for epoch in range(self.num_epochs):
# Train
loss, class_loss, domain_loss = self.train_epoch(epoch)
self.train_losses.append(loss)

# Evaluate
acc = self.evaluate()
self.test_accuracies.append(acc)

print(f"Epoch {epoch+1}/{self.num_epochs}")
print(f" Loss: {loss:.4f} (Class: {class_loss:.4f}, Domain: {domain_loss:.4f})")
print(f" Test Accuracy: {acc:.4f}")

# Save best model
if acc > best_acc:
best_acc = acc
torch.save(self.model.state_dict(), 'best_dann_model.pt')
print(f" Saved best model with accuracy {best_acc:.4f}")

return self.train_losses, self.test_accuracies


# Usage example
def main():
# Hyperparameters
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
NUM_CLASSES = 10
BATCH_SIZE = 128
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3

# Simulated data: source and target domains with distribution shift
# Source: MNIST
source_data = torch.randn(10000, 1, 28, 28)
source_labels = torch.randint(0, NUM_CLASSES, (10000,))

# Target: MNIST-M (MNIST with color and background)
# Simulated as MNIST with Gaussian noise
target_data = torch.randn(10000, 1, 28, 28) + 0.5
target_labels = torch.randint(0, NUM_CLASSES, (10000,))

# Test data (target domain)
test_data = torch.randn(2000, 1, 28, 28) + 0.5
test_labels = torch.randint(0, NUM_CLASSES, (2000,))

# Dataloaders
source_dataset = TensorDataset(source_data, source_labels)
target_dataset = TensorDataset(target_data, target_labels)
test_dataset = TensorDataset(test_data, test_labels)

source_loader = DataLoader(source_dataset, batch_size=BATCH_SIZE, shuffle=True)
target_loader = DataLoader(target_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Model
model = DANN(INPUT_DIM, HIDDEN_DIM, NUM_CLASSES)

# Trainer
trainer = DANNTrainer(
model=model,
source_loader=source_loader,
target_loader=target_loader,
test_loader=test_loader,
num_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
gamma=10.0
)

# Train
train_losses, test_accuracies = trainer.train()

print(f"\nFinal Test Accuracy: {test_accuracies[-1]:.4f}")


if __name__ == '__main__':
main()

Code Explanation

Gradient Reversal Layer

GradientReversalFunction inherits from torch.autograd.Function, customizing forward and backward passes:

1
2
3
4
5
6
7
8
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.view_as(x) # Forward pass is identity

@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.lambda_, None # Backward pass flips gradient

Adaptive Adversarial Weight

compute_lambda method dynamically adjusts adversarial weight based on training progress:Earlyclose to 0 (focus on classification), later close to 1 (enhance domain alignment).

Adversarial Training

Training simultaneously optimizes classification loss and domain discrimination loss:

1
2
3
4
5
6
7
8
9
# Source domain classification loss
class_loss = self.class_criterion(source_class_logits, source_labels)

# Domain discrimination loss
domain_loss = source_domain_loss + target_domain_loss

# Total loss (gradient reversal layer automatically handles adversarial)
loss = class_loss + domain_loss
loss.backward()

Gradient reversal layer ensures domain discriminator gradients update feature extractor in opposite direction.

Deep Q&A

Q1: Why align feature distributions instead of directly training on source domain?

Models trained on source domain may fail on target domain because decision boundaries are in regions with low target domain data density.

Mathematical Explanation: Let decision boundary be, source datahas high density near boundary, model learns precise boundary. But if target datahas low density near boundary, the boundary is unreliable.

Intuitive Example: Training data is daytime images, test data is nighttime images. Even for same objects, nighttime image feature distribution differs greatly, decision boundary needs relearning.

Aligning feature distributions allows source and target domain data to mix in feature space, making boundary reliable in both domains.

Q2: Why does DANN use gradient reversal instead of directly maximizing domain discrimination loss?

Directly maximizing domain discrimination loss requires alternating optimization: first fix feature extractor to optimize domain discriminator, then fix domain discriminator to optimize feature extractor. This is equivalent to GAN training, easily unstable.

Gradient reversal layer allows single-step joint optimization: one backward pass simultaneously updates all parameters. Domain discriminator gradients are automatically flipped before passing to feature extractor, implementing adversarial.

Mathematically, gradient reversal is equivalent to minimizing:wherehas negative sign for feature extractor (adversarial) and positive sign for domain discriminator (discrimination).

Q3: What's the difference between MMD and DANN? Which scenarios for each?

Method Distance Metric Optimization Advantages Disadvantages
MMD Kernel function distance Direct minimization Strong theory, stable Need to select kernel, high computation
DANN Jensen-Shannon divergence Adversarial training Strong expressiveness, good adaptation Unstable training, needs hyperparameter tuning

Application Scenarios: - MMD: Small domain difference, limited data, need stability - DANN: Large domain difference, abundant data, pursue optimal performance

In practice, try MMD first (more stable), if results insufficient then use DANN.

Q4: Why is adaptive BN (AdaBN) effective? What problem does it solve?

BN layer statistics (mean, variance) capture low-order statistical properties of data. Even if source and target domains have same semantics, low-order statistics may differ.

Examples: - Images: Source domain images brighter (higher mean), target domain images darker (lower mean) - Sensors: Source from device A (small variance), target from device B (large variance)

AdaBN replaces source domain statistics with target domain statistics, eliminating low-order statistical differences, letting model focus on high-level semantics.

Theoretical Explanation: AdaBN is equivalent to whitening target domain data to source domain statistical distribution, eliminating covariate shift.

Q5: How to choose domain adaptation method? Is there a decision tree?

Yes, following workflow:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
1. Have target domain labeled data?
├─ Yes → Supervised domain adaptation (fine-tuning, importance weighting)
└─ No → 2

2. Where's main difference between source and target?
├─ Feature distribution ($P(X)$) → 3
├─ Label distribution ($P(Y)$) → Label shift correction
└─ Conditional distribution ($P(Y|X)$) → Need some target domain labels

3. Difference magnitude?
├─ Small (covariate shift) → AdaBN, CORAL
├─ Medium → MMD, DANN
└─ Large (cross-modal) → CycleGAN, ADDA

4. Data and computation resources?
├─ Limited data/resources → AdaBN, CORAL
└─ Abundant data/resources → DANN, ADDA

Q6: Can domain adaptation harm source domain performance?

Yes, this is called negative transfer. Reason is domain alignment may harm discriminativeness: to make source and target features close, model may lose information useful for classification.

Ben-David theory tells us target domain risk affected by three terms:Excessive alignment may decreasebut increaseand.

Solutions: 1. Add source validation set: Monitor source performance, stop if degradation 2. Adjust adversarial weight: Don't make adversarial weight too large 3. Conditional alignment: Only align same-class samples (conditional domain adaptation)

Q7: How to evaluate domain adaptation effectiveness? Besides target domain accuracy?

Besides accuracy, can evaluate:

  1. -distance: Measures domain differencewhereis domain classifier error rate. Smaller, more aligned domains.

  2. MMD: Directly compute MMD in feature space

  3. t-SNE Visualization: Reduce source and target features to 2D, observe if mixed

  4. Intra/Inter-class Distance Ratio:Larger is better (classes separated but domains mixed)

  5. Per-class Accuracy: Check if all classes improve (avoid some classes degrading)

Q8: How to adjust DANN adversarial weight? Automatic adjustment methods?

DANN uses adaptive weight:whereis training progress,controls growth speed.

HyperparameterSelection: -too small (like 1): Adversarial weight grows too slowly, insufficient domain alignment -too large (like 100): Adversarial weight grows too fast, destroys classification learning - Recommended: Automatic Adjustment: Can use validation set performance to dynamically adjust:

Q9: How to handle inconsistent classes between source and target (partial/open-set domain adaptation)?

Standard domain adaptation assumes source and target classes are the same. But in practice:

  1. Partial DA: Target classes subset of source classes ()
  2. Open-set DA: Target has classes source doesn't have ()

Solutions:

Partial DA: - Only align classes in source that also exist in target - Use class weights: Give small weights to source classes not in target

Open-set DA: - Add "unknown class", detect new classes in target - Use open-world classifier: Reject classification when prediction confidence low

Q10: Why does CycleGAN's cycle consistency loss ensure semantics unchanged?

Cycle consistency loss:

Intuition: Iftransforms source to target,transforms target to source, thenshould be close to identity mapping.

Mathematically, cycle consistency equivalent to requiringandare (approximate) inverse mappings:This ensures semantic information not lost:can recover afterand.

But note: Cycle consistency cannot completely guarantee semantics unchanged. For example, ifmaps all source images to same target image,maps back to original, cycle consistency still satisfied, but semantics obviously lost. Therefore usually need other constraints (like perceptual loss, identity loss).

Q11: How to select MMD kernel function and bandwidth?

Kernel Function Selection:

  1. Gaussian Kernel (most common):

  2. Multi-kernel combination: Use multiple Gaussian kernels with different bandwidths

BandwidthSelection:

Empirical rule (median heuristic):Extra close brace or missing open brace\sigma = \text{median}\{\|x_i - x_j\| : i \ne j\\}i.e., median of all sample pair distances. Intuition: Thismakes kernel neither too local (too small) nor too global (too large).

Multi-kernel MMD: Use, whereis median heuristic.

Q12: Where is domain adaptation most valuable in practical applications?

  1. Medical Imaging: Different hospitals, different devices have different data distributions
    • CT from Siemens to GE
    • MRI from 1.5T to 3T
  2. Autonomous Driving: Different weather, lighting, cities have different data distributions
    • Sunny to rainy
    • San Francisco to New York
  3. Recommendation Systems: Different countries, different time periods have different user behaviors
    • US to China
    • 2020 to 2026
  4. Sentiment Analysis: Different domains have different sentiment expressions
    • Movie reviews to product reviews
    • Formal text to social media
  5. Object Detection/Segmentation: Synthetic data to real data
    • GTA5 game scenes to real street scenes

Domain adaptation especially suitable for difficult to annotate but abundant source data scenarios.

  1. Domain-Adversarial Training of Neural Networks (DANN)
    Ganin et al., JMLR 2016
    https://arxiv.org/abs/1505.07818

  2. Learning Transferable Features with Deep Adaptation Networks (DAN)
    Long et al., ICML 2015
    https://arxiv.org/abs/1502.02791

  3. Deep CORAL: Correlation Alignment for Deep Domain Adaptation
    Sun and Saenko, ECCV 2016
    https://arxiv.org/abs/1607.01719

  4. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (CycleGAN)
    Zhu et al., ICCV 2017
    https://arxiv.org/abs/1703.10593

  5. Adversarial Discriminative Domain Adaptation (ADDA)
    Tzeng et al., CVPR 2017
    https://arxiv.org/abs/1702.05464

  6. A Theory of Learning from Different Domains (Ben-David Theory)
    Ben-David et al., Machine Learning 2010
    https://link.springer.com/article/10.1007/s10994-009-5152-4

  7. Revisiting Batch Normalization For Practical Domain Adaptation (AdaBN)
    Li et al., ICLR Workshop 2017
    https://arxiv.org/abs/1603.04779

  8. Maximum Mean Discrepancy
    Gretton et al., JMLR 2012
    https://jmlr.org/papers/v13/gretton12a.html

  9. Conditional Adversarial Domain Adaptation (CDAN)
    Long et al., NeurIPS 2018
    https://arxiv.org/abs/1705.10667

  10. Universal Domain Adaptation
    You et al., CVPR 2019
    https://arxiv.org/abs/1902.06906

  11. Covariate Shift Adaptation by Importance Weighted Cross-Validation
    Sugiyama et al., JMLR 2007
    http://www.jmlr.org/papers/v8/sugiyama07a.html

  12. Detecting and Correcting for Label Shift with Black Box Predictors
    Lipton et al., ICML 2018
    https://arxiv.org/abs/1802.03916

Summary

Domain adaptation is the most challenging but also most practical problem in transfer learning. This article derives from mathematical characterization of distribution shift (covariate shift, label shift, concept shift), derives the theoretical foundation of unsupervised domain adaptation (Ben-David theory), and explains classic methods like DANN, MMD, and CORAL in detail.

We see that the core of domain adaptation is finding a balance between aligning feature distributions and maintaining discriminativeness. DANN implicitly minimizes domain difference through adversarial training, MMD explicitly measures distribution distance through kernel functions, and AdaBN eliminates low-order differences by adjusting statistics. Each method has its advantages and limitations, requiring selection based on specific application scenarios.

The complete DANN implementation demonstrates key techniques like gradient reversal layers, domain discriminators, and adaptive adversarial weights. In the next chapter, we'll explore Few-Shot Learning, studying how to learn new categories from extremely few samples.

  • Post title:Transfer Learning (3): Domain Adaptation Methods
  • Post author:Chen Kai
  • Create time:2024-11-15 10:15:00
  • Post link:https://www.chenk.top/transfer-learning-3-domain-adaptation/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments