Recommendation Systems (9): Multi-Task Learning and Multi-Objective Optimization
Chen Kai BOSS

In real-world recommendation systems, optimizing for a single objective is rarely sufficient. When you browse an e-commerce platform, the system needs to predict not just whether you'll click on a product, but also whether you'll add it to cart, make a purchase, return it, or write a review. Each of these actions represents a different task with distinct patterns, yet they're all interconnected — a user who clicks is more likely to purchase, and someone who purchases is more likely to return. Multi-task learning (MTL) provides a powerful framework for jointly optimizing multiple objectives by sharing representations across related tasks, leading to improved performance on each individual task while reducing computational overhead.

Multi-task learning has become a cornerstone of modern recommendation systems, from Google's MMoE (Multi-gate Mixture-of-Experts) that handles conflicting objectives, to Alibaba's ESMM (Entire Space Multi-Task Model) that addresses sample selection bias in conversion prediction, to Tencent's PLE (Progressive Layered Extraction) that explicitly separates shared and task-specific knowledge. These architectures have demonstrated significant improvements over single-task models by leveraging the commonalities between tasks while preserving task-specific nuances.

This article provides a comprehensive exploration of multi-task learning for recommendation systems, covering foundational architectures (Shared-Bottom, ESMM, MMoE, PLE, STEM-Net), task relationship modeling techniques, loss balancing strategies, industrial applications and case studies, implementation details with 10+ code examples, and detailed Q&A sections addressing common challenges and best practices.

Why Multi-Task Learning for Recommendation?

The Multi-Objective Nature of Recommendation

Recommendation systems in production face multiple, often conflicting objectives:

E-commerce Platforms: - Click-through rate (CTR): Maximize user engagement - Conversion rate (CVR): Maximize purchases - Revenue per user: Maximize GMV (Gross Merchandise Value) - Return rate: Minimize returns - Review quality: Encourage positive reviews

Content Platforms: - Click-through rate: User engagement - Watch time: Content consumption depth - Like/share rate: Content virality - Comment rate: Community engagement - Subscription rate: Long-term retention

Advertising Systems: - Click-through rate: Ad relevance - Conversion rate: Purchase actions - Cost per acquisition (CPA): Efficiency - User lifetime value (LTV): Long-term value

These objectives are interrelated but distinct: optimizing for clicks might increase low-quality interactions, while optimizing solely for conversions might reduce overall engagement. Multi-task learning allows us to balance these objectives by learning shared representations that capture common patterns while maintaining task-specific heads that capture unique characteristics.

Sample Selection Bias Problem

A critical challenge in recommendation is sample selection bias. Consider conversion prediction:

  • Training data: Only contains samples where users clicked (positive samples for CTR)
  • Test data: Contains all samples (both clicked and non-clicked)
  • Problem: The model is trained on a biased distribution, leading to poor generalization

Traditional single-task models trained on clicked samples only cannot properly estimate the true conversion rate in the entire space. ESMM addresses this by modeling the entire space through the chain rule: .

Benefits of Multi-Task Learning

Multi-task learning offers several advantages:

  1. Data Efficiency: Shared representations allow tasks with limited data to benefit from tasks with abundant data
  2. Regularization: Learning shared features acts as implicit regularization, reducing overfitting
  3. Transfer Learning: Knowledge learned from one task transfers to related tasks
  4. Computational Efficiency: Shared bottom layers reduce computation compared to training separate models
  5. Better Generalization: Joint optimization often leads to more robust representations

Challenges in Multi-Task Learning

However, MTL also introduces challenges:

  1. Task Conflicts: Tasks may have conflicting gradients, causing negative transfer
  2. Loss Balancing: Different tasks have different scales and importance
  3. Task Relationships: Understanding which tasks benefit from sharing vs. separation
  4. Architecture Design: Balancing shared vs. task-specific components

Shared-Bottom Architecture

Basic Structure

The Shared-Bottom architecture is the simplest multi-task learning approach:

1
2
3
4
5
6
7
Input Features

Shared Bottom Layers (shared across all tasks)

Task-Specific Towers (one per task)

Task Outputs

Mathematical Formulation:

Fortasks, the shared-bottom model computes: where: -is the input feature vector -is the shared bottom network -is the task-specific tower for task -is the prediction for task

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torch.nn.functional as F

class SharedBottomMTL(nn.Module):
"""
Shared-Bottom Multi-Task Learning Model

All tasks share the same bottom layers and have
separate task-specific towers on top.
"""
def __init__(self, input_dim, shared_hidden_dims,
task_hidden_dims, num_tasks, task_types):
"""
Args:
input_dim: Dimension of input features
shared_hidden_dims: List of hidden dimensions for shared layers
task_hidden_dims: List of hidden dimensions for task-specific towers
num_tasks: Number of tasks
task_types: List of task types ('binary', 'regression', 'multiclass')
"""
super(SharedBottomMTL, self).__init__()
self.num_tasks = num_tasks
self.task_types = task_types

# Shared bottom layers
shared_layers = []
prev_dim = input_dim
for hidden_dim in shared_hidden_dims:
shared_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0.2)
])
prev_dim = hidden_dim

self.shared_bottom = nn.Sequential(*shared_layers)

# Task-specific towers
self.task_towers = nn.ModuleList()
for task_idx in range(num_tasks):
tower_layers = []
prev_dim = shared_hidden_dims[-1]
for hidden_dim in task_hidden_dims:
tower_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
])
prev_dim = hidden_dim

# Output layer based on task type
if task_types[task_idx] == 'binary':
tower_layers.append(nn.Linear(prev_dim, 1))
tower_layers.append(nn.Sigmoid())
elif task_types[task_idx] == 'regression':
tower_layers.append(nn.Linear(prev_dim, 1))
elif task_types[task_idx] == 'multiclass':
tower_layers.append(nn.Linear(prev_dim, num_classes))
tower_layers.append(nn.Softmax(dim=1))

self.task_towers.append(nn.Sequential(*tower_layers))

def forward(self, x):
"""
Forward pass through shared bottom and task-specific towers

Args:
x: Input tensor of shape (batch_size, input_dim)

Returns:
List of task predictions
"""
# Shared representation
shared_repr = self.shared_bottom(x)

# Task-specific predictions
task_outputs = []
for tower in self.task_towers:
output = tower(shared_repr)
task_outputs.append(output)

return task_outputs

# Example usage
model = SharedBottomMTL(
input_dim=128,
shared_hidden_dims=[256, 128, 64],
task_hidden_dims=[32, 16],
num_tasks=3,
task_types=['binary', 'binary', 'regression']
)

# Forward pass
batch_size = 32
x = torch.randn(batch_size, 128)
outputs = model(x)
print(f"Task 1 (CTR) output shape: {outputs[0].shape}")
print(f"Task 2 (CVR) output shape: {outputs[1].shape}")
print(f"Task 3 (Revenue) output shape: {outputs[2].shape}")

Limitations of Shared-Bottom

The shared-bottom architecture has a fundamental limitation: all tasks share the same representation, which can lead to:

  1. Negative Transfer: Conflicting tasks interfere with each other
  2. Insufficient Capacity: Shared layers may not capture task-specific patterns
  3. Gradient Conflicts: Tasks with conflicting objectives create opposing gradients

This motivates more sophisticated architectures like MMoE and PLE that allow selective sharing.

ESMM: Entire Space Multi-Task Model

Problem Formulation

ESMM addresses the sample selection bias problem in conversion prediction:

  • Traditional approach: Train CVR model only on clicked samples
  • Problem: Test on entire space (clicked + non-clicked), causing distribution mismatch
  • Solution: Model the entire space using the chain rule

Mathematical Foundation:This decomposition allows us to: 1. Train CTR model on all impressions 2. Train CVR model on clicked samples only 3. Combine them to get conversion probability in entire space

Architecture

ESMM consists of three components:

  1. CTR Tower: Predicts

  2. CVR Tower: Predicts

  3. CTCVR Tower: PredictsThe key insight: CTCVR = CTR × CVR, ensuring consistency.

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class ESMM(nn.Module):
"""
Entire Space Multi-Task Model (ESMM)

Addresses sample selection bias by modeling:
- CTR: P(click | impression) on all samples
- CVR: P(conversion | click) on clicked samples
- CTCVR: P(click & conversion | impression) = CTR × CVR
"""
def __init__(self, input_dim, hidden_dims):
super(ESMM, self).__init__()

# Shared embedding layer
self.embedding = nn.Sequential(
nn.Linear(input_dim, hidden_dims[0]),
nn.ReLU(),
nn.Dropout(0.2)
)

# CTR tower
ctr_layers = []
prev_dim = hidden_dims[0]
for hidden_dim in hidden_dims[1:]:
ctr_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2)
])
prev_dim = hidden_dim
ctr_layers.append(nn.Linear(prev_dim, 1))
ctr_layers.append(nn.Sigmoid())
self.ctr_tower = nn.Sequential(*ctr_layers)

# CVR tower (same architecture)
cvr_layers = []
prev_dim = hidden_dims[0]
for hidden_dim in hidden_dims[1:]:
cvr_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2)
])
prev_dim = hidden_dim
cvr_layers.append(nn.Linear(prev_dim, 1))
cvr_layers.append(nn.Sigmoid())
self.cvr_tower = nn.Sequential(*cvr_layers)

def forward(self, x):
"""
Forward pass

Args:
x: Input features (batch_size, input_dim)

Returns:
ctr: P(click | impression)
cvr: P(conversion | click)
ctcvr: P(click & conversion | impression) = CTR × CVR
"""
# Shared embedding
shared_repr = self.embedding(x)

# Task-specific predictions
ctr = self.ctr_tower(shared_repr)
cvr = self.cvr_tower(shared_repr)

# CTCVR = CTR × CVR (chain rule)
ctcvr = ctr * cvr

return ctr, cvr, ctcvr

# Loss function for ESMM
def esmm_loss(ctr_pred, cvr_pred, ctcvr_pred,
ctr_label, cvr_label, ctcvr_label,
click_mask):
"""
ESMM loss function

Args:
ctr_pred: Predicted CTR
cvr_pred: Predicted CVR
ctcvr_pred: Predicted CTCVR
ctr_label: True CTR label (click indicator)
cvr_label: True CVR label (conversion indicator, only valid for clicked)
ctcvr_label: True CTCVR label (click & conversion indicator)
click_mask: Mask indicating which samples were clicked

Returns:
Total loss
"""
# CTR loss: computed on all samples
ctr_loss = F.binary_cross_entropy(ctr_pred, ctr_label.float())

# CVR loss: computed only on clicked samples
cvr_loss = F.binary_cross_entropy(
cvr_pred[click_mask],
cvr_label[click_mask].float()
)

# CTCVR loss: computed on all samples
ctcvr_loss = F.binary_cross_entropy(ctcvr_pred, ctcvr_label.float())

# Total loss
total_loss = ctr_loss + cvr_loss + ctcvr_loss

return total_loss, ctr_loss, cvr_loss, ctcvr_loss

# Training example
model = ESMM(input_dim=128, hidden_dims=[256, 128, 64])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Simulated data
batch_size = 32
x = torch.randn(batch_size, 128)
ctr_label = torch.randint(0, 2, (batch_size,))
cvr_label = torch.randint(0, 2, (batch_size,))
ctcvr_label = ctr_label * cvr_label # Conversion only if clicked
click_mask = ctr_label.bool()

# Forward pass
ctr_pred, cvr_pred, ctcvr_pred = model(x)

# Compute loss
loss, ctr_loss, cvr_loss, ctcvr_loss = esmm_loss(
ctr_pred, cvr_pred, ctcvr_pred,
ctr_label, cvr_label, ctcvr_label,
click_mask
)

# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"CTR Loss: {ctr_loss.item():.4f}")
print(f"CVR Loss: {cvr_loss.item():.4f}")
print(f"CTCVR Loss: {ctcvr_loss.item():.4f}")
print(f"Total Loss: {loss.item():.4f}")

Key Advantages

  1. Eliminates Sample Selection Bias: Models entire space through chain rule
  2. Consistent Predictions: CTCVR = CTR × CVR ensures mathematical consistency
  3. Data Efficiency: CTR model benefits from all samples, CVR from clicked samples
  4. Practical: Simple architecture, easy to deploy

MMoE: Multi-gate Mixture-of-Experts

Motivation

MMoE addresses the negative transfer problem in shared-bottom architectures. When tasks conflict, forcing them to share the same representation hurts performance. MMoE introduces:

  1. Multiple Experts: Several expert networks that can learn different patterns
  2. Gating Networks: Task-specific gates that selectively combine experts
  3. Adaptive Sharing: Tasks can share experts when beneficial, use different experts when conflicting

Architecture

MMoE consists of:

  1. Expert Networks:that learn different feature transformations
  2. Gating Networks:for each taskthat learn to weight experts
  3. Task Towers: Task-specific networks on top of expert outputs

Mathematical Formulation:

For task:where: -is the output of expert -is the gate weight for taskand expert -(softmax normalization)

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class Expert(nn.Module):
"""Single expert network"""
def __init__(self, input_dim, hidden_dim, output_dim):
super(Expert, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim),
nn.ReLU()
)

def forward(self, x):
return self.network(x)

class Gate(nn.Module):
"""Gating network for a single task"""
def __init__(self, input_dim, num_experts):
super(Gate, self).__init__()
self.gate_network = nn.Sequential(
nn.Linear(input_dim, num_experts),
nn.Softmax(dim=1)
)

def forward(self, x):
return self.gate_network(x)

class MMoE(nn.Module):
"""
Multi-gate Mixture-of-Experts (MMoE)

Allows tasks to selectively share experts through
task-specific gating networks.
"""
def __init__(self, input_dim, num_experts, expert_hidden_dim,
expert_output_dim, num_tasks, task_hidden_dims, task_types):
"""
Args:
input_dim: Input feature dimension
num_experts: Number of expert networks
expert_hidden_dim: Hidden dimension for experts
expert_output_dim: Output dimension for experts
num_tasks: Number of tasks
task_hidden_dims: Hidden dimensions for task towers
task_types: List of task types
"""
super(MMoE, self).__init__()
self.num_experts = num_experts
self.num_tasks = num_tasks
self.expert_output_dim = expert_output_dim

# Expert networks
self.experts = nn.ModuleList([
Expert(input_dim, expert_hidden_dim, expert_output_dim)
for _ in range(num_experts)
])

# Gating networks (one per task)
self.gates = nn.ModuleList([
Gate(input_dim, num_experts)
for _ in range(num_tasks)
])

# Task-specific towers
self.task_towers = nn.ModuleList()
for task_idx in range(num_tasks):
tower_layers = []
prev_dim = expert_output_dim
for hidden_dim in task_hidden_dims:
tower_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
])
prev_dim = hidden_dim

# Output layer
if task_types[task_idx] == 'binary':
tower_layers.append(nn.Linear(prev_dim, 1))
tower_layers.append(nn.Sigmoid())
elif task_types[task_idx] == 'regression':
tower_layers.append(nn.Linear(prev_dim, 1))

self.task_towers.append(nn.Sequential(*tower_layers))

def forward(self, x):
"""
Forward pass

Args:
x: Input tensor (batch_size, input_dim)

Returns:
List of task predictions
List of gate weights (for analysis)
"""
batch_size = x.size(0)

# Compute expert outputs
expert_outputs = []
for expert in self.experts:
expert_outputs.append(expert(x))
expert_outputs = torch.stack(expert_outputs, dim=1) # (batch_size, num_experts, expert_output_dim)

# Compute task-specific representations
task_outputs = []
gate_weights_list = []

for task_idx in range(self.num_tasks):
# Get gate weights for this task
gate_weights = self.gates[task_idx](x) # (batch_size, num_experts)
gate_weights_list.append(gate_weights)

# Weighted combination of experts
gate_weights_expanded = gate_weights.unsqueeze(2) # (batch_size, num_experts, 1)
task_repr = (expert_outputs * gate_weights_expanded).sum(dim=1) # (batch_size, expert_output_dim)

# Task-specific tower
task_output = self.task_towers[task_idx](task_repr)
task_outputs.append(task_output)

return task_outputs, gate_weights_list

# Example usage
model = MMoE(
input_dim=128,
num_experts=4,
expert_hidden_dim=64,
expert_output_dim=32,
num_tasks=3,
task_hidden_dims=[16],
task_types=['binary', 'binary', 'regression']
)

# Forward pass
x = torch.randn(32, 128)
outputs, gate_weights = model(x)

print(f"Task outputs: {[out.shape for out in outputs]}")
print(f"Gate weights shape: {gate_weights[0].shape}")
print(f"Sample gate weights for task 1: {gate_weights[0][0]}")

Advantages of MMoE

  1. Adaptive Sharing: Tasks can selectively use experts based on their needs
  2. Conflict Handling: Conflicting tasks can use different experts
  3. Scalability: Easy to add new tasks or experts
  4. Interpretability: Gate weights show which experts each task uses

PLE: Progressive Layered Extraction

Motivation

PLE (Progressive Layered Extraction) extends MMoE by explicitly separating:

  1. Shared Experts: Capture common patterns across all tasks
  2. Task-Specific Experts: Capture task-unique patterns
  3. Progressive Extraction: Gradually extract shared knowledge layer by layer

This explicit separation helps when tasks have both shared and conflicting patterns.

Architecture

PLE uses a multi-layer structure:

Layer Structure: - Shared Experts:at layer - Task-Specific Experts:for taskat layer - Gating Networks: Combine shared and task-specific experts

Progressive Extraction: - Lower layers: More task-specific, less sharing - Higher layers: More shared knowledge, less task-specific

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class PLELayer(nn.Module):
"""Single PLE layer with shared and task-specific experts"""
def __init__(self, input_dim, num_shared_experts, num_task_experts,
expert_hidden_dim, expert_output_dim, num_tasks):
super(PLELayer, self).__init__()
self.num_shared_experts = num_shared_experts
self.num_task_experts = num_task_experts
self.num_tasks = num_tasks
self.expert_output_dim = expert_output_dim

# Shared experts
self.shared_experts = nn.ModuleList([
Expert(input_dim, expert_hidden_dim, expert_output_dim)
for _ in range(num_shared_experts)
])

# Task-specific experts
self.task_experts = nn.ModuleList([
nn.ModuleList([
Expert(input_dim, expert_hidden_dim, expert_output_dim)
for _ in range(num_task_experts)
])
for _ in range(num_tasks)
])

# Gating networks
self.shared_gates = nn.ModuleList([
Gate(input_dim, num_shared_experts)
for _ in range(num_tasks)
])

self.task_gates = nn.ModuleList([
Gate(input_dim, num_task_experts)
for _ in range(num_tasks)
])

def forward(self, x):
"""
Forward pass through PLE layer

Returns:
task_outputs: List of task representations
gate_info: Gate weights for analysis
"""
batch_size = x.size(0)

# Shared expert outputs
shared_expert_outputs = []
for expert in self.shared_experts:
shared_expert_outputs.append(expert(x))
shared_expert_outputs = torch.stack(shared_expert_outputs, dim=1) # (batch_size, num_shared_experts, expert_output_dim)

# Task-specific expert outputs
task_expert_outputs_list = []
for task_idx in range(self.num_tasks):
task_expert_outputs = []
for expert in self.task_experts[task_idx]:
task_expert_outputs.append(expert(x))
task_expert_outputs = torch.stack(task_expert_outputs, dim=1) # (batch_size, num_task_experts, expert_output_dim)
task_expert_outputs_list.append(task_expert_outputs)

# Combine shared and task-specific experts for each task
task_outputs = []
gate_info = []

for task_idx in range(self.num_tasks):
# Gate weights
shared_gate_weights = self.shared_gates[task_idx](x) # (batch_size, num_shared_experts)
task_gate_weights = self.task_gates[task_idx](x) # (batch_size, num_task_experts)

# Weighted combination
shared_weighted = (shared_expert_outputs * shared_gate_weights.unsqueeze(2)).sum(dim=1)
task_weighted = (task_expert_outputs_list[task_idx] * task_gate_weights.unsqueeze(2)).sum(dim=1)

# Combine shared and task-specific (concatenate or add)
task_output = shared_weighted + task_weighted # or torch.cat([shared_weighted, task_weighted], dim=1)
task_outputs.append(task_output)

gate_info.append({
'shared_gate': shared_gate_weights,
'task_gate': task_gate_weights
})

return task_outputs, gate_info

class PLE(nn.Module):
"""
Progressive Layered Extraction (PLE)

Explicitly separates shared and task-specific experts
with progressive knowledge extraction across layers.
"""
def __init__(self, input_dim, num_layers, num_shared_experts,
num_task_experts, expert_hidden_dim, expert_output_dim,
num_tasks, task_hidden_dims, task_types):
super(PLE, self).__init__()
self.num_layers = num_layers
self.num_tasks = num_tasks

# PLE layers
self.ple_layers = nn.ModuleList()
prev_dim = input_dim
for layer_idx in range(num_layers):
layer = PLELayer(
prev_dim, num_shared_experts, num_task_experts,
expert_hidden_dim, expert_output_dim, num_tasks
)
self.ple_layers.append(layer)
prev_dim = expert_output_dim

# Task-specific towers
self.task_towers = nn.ModuleList()
for task_idx in range(num_tasks):
tower_layers = []
prev_dim = expert_output_dim
for hidden_dim in task_hidden_dims:
tower_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
])
prev_dim = hidden_dim

if task_types[task_idx] == 'binary':
tower_layers.append(nn.Linear(prev_dim, 1))
tower_layers.append(nn.Sigmoid())
elif task_types[task_idx] == 'regression':
tower_layers.append(nn.Linear(prev_dim, 1))

self.task_towers.append(nn.Sequential(*tower_layers))

def forward(self, x):
"""
Forward pass through PLE layers

Returns:
task_outputs: Final task predictions
all_gate_info: Gate information from all layers
"""
current_input = x
all_gate_info = []

# Progressive extraction through layers
for layer in self.ple_layers:
current_input_list, gate_info = layer(current_input)
all_gate_info.append(gate_info)
# Use average or first task's output as input to next layer
# (In practice, you might use a more sophisticated combination)
current_input = current_input_list[0] # Simplified: use first task

# Final task-specific towers
final_outputs = []
for task_idx in range(self.num_tasks):
output = self.task_towers[task_idx](current_input_list[task_idx])
final_outputs.append(output)

return final_outputs, all_gate_info

# Example usage
model = PLE(
input_dim=128,
num_layers=2,
num_shared_experts=2,
num_task_experts=2,
expert_hidden_dim=64,
expert_output_dim=32,
num_tasks=3,
task_hidden_dims=[16],
task_types=['binary', 'binary', 'regression']
)

x = torch.randn(32, 128)
outputs, gate_info = model(x)
print(f"Number of layers: {len(gate_info)}")
print(f"Task outputs: {[out.shape for out in outputs]}")

STEM-Net: Search-based Task Embedding for Multi-Task Learning

Motivation

STEM-Net introduces task embeddings to explicitly model task relationships. Instead of hard-coding which tasks share experts, STEM-Net learns task representations that guide expert selection.

Key Ideas

  1. Task Embeddings: Learnable representations for each task
  2. Search-Based Architecture: Use task embeddings to search for relevant experts
  3. Dynamic Expert Selection: Experts are selected based on task-task and task-expert similarity

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class STEMNet(nn.Module):
"""
STEM-Net: Search-based Task Embedding for Multi-Task Learning

Learns task embeddings to guide expert selection dynamically.
"""
def __init__(self, input_dim, num_experts, expert_hidden_dim,
expert_output_dim, num_tasks, task_embed_dim,
task_hidden_dims, task_types):
super(STEMNet, self).__init__()
self.num_experts = num_experts
self.num_tasks = num_tasks
self.task_embed_dim = task_embed_dim

# Task embeddings (learnable)
self.task_embeddings = nn.Parameter(
torch.randn(num_tasks, task_embed_dim)
)

# Expert networks
self.experts = nn.ModuleList([
Expert(input_dim, expert_hidden_dim, expert_output_dim)
for _ in range(num_experts)
])

# Expert embeddings (learnable)
self.expert_embeddings = nn.Parameter(
torch.randn(num_experts, task_embed_dim)
)

# Attention mechanism for expert selection
self.attention = nn.MultiheadAttention(
embed_dim=task_embed_dim,
num_heads=4,
batch_first=True
)

# Task-specific towers
self.task_towers = nn.ModuleList()
for task_idx in range(num_tasks):
tower_layers = []
prev_dim = expert_output_dim
for hidden_dim in task_hidden_dims:
tower_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
])
prev_dim = hidden_dim

if task_types[task_idx] == 'binary':
tower_layers.append(nn.Linear(prev_dim, 1))
tower_layers.append(nn.Sigmoid())
elif task_types[task_idx] == 'regression':
tower_layers.append(nn.Linear(prev_dim, 1))

self.task_towers.append(nn.Sequential(*tower_layers))

def forward(self, x):
"""
Forward pass with task-embedding-guided expert selection

Returns:
task_outputs: Task predictions
attention_weights: Attention weights for interpretability
"""
batch_size = x.size(0)

# Compute expert outputs
expert_outputs = []
for expert in self.experts:
expert_outputs.append(expert(x))
expert_outputs = torch.stack(expert_outputs, dim=1) # (batch_size, num_experts, expert_output_dim)

# Expand task embeddings for batch
task_embeds = self.task_embeddings.unsqueeze(0).expand(batch_size, -1, -1) # (batch_size, num_tasks, task_embed_dim)
expert_embeds = self.expert_embeddings.unsqueeze(0).expand(batch_size, -1, -1) # (batch_size, num_experts, task_embed_dim)

# Compute attention: tasks attend to experts
attended_experts, attention_weights = self.attention(
query=task_embeds, # (batch_size, num_tasks, task_embed_dim)
key=expert_embeds, # (batch_size, num_experts, task_embed_dim)
value=expert_embeds # (batch_size, num_experts, task_embed_dim)
)

# Convert attention weights to expert selection weights
# attention_weights: (batch_size, num_tasks, num_experts)
expert_weights = attention_weights # Already normalized by attention

# Weighted combination of experts for each task
task_outputs = []
for task_idx in range(self.num_tasks):
# Get weights for this task
task_expert_weights = expert_weights[:, task_idx, :] # (batch_size, num_experts)

# Weighted combination
task_expert_weights_expanded = task_expert_weights.unsqueeze(2) # (batch_size, num_experts, 1)
task_repr = (expert_outputs * task_expert_weights_expanded).sum(dim=1) # (batch_size, expert_output_dim)

# Task-specific tower
task_output = self.task_towers[task_idx](task_repr)
task_outputs.append(task_output)

return task_outputs, attention_weights

# Example usage
model = STEMNet(
input_dim=128,
num_experts=4,
expert_hidden_dim=64,
expert_output_dim=32,
num_tasks=3,
task_embed_dim=16,
task_hidden_dims=[16],
task_types=['binary', 'binary', 'regression']
)

x = torch.randn(32, 128)
outputs, attention_weights = model(x)
print(f"Task outputs: {[out.shape for out in outputs]}")
print(f"Attention weights shape: {attention_weights.shape}")

Task Relationship Modeling

Understanding Task Relationships

Tasks in recommendation systems have different relationships:

  1. Complementary Tasks: Benefit from sharing (e.g., CTR and CVR)
  2. Conflicting Tasks: Hurt each other when sharing (e.g., engagement vs. revenue)
  3. Hierarchical Tasks: One task is a prerequisite for another (e.g., click → conversion)
  4. Independent Tasks: No clear relationship

Methods for Modeling Task Relationships

Correlation-Based Methods

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def compute_task_correlation(task_predictions, task_labels):
"""
Compute correlation matrix between tasks

Args:
task_predictions: List of prediction tensors
task_labels: List of label tensors

Returns:
Correlation matrix (num_tasks, num_tasks)
"""
num_tasks = len(task_predictions)
correlations = torch.zeros(num_tasks, num_tasks)

for i in range(num_tasks):
for j in range(num_tasks):
# Compute Pearson correlation
pred_i = task_predictions[i].flatten()
pred_j = task_predictions[j].flatten()

mean_i = pred_i.mean()
mean_j = pred_j.mean()

numerator = ((pred_i - mean_i) * (pred_j - mean_j)).sum()
denominator = torch.sqrt(
((pred_i - mean_i) ** 2).sum() *
((pred_j - mean_j) ** 2).sum()
)

if denominator > 0:
correlations[i, j] = numerator / denominator

return correlations

Gradient-Based Methods

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def compute_gradient_similarity(model, loss_fn, x, task_labels):
"""
Compute gradient similarity between tasks

Similar gradients indicate tasks benefit from sharing.
"""
num_tasks = len(task_labels)
gradients_list = []

for task_idx in range(num_tasks):
# Forward pass
outputs = model(x)
loss = loss_fn(outputs[task_idx], task_labels[task_idx])

# Backward pass
model.zero_grad()
loss.backward(retain_graph=True)

# Collect gradients
task_gradients = []
for param in model.parameters():
if param.grad is not None:
task_gradients.append(param.grad.flatten())

gradients = torch.cat(task_gradients)
gradients_list.append(gradients)

# Compute cosine similarity between gradients
similarity_matrix = torch.zeros(num_tasks, num_tasks)
for i in range(num_tasks):
for j in range(num_tasks):
cos_sim = F.cosine_similarity(
gradients_list[i].unsqueeze(0),
gradients_list[j].unsqueeze(0)
)
similarity_matrix[i, j] = cos_sim.item()

return similarity_matrix

Task Embedding Methods

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class TaskRelationshipModel(nn.Module):
"""
Learn task relationships through embeddings
"""
def __init__(self, num_tasks, embed_dim):
super(TaskRelationshipModel, self).__init__()
self.task_embeddings = nn.Embedding(num_tasks, embed_dim)

def compute_task_similarity(self):
"""Compute similarity matrix from task embeddings"""
embeddings = self.task_embeddings.weight # (num_tasks, embed_dim)
similarity = F.cosine_similarity(
embeddings.unsqueeze(1),
embeddings.unsqueeze(0),
dim=2
)
return similarity

def forward(self, task_indices):
"""Get embeddings for given tasks"""
return self.task_embeddings(task_indices)

Loss Balancing Strategies

The Challenge

Different tasks have: - Different scales (CTR: 0.01-0.1, Revenue: 0-1000) - Different importance (business priorities) - Different difficulty (some tasks are easier to learn)

Simply summing losses can lead to one task dominating training.

Uniform Weighting

1
2
3
def uniform_loss(task_losses):
"""Simple uniform weighting"""
return sum(task_losses)

Uncertainty Weighting

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class UncertaintyWeighting(nn.Module):
"""
Learn task-specific uncertainty weights

Kendall et al., "Multi-Task Learning Using Uncertainty
to Weigh Losses for Scene Geometry and Semantics", CVPR 2018
"""
def __init__(self, num_tasks):
super(UncertaintyWeighting, self).__init__()
# Learnable log variance for each task
self.log_vars = nn.Parameter(torch.zeros(num_tasks))

def forward(self, task_losses):
"""
Weighted loss with uncertainty

Loss = sum_k (1/(2*sigma_k^2) * L_k + log(sigma_k))
"""
weighted_losses = []
for task_idx, loss in enumerate(task_losses):
precision = torch.exp(-self.log_vars[task_idx])
weighted_loss = precision * loss + self.log_vars[task_idx]
weighted_losses.append(weighted_loss)

return sum(weighted_losses)

GradNorm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class GradNorm(nn.Module):
"""
Gradient Normalization for Adaptive Loss Balancing

Chen et al., "GradNorm: Gradient Normalization for
Adaptive Loss Balancing in Deep Multitask Networks", ICML 2018
"""
def __init__(self, model, num_tasks, alpha=0.12):
super(GradNorm, self).__init__()
self.model = model
self.num_tasks = num_tasks
self.alpha = alpha

# Learnable task weights
self.task_weights = nn.Parameter(torch.ones(num_tasks))

# Track initial losses for relative loss computation
self.register_buffer('initial_losses', None)

def compute_gradnorm_loss(self, task_losses):
"""
Compute GradNorm loss

Balances gradients across tasks by adjusting task weights.
"""
if self.initial_losses is None:
self.initial_losses = torch.stack(task_losses).detach()

# Compute relative losses
relative_losses = torch.stack(task_losses) / self.initial_losses

# Compute weighted losses
weighted_losses = [
self.task_weights[i] * task_losses[i]
for i in range(self.num_tasks)
]

# Compute gradients of weighted losses w.r.t. shared parameters
shared_params = [p for name, p in self.model.named_parameters()
if 'shared' in name or 'expert' in name]

if len(shared_params) == 0:
# Fallback: use all parameters
shared_params = list(self.model.parameters())

# Compute gradients for each task
task_grads = []
for i, weighted_loss in enumerate(weighted_losses):
self.model.zero_grad()
weighted_loss.backward(retain_graph=True)
grad_norm = torch.norm(
torch.cat([p.grad.flatten() for p in shared_params if p.grad is not None])
)
task_grads.append(grad_norm)

task_grads = torch.stack(task_grads)

# Target gradient norm (average)
mean_grad_norm = task_grads.mean()

# Relative inverse training rates
relative_inverse_rates = relative_losses / relative_losses.mean()

# Target gradients
target_grads = mean_grad_norm * (relative_inverse_rates ** self.alpha)

# GradNorm loss: L2 distance between actual and target gradients
gradnorm_loss = F.mse_loss(task_grads, target_grads)

return gradnorm_loss, self.task_weights

Dynamic Weight Average (DWA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class DynamicWeightAverage(nn.Module):
"""
Dynamic Weight Average for Multi-Task Learning

Liu et al., "End-to-End Multi-Task Learning with Attention", CVPR 2019
"""
def __init__(self, num_tasks, temperature=2.0):
super(DynamicWeightAverage, self).__init__()
self.num_tasks = num_tasks
self.temperature = temperature
self.register_buffer('loss_history', torch.zeros(num_tasks))

def forward(self, task_losses):
"""
Compute dynamic weights based on relative loss decrease rate
"""
current_losses = torch.stack(task_losses)

if self.loss_history.sum() == 0:
# First iteration: uniform weights
weights = torch.ones(self.num_tasks) / self.num_tasks
else:
# Compute relative decrease rate
relative_decrease = self.loss_history / (current_losses + 1e-8)

# Softmax to get weights (higher decrease rate -> higher weight)
weights = F.softmax(relative_decrease / self.temperature, dim=0)

# Update history
self.loss_history = current_losses.detach()

# Weighted loss
weighted_loss = sum(w * loss for w, loss in zip(weights, task_losses))

return weighted_loss, weights

Pareto Optimal Solutions

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class ParetoMTL(nn.Module):
"""
Pareto Multi-Task Learning

Uses multiple gradient descent to find Pareto optimal solutions
"""
def __init__(self, num_tasks):
super(ParetoMTL, self).__init__()
self.num_tasks = num_tasks

def compute_pareto_weights(self, task_losses, model):
"""
Compute Pareto optimal weights using multiple gradient descent

Sener & Koltun, "Multi-Task Learning as Multi-Objective Optimization", NeurIPS 2018
"""
# Initialize weights
weights = torch.ones(self.num_tasks, requires_grad=True) / self.num_tasks

# Compute gradients for each task
task_grads = []
for i, loss in enumerate(task_losses):
model.zero_grad()
loss.backward(retain_graph=True)
grad = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None])
task_grads.append(grad)

task_grads = torch.stack(task_grads) # (num_tasks, grad_dim)

# Find weights that minimize gradient conflict
# Using Frank-Wolfe algorithm or gradient descent on weights
optimizer = torch.optim.SGD([weights], lr=0.1)

for _ in range(10): # Few iterations
optimizer.zero_grad()

# Weighted combination of gradients
weighted_grad = (weights.unsqueeze(1) * task_grads).sum(dim=0)

# Minimize gradient norm (reduces conflict)
loss = weighted_grad.norm()
loss.backward()
optimizer.step()

# Project to simplex
weights.data = F.softmax(weights.data, dim=0)

return weights.detach()

Industrial Applications and Case Studies

Alibaba: ESMM for E-commerce

Problem: Predict conversion rate in entire impression space, not just clicked samples.

Solution: ESMM models CTR and CVR separately, then combines via chain rule.

Results: - 2.6% improvement in CVR prediction accuracy - Eliminated sample selection bias - Improved ranking quality

Key Insights: - Modeling entire space is crucial for production systems - Chain rule ensures mathematical consistency - Simple architecture enables easy deployment

Google: MMoE for YouTube Recommendations

Problem: Multiple conflicting objectives (watch time, engagement, revenue).

Solution: MMoE with multiple experts and task-specific gates.

Results: - Significant improvement over shared-bottom baseline - Better handling of task conflicts - Scalable to many tasks

Key Insights: - Gating mechanism crucial for handling conflicts - Expert diversity important for performance - Architecture scales well with number of tasks

Tencent: PLE for Video Recommendations

Problem: Balance multiple objectives with both shared and conflicting patterns.

Solution: PLE with explicit shared/task-specific expert separation.

Results: - 12% improvement in overall engagement - Better balance across objectives - Improved long-term user satisfaction

Key Insights: - Explicit separation helps when tasks have mixed relationships - Progressive extraction captures hierarchical patterns - Multi-layer structure enables complex modeling

Amazon: Multi-Task Learning for Product Recommendations

Tasks: Click prediction, add-to-cart, purchase, review rating.

Architecture: Custom MMoE variant with domain-specific features.

Results: - 8% improvement in purchase rate - Better handling of cold-start users - Improved recommendation diversity

Netflix: Multi-Objective Optimization for Content Recommendations

Objectives: Watch time, completion rate, user satisfaction, content diversity.

Approach: Pareto-optimal multi-task learning with dynamic weighting.

Results: - Better balance across objectives - Improved user retention - Higher content diversity

Complete Training Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

class RecommendationDataset(Dataset):
"""Dataset for multi-task recommendation"""
def __init__(self, features, labels_dict):
"""
Args:
features: Input features (num_samples, feature_dim)
labels_dict: Dictionary of task labels
{'ctr': [...], 'cvr': [...], 'revenue': [...]}
"""
self.features = torch.FloatTensor(features)
self.labels = {k: torch.FloatTensor(v) for k, v in labels_dict.items()}
self.task_names = list(labels_dict.keys())

def __len__(self):
return len(self.features)

def __getitem__(self, idx):
sample = {'features': self.features[idx]}
for task_name in self.task_names:
sample[task_name] = self.labels[task_name][idx]
return sample

class MultiTaskTrainer:
"""Complete training pipeline for multi-task learning"""
def __init__(self, model, device, loss_balancer=None):
self.model = model.to(device)
self.device = device
self.loss_balancer = loss_balancer

def train_epoch(self, dataloader, optimizer, task_names):
"""Train for one epoch"""
self.model.train()
total_loss = 0
task_losses = {name: 0 for name in task_names}

for batch in dataloader:
# Move to device
features = batch['features'].to(self.device)
labels = {name: batch[name].to(self.device) for name in task_names}

# Forward pass
outputs = self.model(features)

# Compute losses
batch_task_losses = []
for idx, task_name in enumerate(task_names):
if task_name in ['ctr', 'cvr']:
# Binary classification
loss = F.binary_cross_entropy(
outputs[idx].squeeze(),
labels[task_name]
)
else:
# Regression
loss = F.mse_loss(
outputs[idx].squeeze(),
labels[task_name]
)
batch_task_losses.append(loss)
task_losses[task_name] += loss.item()

# Combine losses
if self.loss_balancer:
total_batch_loss, weights = self.loss_balancer(batch_task_losses)
else:
total_batch_loss = sum(batch_task_losses)

# Backward pass
optimizer.zero_grad()
total_batch_loss.backward()
optimizer.step()

total_loss += total_batch_loss.item()

# Average losses
num_batches = len(dataloader)
avg_loss = total_loss / num_batches
avg_task_losses = {name: task_losses[name] / num_batches
for name in task_names}

return avg_loss, avg_task_losses

def evaluate(self, dataloader, task_names):
"""Evaluate model"""
self.model.eval()
task_metrics = {name: {'predictions': [], 'labels': []}
for name in task_names}

with torch.no_grad():
for batch in dataloader:
features = batch['features'].to(self.device)
labels = {name: batch[name].to(self.device) for name in task_names}

outputs = self.model(features)

for idx, task_name in enumerate(task_names):
task_metrics[task_name]['predictions'].append(
outputs[idx].cpu().numpy()
)
task_metrics[task_name]['labels'].append(
labels[task_name].cpu().numpy()
)

# Concatenate and compute metrics
results = {}
for task_name in task_names:
preds = np.concatenate(task_metrics[task_name]['predictions'])
labels = np.concatenate(task_metrics[task_name]['labels'])

if task_name in ['ctr', 'cvr']:
# Binary classification metrics
preds_binary = (preds > 0.5).astype(int)
accuracy = (preds_binary.flatten() == labels).mean()
results[task_name] = {'accuracy': accuracy}
else:
# Regression metrics
mse = ((preds.flatten() - labels) ** 2).mean()
mae = np.abs(preds.flatten() - labels).mean()
results[task_name] = {'mse': mse, 'mae': mae}

return results

# Example usage
def main():
# Create synthetic data
num_samples = 10000
feature_dim = 128

features = np.random.randn(num_samples, feature_dim)
labels_dict = {
'ctr': np.random.randint(0, 2, num_samples).astype(float),
'cvr': np.random.randint(0, 2, num_samples).astype(float),
'revenue': np.random.rand(num_samples) * 100
}

# Create dataset
dataset = RecommendationDataset(features, labels_dict)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Create model
model = MMoE(
input_dim=feature_dim,
num_experts=4,
expert_hidden_dim=64,
expert_output_dim=32,
num_tasks=3,
task_hidden_dims=[16],
task_types=['binary', 'binary', 'regression']
)

# Loss balancer
loss_balancer = UncertaintyWeighting(num_tasks=3)

# Trainer
trainer = MultiTaskTrainer(model, device='cuda', loss_balancer=loss_balancer)

# Optimizer
optimizer = optim.Adam(list(model.parameters()) + list(loss_balancer.parameters()),
lr=0.001)

# Training loop
num_epochs = 10
task_names = ['ctr', 'cvr', 'revenue']

for epoch in range(num_epochs):
train_loss, train_task_losses = trainer.train_epoch(
train_loader, optimizer, task_names
)

val_results = trainer.evaluate(val_loader, task_names)

print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Train Loss: {train_loss:.4f}")
for task_name in task_names:
print(f" {task_name}: {train_task_losses[task_name]:.4f}")
print("Validation Results:")
for task_name, metrics in val_results.items():
print(f" {task_name}: {metrics}")
print()

if __name__ == '__main__':
main()

Advanced Techniques

Cross-Stitch Networks

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class CrossStitchUnit(nn.Module):
"""
Cross-Stitch Unit for sharing representations between tasks

Misra et al., "Cross-stitch Networks for Multi-task Learning", CVPR 2016
"""
def __init__(self, feature_dim):
super(CrossStitchUnit, self).__init__()
# Learnable combination matrix
self.alpha = nn.Parameter(torch.ones(2, 2) * 0.5)
# Ensure rows sum to 1
self.alpha.data = F.softmax(self.alpha.data, dim=1)

def forward(self, x1, x2):
"""
Combine task-specific features

Args:
x1: Features from task 1 (batch_size, feature_dim)
x2: Features from task 2 (batch_size, feature_dim)

Returns:
Combined features for both tasks
"""
# Stack features
stacked = torch.stack([x1, x2], dim=1) # (batch_size, 2, feature_dim)

# Apply combination matrix
combined = torch.einsum('bij,jk->bik', stacked, self.alpha)

# Split back
x1_combined = combined[:, 0, :]
x2_combined = combined[:, 1, :]

return x1_combined, x2_combined

Task Routing

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class TaskRouter(nn.Module):
"""
Dynamic task routing based on input characteristics
"""
def __init__(self, input_dim, num_tasks, num_experts):
super(TaskRouter, self).__init__()
self.router = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, num_tasks * num_experts),
nn.Softmax(dim=1)
)

def forward(self, x):
"""
Route input to experts based on task-specific routing

Returns:
routing_weights: (batch_size, num_tasks, num_experts)
"""
routing_logits = self.router(x)
routing_weights = routing_logits.view(-1, self.num_tasks, self.num_experts)
return routing_weights

Q&A: Common Questions and Answers

Q1: When should I use multi-task learning vs. separate models?

A: Use multi-task learning when: - Tasks are related and share underlying patterns - You have limited data for some tasks - Tasks benefit from shared representations - You want to reduce computational overhead

Use separate models when: - Tasks are completely independent - Tasks have strong conflicts that can't be resolved - You have abundant data for each task - Tasks require very different architectures

Q2: How do I choose between Shared-Bottom, MMoE, and PLE?

A: - Shared-Bottom: Start here for simple cases with complementary tasks - MMoE: Use when tasks may conflict but you're unsure which ones - PLE: Use when you know tasks have both shared and conflicting patterns

Generally, start simple and move to more complex architectures if needed.

Q3: How many experts should I use in MMoE/PLE?

A: Common choices: - 2-4 experts: For 2-3 tasks - 4-8 experts: For 4-6 tasks - 8+ experts: For many tasks or complex scenarios

Start with fewer experts and increase if needed. Too many experts can lead to overfitting.

Q4: How do I balance losses across tasks?

A: Several approaches: 1. Uniform: Simple sum (works if tasks have similar scales) 2. Uncertainty Weighting: Learn task-specific weights automatically 3. GradNorm: Balance gradients instead of losses 4. DWA: Dynamic weighting based on loss decrease rates 5. Manual: Set weights based on business priorities

Start with uniform or uncertainty weighting, then try more sophisticated methods if needed.

Q5: What if tasks have very different scales?

A: Options: 1. Normalize labels: Scale all labels to similar ranges (e.g., [0, 1]) 2. Use appropriate loss functions: MSE for regression, BCE for classification 3. Learn task-specific scales: Use uncertainty weighting or learnable scales 4. Separate normalization: Normalize each task's labels independently

Q6: How do I handle missing labels for some tasks?

A: Strategies: 1. Masked loss: Only compute loss for available labels 2. Imputation: Fill missing labels with default values 3. Separate sampling: Sample batches to ensure all tasks have labels 4. Task-specific data loaders: Handle missing labels in data loading

Most common: masked loss computation.

Q7: Can I use multi-task learning with different input modalities?

A: Yes! Common approaches: 1. Shared encoders: Encode each modality separately, then combine 2. Cross-modal attention: Attend across modalities 3. Modality-specific experts: Separate experts for each modality

Example: Text + image recommendations can share high-level representations while keeping modality-specific encoders.

Q8: How do I evaluate multi-task models?

A: Evaluate both: 1. Individual task performance: Metrics for each task (AUC, MSE, etc.) 2. Joint performance: Combined metrics (e.g., weighted average) 3. Trade-off analysis: Pareto frontier if objectives conflict

Don't just look at average performance — understand task-specific improvements.

Q9: What are common pitfalls in multi-task learning?

A: Common mistakes: 1. Ignoring task conflicts: Forcing conflicting tasks to share 2. Poor loss balancing: One task dominates training 3. Insufficient capacity: Shared layers too small 4. Over-sharing: Sharing when tasks should be separate 5. Under-sharing: Not sharing when tasks benefit from it

Q10: How do I debug multi-task models?

A: Debugging strategies: 1. Check individual task losses: Ensure all tasks are learning 2. Visualize gate weights: Understand expert usage in MMoE/PLE 3. Monitor gradients: Check for gradient conflicts 4. Ablation studies: Remove tasks/experts to understand contributions 5. Compare to single-task baselines: Ensure MTL actually helps

Q11: Can I add new tasks to an existing multi-task model?

A: Yes, but consider: 1. Architecture: MMoE/PLE make this easier than Shared-Bottom 2. Fine-tuning: Retrain or fine-tune when adding tasks 3. Task relationships: New tasks may change expert usage patterns 4. Loss balancing: May need to rebalance weights

MMoE and PLE are more flexible for adding tasks.

Q12: How do I handle tasks with different update frequencies?

A: Options: 1. Separate optimizers: Different learning rates per task 2. Alternating updates: Update tasks in rounds 3. Gradient accumulation: Accumulate gradients across tasks 4. Task-specific schedules: Different update frequencies

Most common: separate optimizers or alternating updates.

Q13: What's the relationship between multi-task learning and transfer learning?

A: - Multi-task learning: Train multiple tasks simultaneously with shared representations - Transfer learning: Pre-train on one task, fine-tune on another

MTL can be seen as simultaneous transfer learning across tasks. Both leverage shared knowledge.

Q14: How do I choose task-specific architectures?

A: Consider: 1. Task type: Classification vs. regression vs. ranking 2. Task complexity: Simple tasks need simpler towers 3. Data availability: More data allows more complex architectures 4. Computational budget: Balance capacity with efficiency

Start with simple towers and increase complexity if needed.

Q15: Can multi-task learning help with cold-start problems?

A: Yes! MTL can help by: 1. Transfer from warm tasks: Warm tasks help cold tasks through shared representations 2. Auxiliary tasks: Use easy-to-collect signals (clicks) to help hard tasks (purchases) 3. Shared embeddings: Cold users/items benefit from shared patterns

This is one of the key advantages of MTL in recommendation systems.

Conclusion

Multi-task learning has become essential for modern recommendation systems, enabling us to optimize multiple objectives simultaneously while leveraging shared patterns across tasks. From the simple Shared-Bottom architecture to sophisticated models like MMoE and PLE, MTL provides a principled framework for handling the complex, multi-objective nature of real-world recommendation problems.

Key takeaways: 1. Start simple: Begin with Shared-Bottom, move to MMoE/PLE if needed 2. Balance losses: Use appropriate weighting strategies 3. Understand task relationships: Know which tasks benefit from sharing 4. Monitor all tasks: Don't optimize for average performance alone 5. Consider business priorities: Align technical choices with business goals

As recommendation systems continue to evolve, multi-task learning will remain a crucial tool for building systems that balance multiple objectives while delivering great user experiences.

  • Post title:Recommendation Systems (9): Multi-Task Learning and Multi-Objective Optimization
  • Post author:Chen Kai
  • Create time:2024-06-11 09:45:00
  • Post link:https://www.chenk.top/en/recommendation-systems-9-multi-task-learning/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments