Transfer Learning (6): Multi-Task Learning
Chen Kai BOSS

Multi-Task Learning (MTL) is a machine learning paradigm that improves model generalization by simultaneously learning multiple related tasks. Rich Caruana's pioneering 1997 paper "Multitask Learning" demonstrated how shared representations help models learn more robust features. In modern deep learning, multi-task learning has achieved tremendous success in computer vision (simultaneous detection, segmentation, depth estimation), natural language processing (joint entity recognition and relation extraction), and recommendation systems (simultaneous CTR and CVR prediction). But multi-task learning is far more than simply summing multiple loss functions — how to design shared structures, how to balance learning across different tasks, and how to handle negative transfer between tasks are all questions requiring deep investigation.

This article derives the mathematical foundations of multi-task learning from first principles, analyzes the pros and cons of hard vs soft parameter sharing, explains task relationship learning and task clustering methods in detail, deeply analyzes gradient conflict problems and solutions (PCGrad, GradNorm, CAGrad, etc.), introduces auxiliary task design principles, and provides a complete multi-task network implementation (including dynamic weight adjustment, gradient projection, task balancing and other industrial-grade techniques). We'll see that multi-task learning essentially seeks a Pareto optimal solution satisfying multiple optimization objectives.

Motivation for Multi-Task Learning

From Single-Task to Multi-Task: Sharing Inductive Bias

Single-task learning trains an independent model for each task, while multi-task learning has all tasks share some parameters or representations. The core assumption behind this is: related tasks share common underlying structures.

Intuitive Example: Consider three tasks for image scene understanding: - Object Detection: Identify object locations and categories in images - Semantic Segmentation: Assign category labels to each pixel - Depth Estimation: Predict depth value for each pixel

All three tasks require understanding spatial structure, object boundaries, texture information and other low-level features. Rather than independently learning these features for each task, they should share a feature extractor with only high-level task-specific heads.

Mathematical Perspective on Multi-Task Learning: Regularization Effect

From an optimization perspective, multi-task learning introduces implicit regularization. Given tasks where task's loss is, withbeing shared parameters andbeing task-specific parameters, the multi-task optimization objective is:whereare task weights.

Key Insight: Shared parametersmust be effective for all tasks simultaneously, constraining their representational space and providing regularization. Formally, multi-task learning is equivalent to adding an implicit constraint on single-task loss:whereis auxiliary task loss andis tolerance. This constraint prevents the model from overfitting to main task training data.

Data Augmentation Perspective: Auxiliary Tasks Provide Additional Signals

Multi-task learning can be viewed as a data augmentation strategy. When the main task has limited labeled data, auxiliary tasks can provide additional supervisory signals.

Example: In low-resource language machine translation: - Main Task: EnglishSwahili translation (only 100K sentence pairs) - Auxiliary Task: EnglishFrench translation (10M sentence pairs)

Although French and Swahili are different, the English encoder can learn better English representations from abundant English-French data, thereby helping English-Swahili translation.

Experiments show that introducing auxiliary tasks can improve main task performance by 5-20% (depending on task relatedness and data volume).

Computational Efficiency: Parameter Sharing Reduces Redundancy

From an engineering perspective, multi-task learning significantly reduces model parameters and computation through parameter sharing:

  • Single-Task:tasks each have a ResNet-50 encoder, total parameters (assuming)
  • Multi-Task:tasks share one encoder, total parameters (each task head has 2M parameters)

Parameters reduced by about 70%, and inference requires only one forward pass to obtain all task outputs, dramatically improving efficiency.

Negative Transfer: The Risk of Multi-Task Learning

However, multi-task learning is not always beneficial. When tasks are unrelated or even conflicting, negative transfer may occur: joint training performance lower than separate training.

Example: - Task A: Face recognition (requires learning fine-grained facial features) - Task B: Scene classification (requires learning global layout and context)

These two tasks have very different feature requirements; forcing parameter sharing may cause mutual interference.

Experimental Data (CIFAR-100): - Train Task A separately: 82% accuracy - Train Task B separately: 78% accuracy - Joint training (naive MTL): 79% and 74% (both tasks decline)

Therefore, how to design shared structures, how to select related tasks, how to balance task weights are keys to multi-task learning success.

Parameter Sharing Strategies: Hard Sharing vs Soft Sharing

The core of multi-task learning is how to share information between tasks. There are two main paradigms: hard parameter sharing and soft parameter sharing.

Hard Parameter Sharing

Hard parameter sharing is the most common multi-task learning architecture, proposed by Caruana in 1993.

Architecture Design: - Shared Layers: All tasks share the same bottom-level network (like convolutional layers, Transformer layers) - Task-Specific Layers: Each task has independent output heads (like fully connected layers, decoders)

Formally, for input: whereis the shared feature extractor andis task's prediction head.

Advantages: 1. Strong Regularization: Shared parameters constrained by multiple tasks, reducing overfitting risk 2. Parameter Efficiency: Most parameters shared, compact model 3. Simple and Direct: Easy to implement and train

Disadvantages: 1. Poor Flexibility: All tasks must use same shared representation, unsuitable for highly divergent tasks 2. High Negative Transfer Risk: Conflicting tasks interfere with each other

Empirical Design Principles: - Shared layers should learn general features (like CNN low layers learning edges, textures) - Task-specific layers should have sufficient capacity to handle task-specific patterns - Typically share first 70-80% of layers, keep last 20-30% independent

Soft Parameter Sharing

Soft parameter sharing was proposed by Duong et al. in 2015, allowing each task to have its own parameters but encouraging parameter similarity through regularization.

Basic Form: Each task has independent model, add parameter similarity constraint:The second term isregularization, penalizing differences between task parameters.

Cross-Stitch Networks:

Misra et al. proposed Cross-Stitch Networks in 2016, allowing tasks to exchange information at multiple levels.

Given two tasks' layeractivationsand, the cross-stitch unit computes:whereare learnable parameters. The magnitude ofreflects how much information taskborrows from task.

Multi-Task Attention Network (MTAN):

Liu et al. proposed Multi-Task Attention Network in 2019, using attention mechanisms to dynamically select which features to share.

For shared featureand task, define task-specific attention weights: whereis sigmoid function andis element-wise multiplication. Each task "softly" selects useful shared features through attention.

Advantages: 1. High Flexibility: Each task can have different parameters, strong adaptability 2. Low Negative Transfer Risk: Tasks can selectively ignore irrelevant information

Disadvantages: 1. Large Parameter Count: Each task has independent parameters, model inflation 2. Complex Training: Requires careful tuning of regularization strength

Dynamic Network Architectures: Conditional Computation

Recently, dynamic networks allow models to dynamically adjust computation paths based on input or task.

Routing Networks:

Rosenbaum et al. proposed in 2018, using routing functions to decide which subnetworks each task uses.

Givensubnetworks, task's routing weights are, then task's output is:Routing weights can be fixed (discrete selection) or learnable (soft routing).

Task-Conditional Adapters:

Rebuffi et al. proposed in 2017, inserting task-specific adapter modules at each layer of a pre-trained model.

For taskand layer, adapter defined as:Adapters are typically small bottleneck networks:, whereis reduction ratio (like 16). Only adapter parameters are task-specific, remaining parameters shared.

Advantage: When adding new tasks to pre-trained models, only need to train adapters, efficient and avoids catastrophic forgetting.

Task Relationship Learning: Discovering Correlations

Multi-task learning effectiveness largely depends on inter-task correlations. How to quantify and utilize task relationships is an important research direction.

Task Affinity Matrix

Task affinity matrixquantifies pairwise task correlations, whererepresents similarity between taskand task.

Computation Method 1: Performance Correlation

Fifty et al. proposed Taskonomy in 2021, measuring task affinity through transfer learning experiments:

  1. Train model on task$ijP_{i j}A_{ij} = P_{i j} - P_{}P_{}$is random initialization baseline performance.

Computation Method 2: Gradient Correlation

Yu et al. proposed in 2020, affinity based on gradient cosine similarity:High positive correlation means tasks update parameters in same direction, low or negative correlation means conflict.

Computation Method 3: Feature Representation Similarity

Compute CKA (Centered Kernel Alignment) of feature representations learned during training:whereis task's feature kernel matrix.

Task Clustering: Grouped Sharing

When task count is large, can first cluster tasks, then tasks within same group share parameters.

Hierarchical Multi-Task Learning:

Assume hierarchical relationships between tasks, like: - Coarse-grained Task: Scene classification (indoor vs outdoor) - Fine-grained Task: Specific scene categories (bedroom, kitchen, street, park)

Can design hierarchical network: 1. Shared layer extracts general features 2. Middle layer for coarse-grained task 3. Top layer for fine-grained tasks, depends on middle layer output

Loss function:

Adaptive Task Grouping:

Standley et al. proposed "Which Tasks Should Be Learned Together?" in 2020, using reinforcement learning to automatically search optimal task grouping.

Algorithm workflow: 1. Initialize each task training independently 2. Use policy network to sample task grouping schemes 3. Train multi-task model according to grouping, evaluate validation performance 4. Use performance as reward, update policy network 5. Repeat until finding optimal grouping

Experiments show automated grouping more effective than manual design or global sharing.

Task Selection: Choosing Primary and Auxiliary Tasks

When there's one primary task and multiple candidate auxiliary tasks, how to select the most helpful auxiliary tasks?

Greedy Selection Strategy:

  1. Train primary task alone, record performance$P_0tt$ - Record primary task performance - Compute gain$P_t = P_t - P_0K$auxiliary tasks with highest gains

Meta-Learning Based Selection:

Du et al. proposed Automated Auxiliary Learning in 2020, using meta-learning to predict auxiliary task effectiveness:

  • Rapidly train model on small data
  • Use meta-model to predict each auxiliary task's help to primary task
  • Select auxiliary tasks with highest predicted benefits

Advantage: Avoids overhead of fully training all candidate tasks.

Gradient Conflicts and Task Balancing

One of the biggest challenges in multi-task learning is gradient conflict: different tasks' gradients may point in different directions, causing training instability or performance degradation.

Problem Analysis: What is Gradient Conflict

Given two tasks' gradientsand, naive multi-task optimization uses gradient sum:

Problem: If(negative cosine similarity), two gradients point in opposite directions, average gradient may reduce one task's performance.

Example: - Task 1 gradient: - Task 2 gradient: - Average gradient:Average gradient's inner product with:, meaning update increases task 1's loss!

Formally, gradient conflict defined as:𝟙Experiments show gradient conflict ratio can reach 30-50% in multi-task training, seriously affecting convergence.

Static Weight Methods: Manual Tuning

Simplest method is manually setting task weights, but requires extensive experimentation.

Uniform Weights:Simple but often suboptimal, as different tasks' loss scales may vary greatly (like classification loss, regression loss).

Uncertainty Weighting:

Kendall et al. proposed in 2018, using task uncertainty to automatically adjust weights.

Assume task's output follows Gaussian distribution, then negative log-likelihood is:whereis learnable task uncertainty parameter. Joint loss:

Intuition: - If task'sis large (high uncertainty), that task's weightis small -term prevents(degenerate solution)

Experiments show uncertainty weighting improves 2-5% over uniform weights.

GradNorm: Gradient Magnitude Normalization

Chen et al. proposed GradNorm in 2018, balancing tasks by adjusting weights to make gradient magnitudes balanced.

Core Idea: Each task's gradient magnitude should be proportional to its training speed.

Given task's loss at timeas, define relative inverse training rate: means tasktrains slowly,means trains fast.

Objective: Adjust weightssuch that:whereis average gradient magnitude across all tasks,is hyperparameter (typically 1.5).

Algorithm:

  1. Forward pass, compute weighted loss$L = t w_t L_tG_t = |L_t|w_tw_t w_t T / _i w_i$ Effect: GradNorm shows significant improvement (3-8%) over uniform weights and uncertainty weighting on multiple datasets.

PCGrad: Projecting Conflicting Gradients

Yu et al. proposed Projecting Conflicting Gradients (PCGrad) in 2020, directly eliminating gradient conflicts.

Core Idea: When two tasks' gradients conflict, project one gradient onto the other gradient's normal plane.

For tasksand, if, replacewith:This is projection ofonto orthogonal complement of, guaranteeing(no conflict).

Algorithm (fortasks):

1
2
3
4
5
6
7
8
9
for each task i:
g_i = compute gradient of task i
for each other task j != i:
if g_i . g_j < 0:
g_i = g_i - (g_i . g_j / ||g_j||^2) * g_j
store modified gradient g_i

final_gradient = mean of all modified g_i
update parameters with final_gradient

Theoretical Guarantee: PCGrad guarantees for all tasks,, meaning update direction at least doesn't increase any task's loss.

Experiments (NYUv2 dataset, semantic segmentation + depth estimation + surface normals): - Uniform weights: mIoU 40.2%, depth error 0.61 - PCGrad: mIoU 42.7%, depth error 0.58

PCGrad significantly alleviates gradient conflicts, improving all tasks' performance.

CAGrad: Conflict-Averse Gradient Descent

Liu et al. proposed Conflict-Averse Gradient descent (CAGrad) in 2021, seeking Pareto optimal gradient direction.

Pareto Optimality: A solution is Pareto optimal if and only if no other solution exists that can improve other tasks without degrading some task.

CAGrad models gradient selection as optimization problem:That is, find minimum norm gradient while not conflicting with any task gradient.

This is a quadratic programming (QP) problem, solvable efficiently with existing solvers (like CVXPY).

Experiments: CAGrad achieves best Pareto front on multiple datasets, superior to PCGrad and GradNorm.

MGDA: Multi-Objective Gradient Descent Algorithm

Multi-Objective Gradient Descent Algorithm (MGDA) proposed by D é sid é ri in 2012, seeks common descent direction for all tasks.

Core Idea: Find gradientwhose inner product with all task gradients is positive (descent direction for all tasks).

Formalized as:This is also a convex optimization problem, solvable with Frank-Wolfe algorithm.

Comparison with PCGrad: - PCGrad handles conflicts pairwise, simple computation but potentially suboptimal - MGDA globally optimizes, theoretically better but complexity

Auxiliary Task Design: How to Choose Auxiliary Tasks

Auxiliary task selection and design are crucial to multi-task learning success.

Self-Supervised Auxiliary Tasks

Self-supervised learning tasks can serve as universal auxiliary tasks, requiring no additional annotation.

Rotation Prediction:

Gidaris et al. proposed in 2018, rotating images by 0/90/180/270 degrees and having model predict rotation angle.

Loss function:Extra close brace or missing open braceL_{\text{rot }} = -\sum_{r \in \{0, 90, 180, 270} } \log P(r | \text{rotate}(x, r))This task forces model to learn object orientation and structural information.

Jigsaw Puzzles:

Noroozi and Favaro proposed in 2016, dividing image into 9 patches and shuffling them, having model predict correct arrangement.

This task makes model learn spatial relationships and object part positions.

Contrastive Learning:

SimCLR, MoCo and other contrastive learning methods can also serve as auxiliary tasks. For sampleand its augmented version:Contrastive learning helps model learn robust representations.

Domain-Specific Auxiliary Tasks

Design targeted auxiliary tasks based on primary task characteristics.

Computer Vision: - Primary Task: Object detection - Auxiliary Tasks: Edge detection, depth estimation, surface normal prediction

Edge detection helps model better localize object boundaries, depth estimation provides 3D geometric information.

Natural Language Processing: - Primary Task: Named Entity Recognition (NER) - Auxiliary Tasks: Part-of-Speech tagging (POS), syntactic dependency parsing

POS tagging provides grammatical information about words, dependency parsing provides sentence structure, both helpful for NER.

Recommendation Systems: - Primary Task: Click-Through Rate (CTR) prediction - Auxiliary Tasks: Conversion Rate (CVR) prediction, dwell time prediction

User click behavior, conversion behavior, dwell time reflect different levels of interest, joint modeling learns more comprehensive user representations.

Curriculum Learning: Task Sequence

Sometimes auxiliary task introduction sequence matters, involving curriculum learning.

Simple to Complex:

Start with simple auxiliary tasks, gradually introduce complex tasks.

For example, in image classification: 1. First pre-train with self-supervised tasks (rotation prediction) 2. Then introduce coarse-grained classification tasks (large categories) 3. Finally perform fine-grained classification tasks (small categories)

Task Switching Strategy:

Graves et al. proposed Automated Curriculum Learning in 2017, using reinforcement learning to dynamically decide when to switch tasks:

  • Current task's learning progress (loss descent speed)
  • Inter-task correlations
  • Primary task's validation performance

Learn optimal task switching timing through policy network.

Complete Code Implementation: Multi-Task Learning Framework

Below is a complete multi-task learning implementation including hard parameter sharing, gradient surgery (PCGrad), dynamic weight adjustment (GradNorm) and other methods.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from typing import List, Tuple, Dict, Optional
import copy

# ============== Multi-Task Network Architecture ==============

class SharedEncoder(nn.Module):
"""Shared encoder: First few layers of ResNet-18"""
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18(pretrained=False)
# Use ResNet's first 3 blocks
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)

return x


class TaskHead(nn.Module):
"""Task-specific head"""
def __init__(self, in_channels: int, num_classes: int, task_type: str = 'classification'):
super().__init__()
self.task_type = task_type

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(in_channels, 256)
self.dropout = nn.Dropout(0.5)

if task_type == 'classification':
self.fc2 = nn.Linear(256, num_classes)
elif task_type == 'regression':
self.fc2 = nn.Linear(256, num_classes)
else:
raise ValueError(f"Unknown task type: {task_type}")

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x


class MultiTaskNetwork(nn.Module):
"""Multi-task network: Hard parameter sharing architecture"""
def __init__(self, task_configs: List[Dict]):
"""
task_configs: Task configuration list, each element is dict
{
'name': task name,
'num_classes': number of classes,
'type': 'classification' or 'regression'
}
"""
super().__init__()
self.task_names = [cfg['name'] for cfg in task_configs]
self.num_tasks = len(task_configs)

# Shared encoder
self.shared_encoder = SharedEncoder()

# Task-specific heads
self.task_heads = nn.ModuleDict({
cfg['name']: TaskHead(
in_channels=256, # ResNet-18 layer3 output channels
num_classes=cfg['num_classes'],
task_type=cfg['type']
)
for cfg in task_configs
})

def forward(self, x: torch.Tensor, task_name: Optional[str] = None) -> Dict[str, torch.Tensor]:
"""
Forward pass
If task_name specified, only compute that task; otherwise compute all tasks
"""
shared_features = self.shared_encoder(x)

outputs = {}
if task_name is not None:
outputs[task_name] = self.task_heads[task_name](shared_features)
else:
for name in self.task_names:
outputs[name] = self.task_heads[name](shared_features)

return outputs


# ============== Multi-Task Loss Functions ==============

class MultiTaskLoss(nn.Module):
"""Multi-task loss: Supports different task types"""
def __init__(self, task_configs: List[Dict], loss_weights: Optional[Dict[str, float]] = None):
super().__init__()
self.task_configs = {cfg['name']: cfg for cfg in task_configs}

# If weights not specified, use uniform weights
if loss_weights is None:
self.loss_weights = {cfg['name']: 1.0 for cfg in task_configs}
else:
self.loss_weights = loss_weights

def compute_loss(self, outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Compute each task's loss"""
losses = {}

for task_name, output in outputs.items():
target = targets[task_name]
task_type = self.task_configs[task_name]['type']

if task_type == 'classification':
loss = F.cross_entropy(output, target)
elif task_type == 'regression':
loss = F.mse_loss(output, target)
else:
raise ValueError(f"Unknown task type: {task_type}")

losses[task_name] = loss

return losses

def forward(self, outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Compute weighted total loss"""
losses = self.compute_loss(outputs, targets)

total_loss = sum(self.loss_weights[name] * loss
for name, loss in losses.items())

return total_loss, losses


# ============== Gradient Surgery: PCGrad ==============

class PCGrad:
"""Projecting Conflicting Gradients"""
def __init__(self, optimizer: optim.Optimizer, task_names: List[str]):
self.optimizer = optimizer
self.task_names = task_names
self.num_tasks = len(task_names)

@staticmethod
def _project_conflicting(grad_i: torch.Tensor, grad_j: torch.Tensor) -> torch.Tensor:
"""Project grad_i onto normal plane of grad_j"""
inner_product = torch.dot(grad_i, grad_j)
if inner_product < 0:
# Conflict: project
proj = inner_product / (torch.norm(grad_j) ** 2 + 1e-8)
grad_i = grad_i - proj * grad_j
return grad_i

def step(self, losses: Dict[str, torch.Tensor]):
"""PCGrad optimization step"""
# 1. Compute each task's gradient
task_gradients = {}
for task_name in self.task_names:
self.optimizer.zero_grad()
losses[task_name].backward(retain_graph=True)

# Collect gradients
grads = []
for param in self.optimizer.param_groups[0]['params']:
if param.grad is not None:
grads.append(param.grad.clone().flatten())

task_gradients[task_name] = torch.cat(grads)

# 2. For each task, project out conflicting parts with other tasks
modified_gradients = {}
for i, task_i in enumerate(self.task_names):
grad_i = task_gradients[task_i].clone()

for j, task_j in enumerate(self.task_names):
if i != j:
grad_j = task_gradients[task_j]
grad_i = self._project_conflicting(grad_i, grad_j)

modified_gradients[task_i] = grad_i

# 3. Average modified gradients
avg_gradient = sum(modified_gradients.values()) / self.num_tasks

# 4. Set average gradient to parameters
self.optimizer.zero_grad()
idx = 0
for param in self.optimizer.param_groups[0]['params']:
if param.grad is not None:
param_size = param.numel()
param.grad = avg_gradient[idx:idx+param_size].view_as(param)
idx += param_size

# 5. Update parameters
self.optimizer.step()


# ============== Dynamic Weight Adjustment: GradNorm ==============

class GradNorm:
"""Gradient Normalization for Adaptive Loss Balancing"""
def __init__(
self,
model: nn.Module,
task_names: List[str],
alpha: float = 1.5,
lr_weights: float = 0.025
):
self.model = model
self.task_names = task_names
self.num_tasks = len(task_names)
self.alpha = alpha

# Task weights (learnable parameters)
self.task_weights = nn.Parameter(torch.ones(self.num_tasks))
self.weight_optimizer = optim.Adam([self.task_weights], lr=lr_weights)

# Record initial losses
self.initial_losses = None

def compute_grad_norm(self, loss: torch.Tensor, parameters: List[torch.nn.Parameter]) -> float:
"""Compute gradient norm of loss w.r.t. parameters"""
grads = torch.autograd.grad(loss, parameters, retain_graph=True, create_graph=True)
grad_norm = torch.norm(torch.cat([g.flatten() for g in grads]))
return grad_norm

def step(self, losses: Dict[str, torch.Tensor], epoch: int):
"""GradNorm update step"""
# Record initial losses (first epoch)
if self.initial_losses is None:
self.initial_losses = {name: loss.item() for name, loss in losses.items()}

# Compute weighted loss
weighted_losses = []
for i, task_name in enumerate(self.task_names):
weighted_losses.append(self.task_weights[i] * losses[task_name])

total_loss = sum(weighted_losses)

# Get shared parameters (only compute gradient norm for shared params)
shared_params = list(self.model.shared_encoder.parameters())

# Compute each task's gradient norm
grad_norms = []
for weighted_loss in weighted_losses:
grad_norm = self.compute_grad_norm(weighted_loss, shared_params)
grad_norms.append(grad_norm)

# Compute average gradient norm
avg_grad_norm = sum(grad_norms) / self.num_tasks

# Compute relative inverse training rates
relative_inverse_train_rates = []
for i, task_name in enumerate(self.task_names):
current_loss = losses[task_name].item()
initial_loss = self.initial_losses[task_name]
loss_ratio = current_loss / (initial_loss + 1e-8)

# Average loss ratio across all tasks
avg_loss_ratio = sum(
losses[t].item() / (self.initial_losses[t] + 1e-8)
for t in self.task_names
) / self.num_tasks

r_i = loss_ratio / (avg_loss_ratio + 1e-8)
relative_inverse_train_rates.append(r_i)

# GradNorm loss: make gradient norm proportional to relative training speed
grad_norm_loss = 0
for i in range(self.num_tasks):
target_grad_norm = avg_grad_norm * (relative_inverse_train_rates[i] ** self.alpha)
grad_norm_loss += torch.abs(grad_norms[i] - target_grad_norm)

# Update task weights
self.weight_optimizer.zero_grad()
grad_norm_loss.backward()
self.weight_optimizer.step()

# Normalize weights (sum to num_tasks)
with torch.no_grad():
self.task_weights.data = self.task_weights.data * self.num_tasks / self.task_weights.sum()

return total_loss, {name: self.task_weights[i].item()
for i, name in enumerate(self.task_names)}


# ============== Multi-Task Trainer ==============

class MultiTaskTrainer:
"""Multi-task learning trainer"""
def __init__(
self,
model: MultiTaskNetwork,
task_configs: List[Dict],
device: str = 'cuda',
optimization_method: str = 'uniform', # 'uniform', 'pcgrad', 'gradnorm'
initial_weights: Optional[Dict[str, float]] = None
):
self.model = model.to(device)
self.device = device
self.task_configs = task_configs
self.task_names = [cfg['name'] for cfg in task_configs]
self.optimization_method = optimization_method

# Loss function
self.criterion = MultiTaskLoss(task_configs, initial_weights)

# Optimizer
self.optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Gradient surgery
if optimization_method == 'pcgrad':
self.pcgrad = PCGrad(self.optimizer, self.task_names)
elif optimization_method == 'gradnorm':
self.gradnorm = GradNorm(
model,
self.task_names,
alpha=1.5,
lr_weights=0.025
)

def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train one epoch"""
self.model.train()

total_losses = {name: 0.0 for name in self.task_names}
total_losses['total'] = 0.0
num_batches = 0

for batch_data in train_loader:
inputs = batch_data['input'].to(self.device)
targets = {name: batch_data[name].to(self.device) for name in self.task_names}

# Forward pass
outputs = self.model(inputs)

# Compute losses
if self.optimization_method == 'uniform':
# Standard multi-task optimization
total_loss, losses = self.criterion(outputs, targets)

self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()

elif self.optimization_method == 'pcgrad':
# PCGrad optimization
_, losses = self.criterion(outputs, targets)
self.pcgrad.step(losses)
total_loss = sum(losses.values())

elif self.optimization_method == 'gradnorm':
# GradNorm optimization
_, losses = self.criterion(outputs, targets)
total_loss, task_weights = self.gradnorm.step(losses, epoch)

# Update criterion weights
self.criterion.loss_weights = task_weights

# Standard gradient update
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()

# Statistics
for name, loss in losses.items():
total_losses[name] += loss.item()
total_losses['total'] += total_loss.item()
num_batches += 1

# Average losses
avg_losses = {name: total / num_batches for name, total in total_losses.items()}
return avg_losses

@torch.no_grad()
def evaluate(self, test_loader: DataLoader) -> Dict[str, float]:
"""Evaluate model"""
self.model.eval()

total_losses = {name: 0.0 for name in self.task_names}
total_correct = {name: 0 for name in self.task_names}
total_samples = {name: 0 for name in self.task_names}
num_batches = 0

for batch_data in test_loader:
inputs = batch_data['input'].to(self.device)
targets = {name: batch_data[name].to(self.device) for name in self.task_names}

outputs = self.model(inputs)
_, losses = self.criterion(outputs, targets)

for name, loss in losses.items():
total_losses[name] += loss.item()

# Compute accuracy for classification tasks
if self.criterion.task_configs[name]['type'] == 'classification':
preds = outputs[name].argmax(dim=1)
total_correct[name] += (preds == targets[name]).sum().item()
total_samples[name] += targets[name].size(0)

num_batches += 1

# Average losses and accuracies
metrics = {}
for name in self.task_names:
metrics[f'{name}_loss'] = total_losses[name] / num_batches
if self.criterion.task_configs[name]['type'] == 'classification':
metrics[f'{name}_acc'] = 100.0 * total_correct[name] / total_samples[name]

return metrics

def train(self, train_loader: DataLoader, test_loader: DataLoader, num_epochs: int = 50):
"""Complete training workflow"""
for epoch in range(num_epochs):
print(f"\nEpoch {epoch + 1}/{num_epochs}")

# Train
train_metrics = self.train_epoch(train_loader, epoch)
print(f"Train - ", end="")
for name, value in train_metrics.items():
print(f"{name}: {value:.4f} ", end="")
print()

# Evaluate
test_metrics = self.evaluate(test_loader)
print(f"Test - ", end="")
for name, value in test_metrics.items():
print(f"{name}: {value:.4f} ", end="")
print()


# ============== Usage Example ==============

def create_dummy_multi_task_dataset(num_samples=1000, batch_size=32):
"""Create dummy dataset for demonstration"""
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
# Random image
image = torch.randn(3, 32, 32)

# Task 1: 10-class classification
task1_label = torch.randint(0, 10, (1,)).item()

# Task 2: 5-class classification
task2_label = torch.randint(0, 5, (1,)).item()

return {
'input': image,
'task1': task1_label,
'task2': task2_label
}

dataset = DummyDataset(num_samples)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return loader


def main():
# Task configurations
task_configs = [
{'name': 'task1', 'num_classes': 10, 'type': 'classification'},
{'name': 'task2', 'num_classes': 5, 'type': 'classification'}
]

# Create model
model = MultiTaskNetwork(task_configs)

# Create datasets
train_loader = create_dummy_multi_task_dataset(num_samples=1000, batch_size=32)
test_loader = create_dummy_multi_task_dataset(num_samples=200, batch_size=32)

# ========== Experiment 1: Uniform Weighting ==========
print("\n" + "="*60)
print("Experiment 1: Uniform Weighting")
print("="*60)

trainer_uniform = MultiTaskTrainer(
model=MultiTaskNetwork(task_configs),
task_configs=task_configs,
device='cuda' if torch.cuda.is_available() else 'cpu',
optimization_method='uniform'
)
trainer_uniform.train(train_loader, test_loader, num_epochs=10)

# ========== Experiment 2: PCGrad ==========
print("\n" + "="*60)
print("Experiment 2: PCGrad")
print("="*60)

trainer_pcgrad = MultiTaskTrainer(
model=MultiTaskNetwork(task_configs),
task_configs=task_configs,
device='cuda' if torch.cuda.is_available() else 'cpu',
optimization_method='pcgrad'
)
trainer_pcgrad.train(train_loader, test_loader, num_epochs=10)

# ========== Experiment 3: GradNorm ==========
print("\n" + "="*60)
print("Experiment 3: GradNorm")
print("="*60)

trainer_gradnorm = MultiTaskTrainer(
model=MultiTaskNetwork(task_configs),
task_configs=task_configs,
device='cuda' if torch.cuda.is_available() else 'cpu',
optimization_method='gradnorm'
)
trainer_gradnorm.train(train_loader, test_loader, num_epochs=10)


if __name__ == '__main__':
main()

Code Explanation

  1. Network Architecture:
    • SharedEncoder: Uses first 3 ResNet blocks as shared feature extractor
    • TaskHead: Task-specific heads supporting both classification and regression
    • MultiTaskNetwork: Hard parameter sharing architecture combining shared encoder and task heads
  2. Loss Functions:
    • MultiTaskLoss: Supports multiple task types (classification, regression)
    • Automatically selects appropriate loss functions based on task types
    • Supports custom task weights
  3. Gradient Optimization Methods:
    • Uniform Weighting: Standard multi-task optimization with uniform or custom weights
    • PCGrad: Projects conflicting gradients to eliminate inter-task conflicts
    • GradNorm: Dynamically adjusts task weights to balance gradient magnitudes
  4. Trainer:
    • MultiTaskTrainer: Unified interface supporting multiple optimization methods
    • Automatically handles forward pass, loss computation, gradient optimization
    • Provides evaluation metrics for each task

Comprehensive Q&A

Q1: When should I use multi-task learning?

A: Multi-task learning is suitable when:

  1. Tasks are related: Share common underlying features or structures
  2. Data is scarce: Auxiliary tasks provide additional supervision signals
  3. Computational constraints: Parameter sharing reduces model size
  4. Need simultaneous predictions: Multiple outputs required at inference

Not suitable when: - Tasks are completely unrelated or conflicting - Single task has abundant data and excellent standalone performance - Interpretability critical (hard to explain how tasks interact)

A: Several methods to measure task relatedness:

  1. Transfer Learning Experiments:
    • Train on task A, transfer to task B
    • If transfer improves over random initialization, tasks are related
  2. Gradient Correlation:
    • Compute cosine similarity of task gradients
    • High positive correlation indicates relatedness
  3. Feature Representation Similarity:
    • Use CKA or other metrics to measure learned feature similarity
    • High CKA score indicates shared features
  4. Empirical Testing:
    • Try multi-task learning; if both tasks improve, they're related
    • If one task degrades, may be conflicting

Q3: Hard sharing vs soft sharing - which to choose?

A: Choice depends on task characteristics:

Hard Sharing: - Suitable for: Highly related tasks (like multiple NLP tasks, multiple vision tasks) - Advantages: Most parameter efficient, strongest regularization - Disadvantages: Poor flexibility, high negative transfer risk

Soft Sharing: - Suitable for: Moderately related tasks, tasks requiring different features at different layers - Advantages: More flexible, lower negative transfer risk - Disadvantages: More parameters, more hyperparameters to tune

Recommendation: Start with hard sharing; if negative transfer occurs, try soft sharing or adaptive methods (like MTAN).

Q4: How to handle large differences in task loss scales?

A: Loss scale differences are common multi-task learning challenges. Solutions:

  1. Loss Normalization:Normalize by initial loss

  2. Uncertainty Weighting: Learn task uncertainty, weight by

  3. Gradient Magnitude Balancing (GradNorm): Dynamically adjust weights to balance gradient magnitudes

  4. Manual Scaling: Experimentally find appropriate task weights through grid search

Q5: How many auxiliary tasks should I add?

A: More not always better. Considerations:

  1. Task Relevance: Only add tasks related to primary task
  2. Computational Budget: Each task adds computation; balance cost and benefit
  3. Diminishing Returns: Beyond certain number, additional tasks provide little benefit

Empirical Guidelines: - Start with 1-2 most related auxiliary tasks - Gradually add tasks, observe primary task performance - Typically 2-4 auxiliary tasks sufficient - More than 10 tasks may require task clustering or hierarchical structures

Q6: What to do when gradient conflicts are severe?

A: Severe gradient conflicts require:

  1. Use PCGrad or CAGrad:
    • Directly eliminate gradient conflicts via projection
    • Significant improvements often seen
  2. Adjust Task Weights:
    • Reduce conflicting task weights
    • Or remove most conflicting tasks
  3. Change Sharing Strategy:
    • Hard sharing → soft sharing
    • Or reduce shared layer count
  4. Task Grouping:
    • Cluster tasks, different groups with separate shared parameters
  5. Sequential Training:
    • If conflicts unsolvable, train tasks sequentially rather than jointly

Q7: How to evaluate multi-task learning models?

A: Multi-task evaluation more complex than single-task:

  1. Individual Task Performance:

    • Evaluate each task's metrics (accuracy, F1, etc.)
    • Compare with single-task baselines
  2. Average Performance:

  3. Pareto Front:

    • Plot performance of each task, observe trade-offs
    • Good multi-task model should be on or near Pareto front
  4. Task-Specific Improvement:where MTL is multi-task performance, STL is single-task

  5. Computational Efficiency:

    • Compare total parameters, inference time
    • Multi-task should be more efficient

Q8: Can multi-task learning alleviate overfitting?

A: Yes, multi-task learning has strong regularization effects:

  1. Shared Parameter Constraints:
    • Shared parameters must satisfy multiple tasks, limiting overfitting on single task
  2. Data Augmentation Effect:
    • Auxiliary tasks provide additional training signals
    • Equivalent to expanding training set
  3. Experimental Evidence:
    • Numerous studies show multi-task learning improves generalization on small datasets
    • Especially significant with 2-10x data augmentation

But Note: - If auxiliary tasks unrelated, may increase overfitting - Task weights need proper tuning

Q9: How to apply multi-task learning to pre-trained models?

A: Several strategies for adding multi-task learning to pre-trained models:

  1. Adapter Method:
    • Insert task-specific adapter modules in pre-trained model
    • Only train adapters, freeze other parameters
    • Parameter efficient, avoids catastrophic forgetting
  2. Fine-tuning + Multi-Task Heads:
    • Add multiple task heads to pre-trained model
    • Fine-tune entire model or only specific layers
  3. Progressive Training:
    • First fine-tune on primary task
    • Then add auxiliary tasks for joint training
  4. Prompt-based Multi-Task Learning:
    • Use different prompts to distinguish tasks
    • Particularly effective in NLP (like T5, GPT)

Q10: What are future directions for multi-task learning research?

A: Several promising research directions:

  1. Automated Task Selection and Weighting:
    • Use meta-learning or RL to automatically find optimal task combinations and weights
    • Reduce manual tuning effort
  2. Continual Multi-Task Learning:
    • How to continuously add new tasks without forgetting old ones
    • Combine multi-task learning with continual learning
  3. Few-Shot Multi-Task Learning:
    • How to effectively share knowledge when tasks have very few samples
    • Combine meta-learning with multi-task learning
  4. Cross-Modal Multi-Task Learning:
    • Jointly train tasks across different modalities (vision, language, audio)
    • Learn universal multimodal representations
  5. Theory and Understanding:
    • Why does multi-task learning work? When does it work?
    • Theoretical analysis of task relatedness, negative transfer, gradient conflicts

Classic Papers

  1. Caruana, R., "Multitask Learning", Machine Learning 1997
    • Pioneering work on multi-task learning
    • Demonstrated regularization effects of shared representations
    • Link
  2. Ruder, S., "An Overview of Multi-Task Learning in Deep Neural Networks", arXiv 2017
    • Comprehensive survey of multi-task learning methods
    • Systematically summarizes different architectures and optimization methods
    • arXiv:1706.05098

Parameter Sharing

  1. Misra, I. et al., "Cross-Stitch Networks for Multi-task Learning", CVPR 2016
    • Proposed cross-stitch units allowing multi-layer information exchange
    • Soft parameter sharing method
    • arXiv:1604.03539
  2. Liu, S. et al., "End-to-End Multi-Task Learning with Attention", CVPR 2019
    • Multi-Task Attention Network (MTAN)
    • Dynamically select shared features via attention
    • arXiv:1803.10704

Gradient Conflicts and Balancing

  1. Chen, Z. et al., "GradNorm: Gradient Normalization for Adaptive Loss Balancing", ICML 2018
    • Proposed GradNorm algorithm
    • Dynamically adjust task weights to balance gradient magnitudes
    • arXiv:1711.02257
  2. Yu, T. et al., "Gradient Surgery for Multi-Task Learning", NeurIPS 2020
    • Proposed PCGrad algorithm
    • Eliminate gradient conflicts via projection
    • arXiv:2001.06782
  3. Liu, B. et al., "Conflict-Averse Gradient Descent for Multi-task Learning", NeurIPS 2021
    • Proposed CAGrad algorithm
    • Find Pareto optimal gradient direction via QP
    • arXiv:2110.14048

Task Relationship Learning

  1. Zamir, A. R. et al., "Taskonomy: Disentangling Task Transfer Learning", CVPR 2018
    • Large-scale study of task relationships
    • Constructed task affinity matrix
    • arXiv:1804.08328
  2. Standley, T. et al., "Which Tasks Should Be Learned Together in Multi-task Learning?", ICML 2020
    • Automated task grouping using RL
    • Found optimal task combinations
    • arXiv:1905.07553

Uncertainty Weighting

  1. Kendall, A. et al., "Multi-Task Learning Using Uncertainty to Weigh Losses", CVPR 2018
    • Proposed uncertainty-based automatic weighting
    • Learn task uncertainty parameters
    • arXiv:1705.07115

Applications

  1. He, K. et al., "Mask R-CNN", ICCV 2017
    • Multi-task learning in object detection
    • Simultaneous detection, segmentation, keypoint detection
    • arXiv:1703.06870
  2. Liu, S. et al., "Multi-Task Deep Neural Networks for Natural Language Understanding", ACL 2019
    • Multi-task learning in NLP
    • Jointly train multiple language understanding tasks
    • arXiv:1901.11504

Summary

Multi-task learning is a powerful paradigm that improves model generalization and computational efficiency by simultaneously learning multiple related tasks. This article derived multi-task learning's mathematical foundations from first principles, analyzed hard vs soft parameter sharing strategies in detail, deeply explained gradient conflict problems and solutions (PCGrad, GradNorm, CAGrad), introduced auxiliary task design principles, and provided complete multi-task network implementations.

We saw that multi-task learning's core is finding Pareto optimal solutions satisfying multiple optimization objectives. Through proper task selection, architecture design, and gradient optimization, multi-task learning can significantly improve model performance while reducing parameters and computation. From computer vision to natural language processing to recommendation systems, multi-task learning has become an indispensable tool in modern machine learning.

Next chapter we'll explore zero-shot learning, investigating how models can recognize unseen classes without any labeled examples.

  • Post title:Transfer Learning (6): Multi-Task Learning
  • Post author:Chen Kai
  • Create time:2024-12-03 14:00:00
  • Post link:https://www.chenk.top/transfer-learning-6-multi-task-learning/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments