Multi-Task Learning (MTL) is a machine learning paradigm that improves model generalization by simultaneously learning multiple related tasks. Rich Caruana's pioneering 1997 paper "Multitask Learning" demonstrated how shared representations help models learn more robust features. In modern deep learning, multi-task learning has achieved tremendous success in computer vision (simultaneous detection, segmentation, depth estimation), natural language processing (joint entity recognition and relation extraction), and recommendation systems (simultaneous CTR and CVR prediction). But multi-task learning is far more than simply summing multiple loss functions — how to design shared structures, how to balance learning across different tasks, and how to handle negative transfer between tasks are all questions requiring deep investigation.
This article derives the mathematical foundations of multi-task learning from first principles, analyzes the pros and cons of hard vs soft parameter sharing, explains task relationship learning and task clustering methods in detail, deeply analyzes gradient conflict problems and solutions (PCGrad, GradNorm, CAGrad, etc.), introduces auxiliary task design principles, and provides a complete multi-task network implementation (including dynamic weight adjustment, gradient projection, task balancing and other industrial-grade techniques). We'll see that multi-task learning essentially seeks a Pareto optimal solution satisfying multiple optimization objectives.
Motivation for Multi-Task Learning
From Single-Task to Multi-Task: Sharing Inductive Bias
Single-task learning trains an independent model for each task, while multi-task learning has all tasks share some parameters or representations. The core assumption behind this is: related tasks share common underlying structures.
Intuitive Example: Consider three tasks for image scene understanding: - Object Detection: Identify object locations and categories in images - Semantic Segmentation: Assign category labels to each pixel - Depth Estimation: Predict depth value for each pixel
All three tasks require understanding spatial structure, object boundaries, texture information and other low-level features. Rather than independently learning these features for each task, they should share a feature extractor with only high-level task-specific heads.
Mathematical Perspective on Multi-Task Learning: Regularization Effect
From an optimization perspective, multi-task learning introduces
implicit regularization. Given
Key Insight: Shared parameters
Data Augmentation Perspective: Auxiliary Tasks Provide Additional Signals
Multi-task learning can be viewed as a data augmentation strategy. When the main task has limited labeled data, auxiliary tasks can provide additional supervisory signals.
Example: In low-resource language machine
translation: - Main Task: English
Although French and Swahili are different, the English encoder can learn better English representations from abundant English-French data, thereby helping English-Swahili translation.
Experiments show that introducing auxiliary tasks can improve main task performance by 5-20% (depending on task relatedness and data volume).
Computational Efficiency: Parameter Sharing Reduces Redundancy
From an engineering perspective, multi-task learning significantly reduces model parameters and computation through parameter sharing:
- Single-Task:
tasks each have a ResNet-50 encoder, total parameters (assuming ) - Multi-Task:
tasks share one encoder, total parameters (each task head has 2M parameters)
Parameters reduced by about 70%, and inference requires only one forward pass to obtain all task outputs, dramatically improving efficiency.
Negative Transfer: The Risk of Multi-Task Learning
However, multi-task learning is not always beneficial. When tasks are unrelated or even conflicting, negative transfer may occur: joint training performance lower than separate training.
Example: - Task A: Face recognition (requires learning fine-grained facial features) - Task B: Scene classification (requires learning global layout and context)
These two tasks have very different feature requirements; forcing parameter sharing may cause mutual interference.
Experimental Data (CIFAR-100): - Train Task A separately: 82% accuracy - Train Task B separately: 78% accuracy - Joint training (naive MTL): 79% and 74% (both tasks decline)
Therefore, how to design shared structures, how to select related tasks, how to balance task weights are keys to multi-task learning success.
Parameter Sharing Strategies: Hard Sharing vs Soft Sharing
The core of multi-task learning is how to share information between tasks. There are two main paradigms: hard parameter sharing and soft parameter sharing.
Hard Parameter Sharing
Hard parameter sharing is the most common multi-task learning architecture, proposed by Caruana in 1993.
Architecture Design: - Shared Layers: All tasks share the same bottom-level network (like convolutional layers, Transformer layers) - Task-Specific Layers: Each task has independent output heads (like fully connected layers, decoders)
Formally, for input
Advantages: 1. Strong Regularization: Shared parameters constrained by multiple tasks, reducing overfitting risk 2. Parameter Efficiency: Most parameters shared, compact model 3. Simple and Direct: Easy to implement and train
Disadvantages: 1. Poor Flexibility: All tasks must use same shared representation, unsuitable for highly divergent tasks 2. High Negative Transfer Risk: Conflicting tasks interfere with each other
Empirical Design Principles: - Shared layers should learn general features (like CNN low layers learning edges, textures) - Task-specific layers should have sufficient capacity to handle task-specific patterns - Typically share first 70-80% of layers, keep last 20-30% independent
Soft Parameter Sharing
Soft parameter sharing was proposed by Duong et al. in 2015, allowing each task to have its own parameters but encouraging parameter similarity through regularization.
Basic Form: Each task has independent model
Cross-Stitch Networks:
Misra et al. proposed Cross-Stitch Networks in 2016, allowing tasks to exchange information at multiple levels.
Given two tasks' layer
Multi-Task Attention Network (MTAN):
Liu et al. proposed Multi-Task Attention Network in 2019, using attention mechanisms to dynamically select which features to share.
For shared feature
Advantages: 1. High Flexibility: Each task can have different parameters, strong adaptability 2. Low Negative Transfer Risk: Tasks can selectively ignore irrelevant information
Disadvantages: 1. Large Parameter Count: Each task has independent parameters, model inflation 2. Complex Training: Requires careful tuning of regularization strength
Dynamic Network Architectures: Conditional Computation
Recently, dynamic networks allow models to dynamically adjust computation paths based on input or task.
Routing Networks:
Rosenbaum et al. proposed in 2018, using routing functions to decide which subnetworks each task uses.
Given
Task-Conditional Adapters:
Rebuffi et al. proposed in 2017, inserting task-specific adapter modules at each layer of a pre-trained model.
For task
Advantage: When adding new tasks to pre-trained models, only need to train adapters, efficient and avoids catastrophic forgetting.
Task Relationship Learning: Discovering Correlations
Multi-task learning effectiveness largely depends on inter-task correlations. How to quantify and utilize task relationships is an important research direction.
Task Affinity Matrix
Task affinity matrix
Computation Method 1: Performance Correlation
Fifty et al. proposed Taskonomy in 2021, measuring task affinity through transfer learning experiments:
- Train model on task$i
j P_{i j} A_{ij} = P_{i j} - P_{} P_{}$is random initialization baseline performance.
Computation Method 2: Gradient Correlation
Yu et al. proposed in 2020, affinity based on gradient cosine
similarity:
Computation Method 3: Feature Representation Similarity
Compute CKA (Centered Kernel Alignment) of feature representations
learned during training:
Task Clustering: Grouped Sharing
When task count is large, can first cluster tasks, then tasks within same group share parameters.
Hierarchical Multi-Task Learning:
Assume hierarchical relationships between tasks, like: - Coarse-grained Task: Scene classification (indoor vs outdoor) - Fine-grained Task: Specific scene categories (bedroom, kitchen, street, park)
Can design hierarchical network: 1. Shared layer extracts general features 2. Middle layer for coarse-grained task 3. Top layer for fine-grained tasks, depends on middle layer output
Loss function:
Adaptive Task Grouping:
Standley et al. proposed "Which Tasks Should Be Learned Together?" in 2020, using reinforcement learning to automatically search optimal task grouping.
Algorithm workflow: 1. Initialize each task training independently 2. Use policy network to sample task grouping schemes 3. Train multi-task model according to grouping, evaluate validation performance 4. Use performance as reward, update policy network 5. Repeat until finding optimal grouping
Experiments show automated grouping more effective than manual design or global sharing.
Task Selection: Choosing Primary and Auxiliary Tasks
When there's one primary task and multiple candidate auxiliary tasks, how to select the most helpful auxiliary tasks?
Greedy Selection Strategy:
- Train primary task alone, record performance$P_0
t t$ - Record primary task performance - Compute gain$P_t = P_t - P_0 K$auxiliary tasks with highest gains
Meta-Learning Based Selection:
Du et al. proposed Automated Auxiliary Learning in 2020, using meta-learning to predict auxiliary task effectiveness:
- Rapidly train model on small data
- Use meta-model to predict each auxiliary task's help to primary task
- Select auxiliary tasks with highest predicted benefits
Advantage: Avoids overhead of fully training all candidate tasks.
Gradient Conflicts and Task Balancing
One of the biggest challenges in multi-task learning is gradient conflict: different tasks' gradients may point in different directions, causing training instability or performance degradation.
Problem Analysis: What is Gradient Conflict
Given two tasks' gradients
Problem: If
Example: - Task 1 gradient:
Formally, gradient conflict defined as:
Static Weight Methods: Manual Tuning
Simplest method is manually setting task weights
Uniform Weights:
Uncertainty Weighting:
Kendall et al. proposed in 2018, using task uncertainty to automatically adjust weights.
Assume task
Intuition: - If task
Experiments show uncertainty weighting improves 2-5% over uniform weights.
GradNorm: Gradient Magnitude Normalization
Chen et al. proposed GradNorm in 2018, balancing tasks by adjusting weights to make gradient magnitudes balanced.
Core Idea: Each task's gradient magnitude should be proportional to its training speed.
Given task
Objective: Adjust weights
Algorithm:
- Forward pass, compute weighted loss$L = t w_t L_t
G_t = |L_t| w_t w_t w_t T / _i w_i$ Effect: GradNorm shows significant improvement (3-8%) over uniform weights and uncertainty weighting on multiple datasets.
PCGrad: Projecting Conflicting Gradients
Yu et al. proposed Projecting Conflicting Gradients (PCGrad) in 2020, directly eliminating gradient conflicts.
Core Idea: When two tasks' gradients conflict, project one gradient onto the other gradient's normal plane.
For tasks
Algorithm (for
1 | for each task i: |
Theoretical Guarantee: PCGrad guarantees for all
tasks
Experiments (NYUv2 dataset, semantic segmentation + depth estimation + surface normals): - Uniform weights: mIoU 40.2%, depth error 0.61 - PCGrad: mIoU 42.7%, depth error 0.58
PCGrad significantly alleviates gradient conflicts, improving all tasks' performance.
CAGrad: Conflict-Averse Gradient Descent
Liu et al. proposed Conflict-Averse Gradient descent (CAGrad) in 2021, seeking Pareto optimal gradient direction.
Pareto Optimality: A solution is Pareto optimal if and only if no other solution exists that can improve other tasks without degrading some task.
CAGrad models gradient selection as optimization problem:
This is a quadratic programming (QP) problem, solvable efficiently with existing solvers (like CVXPY).
Experiments: CAGrad achieves best Pareto front on multiple datasets, superior to PCGrad and GradNorm.
MGDA: Multi-Objective Gradient Descent Algorithm
Multi-Objective Gradient Descent Algorithm (MGDA) proposed by D é sid é ri in 2012, seeks common descent direction for all tasks.
Core Idea: Find gradient
Formalized as:
Comparison with PCGrad: - PCGrad handles conflicts
pairwise, simple computation but potentially suboptimal - MGDA globally
optimizes, theoretically better but complexity
Auxiliary Task Design: How to Choose Auxiliary Tasks
Auxiliary task selection and design are crucial to multi-task learning success.
Self-Supervised Auxiliary Tasks
Self-supervised learning tasks can serve as universal auxiliary tasks, requiring no additional annotation.
Rotation Prediction:
Gidaris et al. proposed in 2018, rotating images by 0/90/180/270 degrees and having model predict rotation angle.
Loss function:
Jigsaw Puzzles:
Noroozi and Favaro proposed in 2016, dividing image into 9 patches and shuffling them, having model predict correct arrangement.
This task makes model learn spatial relationships and object part positions.
Contrastive Learning:
SimCLR, MoCo and other contrastive learning methods can also serve as
auxiliary tasks. For sample
Domain-Specific Auxiliary Tasks
Design targeted auxiliary tasks based on primary task characteristics.
Computer Vision: - Primary Task: Object detection - Auxiliary Tasks: Edge detection, depth estimation, surface normal prediction
Edge detection helps model better localize object boundaries, depth estimation provides 3D geometric information.
Natural Language Processing: - Primary Task: Named Entity Recognition (NER) - Auxiliary Tasks: Part-of-Speech tagging (POS), syntactic dependency parsing
POS tagging provides grammatical information about words, dependency parsing provides sentence structure, both helpful for NER.
Recommendation Systems: - Primary Task: Click-Through Rate (CTR) prediction - Auxiliary Tasks: Conversion Rate (CVR) prediction, dwell time prediction
User click behavior, conversion behavior, dwell time reflect different levels of interest, joint modeling learns more comprehensive user representations.
Curriculum Learning: Task Sequence
Sometimes auxiliary task introduction sequence matters, involving curriculum learning.
Simple to Complex:
Start with simple auxiliary tasks, gradually introduce complex tasks.
For example, in image classification: 1. First pre-train with self-supervised tasks (rotation prediction) 2. Then introduce coarse-grained classification tasks (large categories) 3. Finally perform fine-grained classification tasks (small categories)
Task Switching Strategy:
Graves et al. proposed Automated Curriculum Learning in 2017, using reinforcement learning to dynamically decide when to switch tasks:
- Current task's learning progress (loss descent speed)
- Inter-task correlations
- Primary task's validation performance
Learn optimal task switching timing through policy network.
Complete Code Implementation: Multi-Task Learning Framework
Below is a complete multi-task learning implementation including hard parameter sharing, gradient surgery (PCGrad), dynamic weight adjustment (GradNorm) and other methods.
1 | import torch |
Code Explanation
- Network Architecture:
SharedEncoder: Uses first 3 ResNet blocks as shared feature extractorTaskHead: Task-specific heads supporting both classification and regressionMultiTaskNetwork: Hard parameter sharing architecture combining shared encoder and task heads
- Loss Functions:
MultiTaskLoss: Supports multiple task types (classification, regression)- Automatically selects appropriate loss functions based on task types
- Supports custom task weights
- Gradient Optimization Methods:
- Uniform Weighting: Standard multi-task optimization with uniform or custom weights
- PCGrad: Projects conflicting gradients to eliminate inter-task conflicts
- GradNorm: Dynamically adjusts task weights to balance gradient magnitudes
- Trainer:
MultiTaskTrainer: Unified interface supporting multiple optimization methods- Automatically handles forward pass, loss computation, gradient optimization
- Provides evaluation metrics for each task
Comprehensive Q&A
Q1: When should I use multi-task learning?
A: Multi-task learning is suitable when:
- Tasks are related: Share common underlying features or structures
- Data is scarce: Auxiliary tasks provide additional supervision signals
- Computational constraints: Parameter sharing reduces model size
- Need simultaneous predictions: Multiple outputs required at inference
Not suitable when: - Tasks are completely unrelated or conflicting - Single task has abundant data and excellent standalone performance - Interpretability critical (hard to explain how tasks interact)
Q2: How to determine if tasks are related?
A: Several methods to measure task relatedness:
- Transfer Learning Experiments:
- Train on task A, transfer to task B
- If transfer improves over random initialization, tasks are related
- Gradient Correlation:
- Compute cosine similarity of task gradients
- High positive correlation indicates relatedness
- Feature Representation Similarity:
- Use CKA or other metrics to measure learned feature similarity
- High CKA score indicates shared features
- Empirical Testing:
- Try multi-task learning; if both tasks improve, they're related
- If one task degrades, may be conflicting
Q3: Hard sharing vs soft sharing - which to choose?
A: Choice depends on task characteristics:
Hard Sharing: - Suitable for: Highly related tasks (like multiple NLP tasks, multiple vision tasks) - Advantages: Most parameter efficient, strongest regularization - Disadvantages: Poor flexibility, high negative transfer risk
Soft Sharing: - Suitable for: Moderately related tasks, tasks requiring different features at different layers - Advantages: More flexible, lower negative transfer risk - Disadvantages: More parameters, more hyperparameters to tune
Recommendation: Start with hard sharing; if negative transfer occurs, try soft sharing or adaptive methods (like MTAN).
Q4: How to handle large differences in task loss scales?
A: Loss scale differences are common multi-task learning challenges. Solutions:
Loss Normalization:
Normalize by initial loss Uncertainty Weighting: Learn task uncertainty
, weight by Gradient Magnitude Balancing (GradNorm): Dynamically adjust weights to balance gradient magnitudes
Manual Scaling: Experimentally find appropriate task weights through grid search
Q5: How many auxiliary tasks should I add?
A: More not always better. Considerations:
- Task Relevance: Only add tasks related to primary task
- Computational Budget: Each task adds computation; balance cost and benefit
- Diminishing Returns: Beyond certain number, additional tasks provide little benefit
Empirical Guidelines: - Start with 1-2 most related auxiliary tasks - Gradually add tasks, observe primary task performance - Typically 2-4 auxiliary tasks sufficient - More than 10 tasks may require task clustering or hierarchical structures
Q6: What to do when gradient conflicts are severe?
A: Severe gradient conflicts require:
- Use PCGrad or CAGrad:
- Directly eliminate gradient conflicts via projection
- Significant improvements often seen
- Adjust Task Weights:
- Reduce conflicting task weights
- Or remove most conflicting tasks
- Change Sharing Strategy:
- Hard sharing → soft sharing
- Or reduce shared layer count
- Task Grouping:
- Cluster tasks, different groups with separate shared parameters
- Sequential Training:
- If conflicts unsolvable, train tasks sequentially rather than jointly
Q7: How to evaluate multi-task learning models?
A: Multi-task evaluation more complex than single-task:
Individual Task Performance:
- Evaluate each task's metrics (accuracy, F1, etc.)
- Compare with single-task baselines
Average Performance:
Pareto Front:
- Plot performance of each task, observe trade-offs
- Good multi-task model should be on or near Pareto front
Task-Specific Improvement:
where MTL is multi-task performance, STL is single-task Computational Efficiency:
- Compare total parameters, inference time
- Multi-task should be more efficient
Q8: Can multi-task learning alleviate overfitting?
A: Yes, multi-task learning has strong regularization effects:
- Shared Parameter Constraints:
- Shared parameters must satisfy multiple tasks, limiting overfitting on single task
- Data Augmentation Effect:
- Auxiliary tasks provide additional training signals
- Equivalent to expanding training set
- Experimental Evidence:
- Numerous studies show multi-task learning improves generalization on small datasets
- Especially significant with 2-10x data augmentation
But Note: - If auxiliary tasks unrelated, may increase overfitting - Task weights need proper tuning
Q9: How to apply multi-task learning to pre-trained models?
A: Several strategies for adding multi-task learning to pre-trained models:
- Adapter Method:
- Insert task-specific adapter modules in pre-trained model
- Only train adapters, freeze other parameters
- Parameter efficient, avoids catastrophic forgetting
- Fine-tuning + Multi-Task Heads:
- Add multiple task heads to pre-trained model
- Fine-tune entire model or only specific layers
- Progressive Training:
- First fine-tune on primary task
- Then add auxiliary tasks for joint training
- Prompt-based Multi-Task Learning:
- Use different prompts to distinguish tasks
- Particularly effective in NLP (like T5, GPT)
Q10: What are future directions for multi-task learning research?
A: Several promising research directions:
- Automated Task Selection and Weighting:
- Use meta-learning or RL to automatically find optimal task combinations and weights
- Reduce manual tuning effort
- Continual Multi-Task Learning:
- How to continuously add new tasks without forgetting old ones
- Combine multi-task learning with continual learning
- Few-Shot Multi-Task Learning:
- How to effectively share knowledge when tasks have very few samples
- Combine meta-learning with multi-task learning
- Cross-Modal Multi-Task Learning:
- Jointly train tasks across different modalities (vision, language, audio)
- Learn universal multimodal representations
- Theory and Understanding:
- Why does multi-task learning work? When does it work?
- Theoretical analysis of task relatedness, negative transfer, gradient conflicts
Related Papers
Classic Papers
- Caruana, R., "Multitask Learning", Machine Learning
1997
- Pioneering work on multi-task learning
- Demonstrated regularization effects of shared representations
- Link
- Ruder, S., "An Overview of Multi-Task Learning in Deep
Neural Networks", arXiv 2017
- Comprehensive survey of multi-task learning methods
- Systematically summarizes different architectures and optimization methods
- arXiv:1706.05098
Parameter Sharing
- Misra, I. et al., "Cross-Stitch Networks for Multi-task
Learning", CVPR 2016
- Proposed cross-stitch units allowing multi-layer information exchange
- Soft parameter sharing method
- arXiv:1604.03539
- Liu, S. et al., "End-to-End Multi-Task Learning with
Attention", CVPR 2019
- Multi-Task Attention Network (MTAN)
- Dynamically select shared features via attention
- arXiv:1803.10704
Gradient Conflicts and Balancing
- Chen, Z. et al., "GradNorm: Gradient Normalization for
Adaptive Loss Balancing", ICML 2018
- Proposed GradNorm algorithm
- Dynamically adjust task weights to balance gradient magnitudes
- arXiv:1711.02257
- Yu, T. et al., "Gradient Surgery for Multi-Task Learning",
NeurIPS 2020
- Proposed PCGrad algorithm
- Eliminate gradient conflicts via projection
- arXiv:2001.06782
- Liu, B. et al., "Conflict-Averse Gradient Descent for
Multi-task Learning", NeurIPS 2021
- Proposed CAGrad algorithm
- Find Pareto optimal gradient direction via QP
- arXiv:2110.14048
Task Relationship Learning
- Zamir, A. R. et al., "Taskonomy: Disentangling Task Transfer
Learning", CVPR 2018
- Large-scale study of task relationships
- Constructed task affinity matrix
- arXiv:1804.08328
- Standley, T. et al., "Which Tasks Should Be Learned Together
in Multi-task Learning?", ICML 2020
- Automated task grouping using RL
- Found optimal task combinations
- arXiv:1905.07553
Uncertainty Weighting
- Kendall, A. et al., "Multi-Task Learning Using Uncertainty
to Weigh Losses", CVPR 2018
- Proposed uncertainty-based automatic weighting
- Learn task uncertainty parameters
- arXiv:1705.07115
Applications
- He, K. et al., "Mask R-CNN", ICCV 2017
- Multi-task learning in object detection
- Simultaneous detection, segmentation, keypoint detection
- arXiv:1703.06870
- Liu, S. et al., "Multi-Task Deep Neural Networks for Natural
Language Understanding", ACL 2019
- Multi-task learning in NLP
- Jointly train multiple language understanding tasks
- arXiv:1901.11504
Summary
Multi-task learning is a powerful paradigm that improves model generalization and computational efficiency by simultaneously learning multiple related tasks. This article derived multi-task learning's mathematical foundations from first principles, analyzed hard vs soft parameter sharing strategies in detail, deeply explained gradient conflict problems and solutions (PCGrad, GradNorm, CAGrad), introduced auxiliary task design principles, and provided complete multi-task network implementations.
We saw that multi-task learning's core is finding Pareto optimal solutions satisfying multiple optimization objectives. Through proper task selection, architecture design, and gradient optimization, multi-task learning can significantly improve model performance while reducing parameters and computation. From computer vision to natural language processing to recommendation systems, multi-task learning has become an indispensable tool in modern machine learning.
Next chapter we'll explore zero-shot learning, investigating how models can recognize unseen classes without any labeled examples.
- Post title:Transfer Learning (6): Multi-Task Learning
- Post author:Chen Kai
- Create time:2024-12-03 14:00:00
- Post link:https://www.chenk.top/transfer-learning-6-multi-task-learning/
- Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.