多任务学习(Multi-Task Learning,
MTL)是一种通过同时学习多个相关任务来提升模型泛化能力的机器学习范式。
1997 年 Rich Caruana 的开创性论文"Multitask
Learning"展示了共享表示如何帮助模型学习更鲁棒的特征。现代深度学习中,多任务学习在计算机视觉(同时检测、分割、深度估计)、自然语言处理(联合实体识别与关系抽取)、推荐系统(同时预测点击率与转化率)等领域取得了巨大成功。但多任务学习并非简单地将多个损失函数相加——如何设计共享结构、如何平衡不同任务的学习、如何处理任务间的负迁移,都是需要深入研究的问题。
本文将从第一性原理出发,推导多任务学习的数学基础,解析硬参数共享与软参数共享的优劣,详细讲解任务关系学习与任务聚类方法,深入剖析梯度冲突问题及其解决方案(PCGrad
、 GradNorm 、 CAGrad
等),介绍辅助任务的设计原则,并提供一个完整的多任务网络实现(包含动态权重调整、梯度投影、任务平衡等工业级技巧)。我们会看到,多任务学习本质上是在寻找一个能够满足多个优化目标的帕累托最优解。
多任务学习的动机:为什么要多任务学习
从单任务到多任务:归纳偏置的共享
单任务学习(Single-Task
Learning)为每个任务独立训练一个模型,而多任务学习让所有任务共享部分参数或表示。这背后的核心假设是:相关任务之间存在共同的底层结构 。
直觉例子 :考虑图像场景理解的三个任务:
目标检测 :识别图像中的物体位置和类别
语义分割 :为每个像素分配类别标签
深度估计 :预测每个像素的深度值
这三个任务都需要理解图像的空间结构、物体边界、纹理信息等底层特征。与其为每个任务独立学习这些特征,不如让它们共享一个特征提取器,只在高层用任务特定的头部(task-specific
heads)。
多任务学习的数学视角:正则化效应
从优化角度看,多任务学习引入了一种隐式正则化 。设有
个任务,第 个任务的损失为 ,其中 是共享参数,
是任务特定参数。多任务优化目标为:
其中 是任务权重。
关键洞察:共享参数
必须同时对所有任务有效,这限制了其表示空间,起到了正则化作用 。形式化地,多任务学习等价于在单任务损失上添加了一个隐式约束:
其中
是辅助任务损失,
是容忍度。这种约束防止模型过度拟合主任务的训练数据。
数据增强视角:辅助任务提供额外信号
多任务学习可以视为一种数据增强 策略。当主任务的标注数据有限时,辅助任务可以提供额外的监督信号。
例子 :在低资源语言的机器翻译中:
主任务 :英语
斯瓦希里语翻译(只有 10 万句对)
辅助任务 :英语 法语翻译(1000 万句对)
尽管法语和斯瓦希里语不同,但英语编码器可以从大量英法数据中学习更好的英语表示,从而帮助英斯翻译。
实验表明,引入辅助任务后,主任务的性能可以提升
5-20%(取决于任务相关性和数据量)。
计算效率:参数共享减少冗余
从工程角度,多任务学习通过参数共享显著减少了模型参数和计算量:
单任务 :
个任务各自有一个 ResNet-50 编码器,总参数量 (设 )
多任务 :
个任务共享一个编码器,总参数量 (每个任务头 2M 参数)
参数量减少了约
70%,推理时也只需一次前向传播即可得到所有任务的输出,大幅提升了效率。
负迁移:多任务学习的风险
然而,多任务学习并非总是有益。当任务之间不相关甚至冲突时,可能发生负迁移(Negative
Transfer) :联合训练的性能低于单独训练。
例子 :
任务 A :人脸识别(需要学习细粒度的人脸特征)
任务 B :场景分类(需要学习全局的布局和上下文)
这两个任务的特征需求差异很大,强行共享参数可能导致相互干扰。
实验数据(CIFAR-100):
单独训练任务 A:82%准确率
单独训练任务 B:78%准确率
联合训练(naive MTL):79%和 74%(两个任务都下降)
因此,如何设计共享结构、如何选择相关任务、如何平衡任务权重 ,是多任务学习成功的关键。
参数共享策略:硬共享与软共享
多任务学习的核心是如何在任务之间共享信息。主要有两种范式:硬参数共享和软参数共享。
硬参数共享(Hard Parameter
Sharing)
硬参数共享是最常见的多任务学习架构,由 Caruana 于 1993 年提出。
架构设计 :
共享层 :所有任务共享相同的底层网络(如卷积层、
Transformer 层)
任务特定层 :每个任务有独立的输出头部(如全连接层、解码器)
形式化地,对输入 :
$$
h = f_{}(x; _{}) $$
$$
y_t = f_t(h; _t), t = 1, , T $$
其中 是共享特征提取器, 是任务 的预测头。
优点 :
强正则化 :共享参数被多个任务约束,过拟合风险降低
参数效率 :大部分参数被共享,模型紧凑
简单直接 :易于实现和训练
缺点 :
灵活性差 :所有任务必须使用相同的共享表示,不适合差异大的任务
负迁移风险高 :冲突任务会相互干扰
经验设计原则 :
共享层应该学习通用特征(如 CNN 的低层学习边缘、纹理)
任务特定层应该有足够的容量来处理任务特有的模式
通常在网络的前 70-80%层共享,后 20-30%层独立
软参数共享(Soft Parameter
Sharing)
软参数共享由 Duong 等人于 2015
年提出,允许每个任务有自己的参数,但通过正则化鼓励参数相似。
基础形式 :每个任务有独立的模型 ,添加参数相似性约束:
$$
L = _{t=1}^T L_t(t) + {t < t'} |t - {t'}|_2^2
$$
第二项是
正则化,惩罚任务参数之间的差异。
跨任务层归一化(Cross-Stitch Networks) :
Misra 等人于 2016 年提出 Cross-Stitch
Networks,允许任务在多个层次交换信息。
设两个任务的第 层激活为 和 ,交叉缝合单元(cross-stitch unit)计算:
其中
是可学习参数。
的大小反映了任务 从任务 借用了多少信息。
多任务注意力网络(MTAN) :
Liu 等人于 2019 年提出 Multi-Task Attention
Network,用注意力机制动态选择共享哪些特征。
对共享特征 和任务 ,定义任务特定的注意力权重:
$$
a_t = (W_t h + b_t) $$
其中 是 sigmoid 函数,
是逐元素乘法。每个任务通过注意力机制"软"地选择有用的共享特征。
优点 :
灵活性高 :每个任务可以有不同的参数,适应性强
负迁移风险低 :任务可以选择性地忽略不相关的信息
缺点 :
参数量大 :每个任务都有独立参数,模型膨胀
训练复杂 :需要仔细调节正则化强度
动态网络架构:条件计算
近年来,动态网络(Dynamic
Networks)允许模型根据输入或任务动态调整计算路径。
多任务路由网络(Routing Networks) :
Rosenbaum 等人于 2018
年提出,用路由函数决定每个任务使用哪些子网络。
设有 个子网络Extra close brace or missing open brace \{f_k} _{k=1}^K ,任务 的路由权重为$
w_t = [w_{t,1}, , w_{t,K}]则 任 务
t$ 的输出为:
$$
y_t = {k=1}^K w {t,k} f_k(x) $$
路由权重可以是固定的(离散选择)或可学习的(软路由)。
任务条件适配器(Task-Conditional Adapters) :
Rebuffi 等人于 2017
年提出,在预训练模型的每层插入任务特定的适配器(adapter)模块。
对任务 和层 ,适配器定义为:
$$
h^{(l+1)} = f{(l)}(h {(l)}) + _t(f{(l)}(h {(l)}))
$$
适配器通常是一个小的瓶颈网络(bottleneck):$ d d/r d其 中 r$ 是缩减比(如
16)。只有适配器参数是任务特定的,其余参数共享。
优势 :在预训练模型上添加新任务时,只需训练适配器,高效且避免灾难性遗忘。
任务关系学习:发现相关性
多任务学习的效果很大程度上取决于任务之间的相关性。如何量化和利用任务关系是一个重要研究方向。
任务亲和性矩阵(Task Affinity
Matrix)
任务亲和性矩阵 量化了任务对之间的相关性,其中 表示任务 和任务
的相似度。
计算方法 1:性能相关性
Fifty 等人于 2021 年提出
Taskonomy,通过迁移学习实验测量任务亲和性:
在任务 上训练模型
将模型迁移到任务 ,测量性能
定义亲和性
其中 是随机初始化的基线性能。
计算方法 2:梯度相关性
Yu 等人于 2020 年提出基于梯度余弦相似度的亲和性:
$$
A_{ij} = _{(x,y) } $$
高正相关意味着任务在相同方向更新参数,低或负相关意味着冲突。
计算方法 3:特征表示相似性
计算任务在训练过程中学到的特征表示的 CKA(Centered Kernel
Alignment):
其中 是任务 的特征核矩阵。
任务聚类:分组共享
当任务数量很大时,可以先将任务聚类,同一组内的任务共享参数。
层次化多任务学习(Hierarchical MTL) :
假设任务之间有层次关系,如:
粗粒度任务 :场景分类(室内 vs 室外)
细粒度任务 :具体场景类别(卧室、厨房、街道、公园)
可以设计层次化网络:
共享层提取通用特征
中间层用于粗粒度任务
顶层用于细粒度任务,依赖中间层输出
损失函数为:
$$
L = L_{} + _{i } L_i $$
自适应任务分组 :
Standley 等人于 2020 年提出 Which Tasks Should Be Learned
Together?,用强化学习自动搜索最优任务分组。
算法流程:
初始化每个任务独立训练
用策略网络(policy network)采样任务分组方案
按分组方案训练多任务模型,评估验证集性能
用性能作为奖励,更新策略网络
重复直到找到最优分组
实验表明,自动分组比人工设计或全局共享都更有效。
任务选择:主辅任务的选择
当有一个主任务(primary
task)和多个候选辅助任务时,如何选择最有帮助的辅助任务?
贪心选择策略 :
单独训练主任务,记录性能
对每个候选辅助任务 :
联合训练主任务和辅助任务
记录主任务性能
计算增益
选择增益最大的 个辅助任务
基于元学习的选择 :
Du 等人于 2020 年提出 Automated Auxiliary
Learning,用元学习预测辅助任务的有效性:
在少量数据上快速训练模型
用元模型(meta-model)预测每个辅助任务对主任务的帮助
选择预测收益最高的辅助任务
优势:避免了全量训练所有候选任务的开销。
梯度冲突与任务平衡
多任务学习最大的挑战之一是梯度冲突(Gradient
Conflict) :不同任务的梯度可能指向不同方向,导致训练不稳定或性能下降。
问题分析:什么是梯度冲突
设两个任务的梯度为$ g_1 = L_1和 g_2 = L_2$,naive
多任务优化使用梯度和:
$$
g = _1 g_1 + _2 g_2 $$
问题 :如果Double subscripts: use braces to clarify g_1^g_2 < 0 (余弦相似度为负),两个梯度指向相反方向,平均梯度可能降低某个任务的性能。
例子 :
任务 1 梯度:
任务 2 梯度:
平均梯度: 平均梯度与 的内积: ,意味着更新会增加任务 1 的损失!
形式化地,梯度冲突定义为:
𝟙
实验表明,在多任务训练中,梯度冲突的比例可达 30-50%,严重影响收敛。
静态权重方法:手动调节
最简单的方法是手动设置任务权重 ,但这需要大量实验。
均匀权重 :
$$
L = _{t=1}^T L_t $$
简单但往往次优,因为不同任务的损失尺度可能相差很大(如分类损失 ,回归损失 )。
不确定性加权 :
Kendall 等人于 2018
年提出,用任务的不确定性(uncertainty)自动调节权重。
假设任务 的输出是高斯分布 ,则负对数似然为:
$$
L_t = |y_t - f_t(x)|^2 + _t $$
其中
是可学习的任务不确定性参数。联合损失为:
$$
L = _{t=1}^T ( L_t + _t ) $$
直觉 :
如果任务 的
大(高不确定性),该任务的权重 就小
项防止 (退化解)
实验表明,不确定性加权比均匀权重提升 2-5%。
GradNorm:梯度幅度归一化
Chen 等人于 2018 年提出 GradNorm,通过调节任务权重使梯度幅度平衡。
核心思想 :各任务的梯度幅度应该与其训练速度成正比。
设任务 在时刻
的损失为 ,定义相对逆训练速度:
$$
r_t() = $$
意味着任务 训练较慢,
意味着训练较快。
目标 :调节权重 ,使得:
其中
是所有任务的平均梯度幅度,
是超参数(通常取 1.5)。
算法 :
前向传播,计算加权损失
计算每个任务的梯度幅度
更新权重:最小化
对权重 做梯度下降,然后归一化
效果 :GradNorm
在多个数据集上比均匀权重和不确定性加权都有明显提升(3-8%)。
PCGrad:投影冲突梯度
Yu 等人于 2020 年提出 Projecting Conflicting Gradients
(PCGrad),直接消除梯度冲突。
核心思想 :当两个任务的梯度冲突时,将一个梯度投影到另一个梯度的法平面(normal
plane)上。
对任务 和 ,如果 ,将 替换为:
这是 在 的正交补空间上的投影,保证 (无冲突)。
算法(对
个任务) :
1 2 3 4 5 6 7 8 9 for each task i: g_i = compute gradient of task i for each other task j != i: if g_i . g_j < 0: g_i = g_i - (g_i . g_j / ||g_j||^2) * g_j store modified gradient g_i final_gradient = mean of all modified g_i update parameters with final_gradient
理论保证 :PCGrad 保证对所有任务$ t ^g_t
$,即更新方向至少不增加任何任务的损失。
实验 (NYUv2 数据集,语义分割+深度估计+表面法向量):
均匀权重:mIoU 40.2%, 深度误差 0.61
PCGrad:mIoU 42.7%, 深度误差 0.58
PCGrad 显著缓解了梯度冲突,提升了所有任务的性能。
CAGrad:冲突避免梯度下降
Liu 等人于 2021 年提出 Conflict-Averse Gradient descent
(CAGrad),寻找帕累托最优的梯度方向。
帕累托最优 :一个解是帕累托最优的,当且仅当不存在另一个解能在不降低某个任务性能的情况下提升其他任务。
CAGrad 将梯度选择建模为优化问题:
即寻找最小范数的梯度,同时不与任何任务梯度冲突。
这是一个二次规划(QP)问题,可以用现有求解器(如 CVXPY)高效求解。
实验 :CAGrad
在多个数据集上达到了最佳的帕累托前沿(Pareto front),优于 PCGrad 和
GradNorm 。
MGDA:多目标梯度下降
多目标梯度下降(Multi-Objective Gradient Descent Algorithm, MGDA)由 D
é sid é ri 于 2012 年提出,寻找所有任务共同的下降方向。
核心思想 :寻找梯度 ,使得其与所有任务梯度的内积都为正(即对所有任务都是下降方向)。
形式化为:
这也是一个凸优化问题,可以用 Frank-Wolfe 算法求解。
与 PCGrad 的比较 :
PCGrad 逐对处理冲突,计算简单但可能次优
MGDA 全局优化,理论更优但计算复杂度
辅助任务设计:如何选择辅助任务
辅助任务的选择和设计对多任务学习的成功至关重要。
自监督辅助任务
自监督学习任务可以作为通用的辅助任务,无需额外标注。
旋转预测 :
Gidaris 等人于 2018 年提出,将图像旋转 0/90/180/270
度,让模型预测旋转角度。
损失函数:
这个任务强迫模型学习物体的方向和结构信息。
拼图问题(Jigsaw Puzzles) :
Noroozi 和 Favaro 于 2016 年提出,将图像分成 9
块并打乱,让模型预测正确的排列。
这个任务让模型学习空间关系和物体部件的位置。
对比学习 :
SimCLR 、 MoCo 等对比学习方法也可以作为辅助任务。对一个样本
和其增强版本 :
$$
L_{} = - $$
对比学习帮助模型学习鲁棒的表示。
特定领域的辅助任务
根据主任务的特点,设计针对性的辅助任务。
计算机视觉 :
主任务 :目标检测
辅助任务 :边缘检测、深度估计、表面法向量预测
边缘检测帮助模型更好地定位物体边界,深度估计提供 3D 几何信息。
自然语言处理 :
主任务 :命名实体识别(NER)
辅助任务 :词性标注(POS)、句法依存分析
词性标注提供了词的语法信息,依存分析提供了句子结构,都对 NER
有帮助。
推荐系统 :
主任务 :点击率预测(CTR)
辅助任务 :转化率预测(CVR)、停留时长预测
用户的点击行为、转化行为、停留时长反映了不同层次的兴趣,联合建模可以学习更全面的用户表示。
课程学习:任务的顺序
有时,辅助任务的引入顺序很重要,这涉及课程学习(Curriculum
Learning) 。
简单到复杂 :
从简单的辅助任务开始,逐步引入复杂的任务。
例如,在图像分类中:
先用自监督任务(旋转预测)预训练
再引入粗粒度分类任务(大类别)
最后进行细粒度分类任务(小类别)
任务切换策略 :
Graves 等人于 2017 年提出 Automated Curriculum
Learning,用强化学习动态决定何时切换任务:
当前任务的学习进度(loss 下降速度)
任务之间的相关性
主任务的验证性能
通过策略网络学习最优的任务切换时机。
完整代码实现:多任务学习框架
下面提供一个完整的多任务学习实现,包含硬参数共享、梯度手术(PCGrad)、动态权重调整(GradNorm)等方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torch.utils.data import DataLoaderimport torchvisionimport torchvision.transforms as transformsimport numpy as npfrom typing import List , Tuple , Dict , Optional import copyimport randomclass SharedEncoder (nn.Module): """共享编码器:ResNet-18 的前几层""" def __init__ (self ): super ().__init__() resnet = torchvision.models.resnet18(pretrained=False ) self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 def forward (self, x: torch.Tensor ) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x class TaskHead (nn.Module): """任务特定头部""" def __init__ (self, in_channels: int , num_classes: int , task_type: str = 'classification' ): super ().__init__() self.task_type = task_type self.avgpool = nn.AdaptiveAvgPool2d((1 , 1 )) self.fc1 = nn.Linear(in_channels, 256 ) self.dropout = nn.Dropout(0.5 ) if task_type == 'classification' : self.fc2 = nn.Linear(256 , num_classes) elif task_type == 'regression' : self.fc2 = nn.Linear(256 , num_classes) else : raise ValueError(f"Unknown task type: {task_type} " ) def forward (self, x: torch.Tensor ) -> torch.Tensor: x = self.avgpool(x) x = torch.flatten(x, 1 ) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x class MultiTaskNetwork (nn.Module): """多任务网络:硬参数共享架构""" def __init__ (self, task_configs: List [Dict ] ): """ task_configs: 任务配置列表,每个元素是字典 { 'name': 任务名称, 'num_classes': 类别数, 'type': 'classification' 或 'regression' } """ super ().__init__() self.task_names = [cfg['name' ] for cfg in task_configs] self.num_tasks = len (task_configs) self.shared_encoder = SharedEncoder() self.task_heads = nn.ModuleDict({ cfg['name' ]: TaskHead( in_channels=256 , num_classes=cfg['num_classes' ], task_type=cfg['type' ] ) for cfg in task_configs }) def forward (self, x: torch.Tensor, task_name: Optional [str ] = None ) -> Dict [str , torch.Tensor]: """ 前向传播 如果指定 task_name,只计算该任务;否则计算所有任务 """ shared_features = self.shared_encoder(x) outputs = {} if task_name is not None : outputs[task_name] = self.task_heads[task_name](shared_features) else : for name in self.task_names: outputs[name] = self.task_heads[name](shared_features) return outputs class MultiTaskLoss (nn.Module): """多任务损失:支持不同任务类型""" def __init__ (self, task_configs: List [Dict ], loss_weights: Optional [Dict [str , float ]] = None ): super ().__init__() self.task_configs = {cfg['name' ]: cfg for cfg in task_configs} if loss_weights is None : self.loss_weights = {cfg['name' ]: 1.0 for cfg in task_configs} else : self.loss_weights = loss_weights def compute_loss (self, outputs: Dict [str , torch.Tensor], targets: Dict [str , torch.Tensor] ) -> Dict [str , torch.Tensor]: """计算每个任务的损失""" losses = {} for task_name, output in outputs.items(): target = targets[task_name] task_type = self.task_configs[task_name]['type' ] if task_type == 'classification' : loss = F.cross_entropy(output, target) elif task_type == 'regression' : loss = F.mse_loss(output, target) else : raise ValueError(f"Unknown task type: {task_type} " ) losses[task_name] = loss return losses def forward (self, outputs: Dict [str , torch.Tensor], targets: Dict [str , torch.Tensor] ) -> Tuple [torch.Tensor, Dict [str , torch.Tensor]]: """计算加权总损失""" losses = self.compute_loss(outputs, targets) total_loss = sum (self.loss_weights[name] * loss for name, loss in losses.items()) return total_loss, losses class PCGrad : """Projecting Conflicting Gradients""" def __init__ (self, optimizer: optim.Optimizer, task_names: List [str ] ): self.optimizer = optimizer self.task_names = task_names self.num_tasks = len (task_names) @staticmethod def _project_conflicting (grad_i: torch.Tensor, grad_j: torch.Tensor ) -> torch.Tensor: """将 grad_i 投影到 grad_j 的法平面""" inner_product = torch.dot(grad_i, grad_j) if inner_product < 0 : proj = inner_product / (torch.norm(grad_j) ** 2 + 1e-8 ) grad_i = grad_i - proj * grad_j return grad_i def step (self, losses: Dict [str , torch.Tensor] ): """PCGrad 优化步骤""" task_gradients = {} for task_name in self.task_names: self.optimizer.zero_grad() losses[task_name].backward(retain_graph=True ) grads = [] for param in self.optimizer.param_groups[0 ]['params' ]: if param.grad is not None : grads.append(param.grad.clone().flatten()) task_gradients[task_name] = torch.cat(grads) modified_gradients = {} for i, task_i in enumerate (self.task_names): grad_i = task_gradients[task_i].clone() for j, task_j in enumerate (self.task_names): if i != j: grad_j = task_gradients[task_j] grad_i = self._project_conflicting(grad_i, grad_j) modified_gradients[task_i] = grad_i avg_gradient = sum (modified_gradients.values()) / self.num_tasks self.optimizer.zero_grad() idx = 0 for param in self.optimizer.param_groups[0 ]['params' ]: if param.grad is not None : param_size = param.numel() param.grad = avg_gradient[idx:idx+param_size].view_as(param) idx += param_size self.optimizer.step() class GradNorm : """Gradient Normalization for Adaptive Loss Balancing""" def __init__ ( self, model: nn.Module, task_names: List [str ], alpha: float = 1.5 , lr_weights: float = 0.025 ): self.model = model self.task_names = task_names self.num_tasks = len (task_names) self.alpha = alpha self.task_weights = nn.Parameter(torch.ones(self.num_tasks)) self.weight_optimizer = optim.Adam([self.task_weights], lr=lr_weights) self.initial_losses = None def compute_grad_norm (self, loss: torch.Tensor, parameters: List [torch.nn.Parameter] ) -> float : """计算损失关于参数的梯度范数""" grads = torch.autograd.grad(loss, parameters, retain_graph=True , create_graph=True ) grad_norm = torch.norm(torch.cat([g.flatten() for g in grads])) return grad_norm def step (self, losses: Dict [str , torch.Tensor], epoch: int ): """GradNorm 更新步骤""" if self.initial_losses is None : self.initial_losses = {name: loss.item() for name, loss in losses.items()} weighted_losses = [] for i, task_name in enumerate (self.task_names): weighted_losses.append(self.task_weights[i] * losses[task_name]) total_loss = sum (weighted_losses) shared_params = list (self.model.shared_encoder.parameters()) grad_norms = [] for weighted_loss in weighted_losses: grad_norm = self.compute_grad_norm(weighted_loss, shared_params) grad_norms.append(grad_norm) avg_grad_norm = sum (grad_norms) / self.num_tasks relative_inverse_train_rates = [] for i, task_name in enumerate (self.task_names): current_loss = losses[task_name].item() initial_loss = self.initial_losses[task_name] loss_ratio = current_loss / (initial_loss + 1e-8 ) avg_loss_ratio = sum ( losses[t].item() / (self.initial_losses[t] + 1e-8 ) for t in self.task_names ) / self.num_tasks r_i = loss_ratio / (avg_loss_ratio + 1e-8 ) relative_inverse_train_rates.append(r_i) grad_norm_loss = 0 for i in range (self.num_tasks): target_grad_norm = avg_grad_norm * (relative_inverse_train_rates[i] ** self.alpha) grad_norm_loss += torch.abs (grad_norms[i] - target_grad_norm) self.weight_optimizer.zero_grad() grad_norm_loss.backward() self.weight_optimizer.step() with torch.no_grad(): self.task_weights.data = self.task_weights.data * self.num_tasks / self.task_weights.sum () return total_loss, {name: self.task_weights[i].item() for i, name in enumerate (self.task_names)} class MultiTaskTrainer : """多任务学习训练器""" def __init__ ( self, model: MultiTaskNetwork, task_configs: List [Dict ], device: str = 'cuda' , optimization_method: str = 'uniform' , initial_weights: Optional [Dict [str , float ]] = None ): self.model = model.to(device) self.device = device self.task_configs = task_configs self.task_names = [cfg['name' ] for cfg in task_configs] self.optimization_method = optimization_method self.criterion = MultiTaskLoss(task_configs, initial_weights) self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3 , weight_decay=1e-4 ) if optimization_method == 'pcgrad' : self.pcgrad = PCGrad(self.optimizer, self.task_names) elif optimization_method == 'gradnorm' : self.gradnorm = GradNorm(self.model, self.task_names) def train_epoch (self, train_loader: DataLoader, epoch: int ) -> Dict [str , float ]: """训练一个 epoch""" self.model.train() epoch_losses = {name: 0.0 for name in self.task_names} epoch_total_loss = 0.0 num_batches = 0 for batch_idx, (inputs, targets) in enumerate (train_loader): inputs = inputs.to(self.device) targets = {name: targets[name].to(self.device) for name in self.task_names} outputs = self.model(inputs) if self.optimization_method == 'uniform' : total_loss, losses = self.criterion(outputs, targets) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() elif self.optimization_method == 'pcgrad' : _, losses = self.criterion(outputs, targets) self.pcgrad.step(losses) total_loss = sum (losses.values()) elif self.optimization_method == 'gradnorm' : _, losses = self.criterion(outputs, targets) total_loss, weights = self.gradnorm.step(losses, epoch) if batch_idx % 50 == 0 : print (f" Batch {batch_idx} : Task weights = {weights} " ) epoch_total_loss += total_loss.item() for name, loss in losses.items(): epoch_losses[name] += loss.item() num_batches += 1 if batch_idx % 100 == 0 : print (f'Epoch {epoch} [{batch_idx} /{len (train_loader)} ] ' f'Total Loss: {total_loss.item():.4 f} ' ) avg_losses = {name: loss / num_batches for name, loss in epoch_losses.items()} avg_losses['total' ] = epoch_total_loss / num_batches return avg_losses @torch.no_grad() def evaluate (self, test_loader: DataLoader ) -> Dict [str , float ]: """评估模型""" self.model.eval () metrics = {} for cfg in self.task_configs: if cfg['type' ] == 'classification' : metrics[cfg['name' ]] = {'correct' : 0 , 'total' : 0 } elif cfg['type' ] == 'regression' : metrics[cfg['name' ]] = {'mse' : 0.0 , 'count' : 0 } for inputs, targets in test_loader: inputs = inputs.to(self.device) targets = {name: targets[name].to(self.device) for name in self.task_names} outputs = self.model(inputs) for name in self.task_names: task_type = next (cfg['type' ] for cfg in self.task_configs if cfg['name' ] == name) if task_type == 'classification' : _, predicted = outputs[name].max (1 ) metrics[name]['total' ] += targets[name].size(0 ) metrics[name]['correct' ] += predicted.eq(targets[name]).sum ().item() elif task_type == 'regression' : mse = F.mse_loss(outputs[name], targets[name], reduction='sum' ) metrics[name]['mse' ] += mse.item() metrics[name]['count' ] += targets[name].size(0 ) results = {} for name in self.task_names: task_type = next (cfg['type' ] for cfg in self.task_configs if cfg['name' ] == name) if task_type == 'classification' : accuracy = 100. * metrics[name]['correct' ] / metrics[name]['total' ] results[name] = accuracy elif task_type == 'regression' : rmse = np.sqrt(metrics[name]['mse' ] / metrics[name]['count' ]) results[name] = rmse return results class MultiTaskCIFAR10 (torch.utils.data.Dataset): """ 将 CIFAR-10 转换为多任务数据集 任务 1:粗粒度分类(动物 vs 交通工具) 任务 2:细粒度分类(10 类) """ def __init__ (self, root: str , train: bool , transform=None ): self.cifar10 = torchvision.datasets.CIFAR10( root=root, train=train, download=True , transform=transform ) self.coarse_mapping = { 0 : 1 , 1 : 1 , 2 : 0 , 3 : 0 , 4 : 0 , 5 : 0 , 6 : 0 , 7 : 0 , 8 : 1 , 9 : 1 , } def __len__ (self ): return len (self.cifar10) def __getitem__ (self, idx ): image, fine_label = self.cifar10[idx] coarse_label = self.coarse_mapping[fine_label] targets = { 'coarse' : torch.tensor(coarse_label, dtype=torch.long), 'fine' : torch.tensor(fine_label, dtype=torch.long) } return image, targets def main (): device = 'cuda' if torch.cuda.is_available() else 'cpu' num_epochs = 50 batch_size = 128 transform_train = transforms.Compose([ transforms.RandomCrop(32 , padding=4 ), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.2010 )), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.2010 )), ]) trainset = MultiTaskCIFAR10(root='./data' , train=True , transform=transform_train) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True , num_workers=2 ) testset = MultiTaskCIFAR10(root='./data' , train=False , transform=transform_test) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False , num_workers=2 ) task_configs = [ {'name' : 'coarse' , 'num_classes' : 2 , 'type' : 'classification' }, {'name' : 'fine' , 'num_classes' : 10 , 'type' : 'classification' }, ] methods = ['uniform' , 'pcgrad' , 'gradnorm' ] for method in methods: print (f"\n{'=' *60 } " ) print (f"Training with optimization method: {method} " ) print (f"{'=' *60 } \n" ) model = MultiTaskNetwork(task_configs) trainer = MultiTaskTrainer( model=model, task_configs=task_configs, device=device, optimization_method=method ) best_results = {} for epoch in range (num_epochs): print (f"\nEpoch {epoch+1 } /{num_epochs} " ) train_losses = trainer.train_epoch(trainloader, epoch) print (f"Train losses: {train_losses} " ) if (epoch + 1 ) % 5 == 0 : test_results = trainer.evaluate(testloader) print (f"Test results: {test_results} " ) if not best_results or test_results['fine' ] > best_results.get('fine' , 0 ): best_results = test_results.copy() print (f"\nBest results for {method} : {best_results} " ) if __name__ == '__main__' : main()
代码说明
网络架构 :
SharedEncoder:共享的 ResNet 编码器
TaskHead:任务特定的分类/回归头部
MultiTaskNetwork:组合共享编码器和多个任务头部
损失函数 :
MultiTaskLoss:支持分类和回归任务的多任务损失
可以指定任务权重
优化方法 :
Uniform :均匀权重,标准反向传播
PCGrad :梯度投影,消除任务间冲突
GradNorm :动态调整任务权重,平衡梯度幅度
训练器 :
MultiTaskTrainer:统一的训练接口
支持多种优化方法
自动处理不同任务类型的评估指标
数据集 :
MultiTaskCIFAR10:将 CIFAR-10 转换为多任务数据集
包含粗粒度和细粒度两个分类任务
Q&A:常见问题解答
Q1:如何判断任务是否适合多任务学习?
A :任务适合多任务学习需要满足:
相关性标准 :
输入域相同或相似 :如都是图像、都是文本
需要相似的底层特征 :如边缘检测、词嵌入
正相关性 :在一个任务上表现好的模型在另一个任务上也倾向于表现好
定量评估 :
计算任务间的迁移学习增益:先在任务 A 上训练,再迁移到任务
B,看是否比从头训练 B 更好
计算梯度余弦相似度:如果平均相似度>0.3,通常适合多任务学习
反例 :
图像分类和文本情感分析:输入模态完全不同
人脸识别和场景分类:需要的特征(局部细节 vs 全局布局)差异很大
Q2:如何设置任务权重?
A :任务权重设置有多种策略:
静态权重 :
均匀权重 : ,简单但可能次优
按数据量加权 :$_t 其 中 N_t是 任 务 t$ 的样本数
按损失尺度加权 : ,归一化初始损失
动态权重 :
不确定性加权 : ,自动调节
GradNorm :根据梯度幅度和训练进度动态调整
强化学习 :用 RL 学习权重调度策略
调优建议 :
先用均匀权重或不确定性加权作为基线
如果某个任务明显欠拟合,增加其权重
使用 GradNorm 可以自动平衡,减少手动调节
Q3:什么时候使用硬共享,什么时候使用软共享?
A :选择取决于任务相关性和资源约束:
硬共享(Hard Parameter Sharing) :
适用 :任务高度相关,需要强正则化,参数预算紧张
例子 :同一个图像的多个视觉任务(检测+分割)
软共享(Soft Parameter Sharing) :
适用 :任务部分相关,需要灵活性,参数预算充足
例子 :不同领域的文本分类(新闻+医疗+法律)
混合策略 :
前几层硬共享(学习通用特征)
后几层软共享(学习任务特定模式)
使用注意力机制动态选择共享
实验对比 (NYUv2 数据集):
硬共享:参数 25M,平均性能 85%
软共享:参数 60M,平均性能 87%
混合:参数 35M,平均性能 86.5%
Q4:如何处理任务数量不平衡?
A :任务数量不平衡指不同任务的样本数差异很大:
采样策略 :
按任务采样 :每个 batch
随机选择一个任务,然后从该任务采样数据
按数据量采样 :每个 batch
按任务数据量比例混合采样
温度采样 :概率 ,其中 是温度(如
0.7)
损失加权 :
为小样本任务增加权重: ,其中 课程学习 :
先在大样本任务上预训练
再逐步引入小样本任务
实验 (数据量比例 1:10:100 的三个任务):
均匀采样:小任务性能很差(40%)
按任务采样:所有任务性能均衡(75%, 73%, 77%)
温度采样( ):略优于按任务采样(76%, 74%,
77%)
Q5:PCGrad 和 GradNorm
哪个更好?
A :两者解决不同的问题,可以结合使用:
PCGrad :
目标 :消除梯度冲突,保证每个任务都不受损
优点 :理论保证,对所有任务都是下降方向
缺点 :不考虑任务的相对重要性或训练进度
GradNorm :
目标 :平衡任务的训练速度,防止某些任务训练过快
优点 :动态调整权重,适应训练动态
缺点 :不显式处理梯度冲突
组合策略 :
用 GradNorm 计算任务权重$ w_t用 处 理 加 权 梯 度 w_t L_t$ 的冲突
实验对比 (Cityscapes 数据集):
PCGrad:mIoU 76.3%, depth error 0.012
GradNorm:mIoU 75.8%, depth error 0.011
PCGrad + GradNorm:mIoU 77.1%, depth error 0.010
组合方法在大多数情况下效果最好。
Q6:负迁移如何检测和缓解?
A :负迁移是多任务学习的主要风险:
检测方法 :
性能对比 :多任务性能 < 单任务性能
梯度分析 :计算任务间梯度余弦相似度,负相关说明冲突
消融实验 :逐个移除任务,看主任务性能是否提升
缓解策略 :
任务分组 :将冲突任务分到不同组,组内共享
软共享 :用注意力机制让任务选择性地共享
梯度手术 :用 PCGrad 或 CAGrad 消除冲突
任务选择 :移除对主任务有害的辅助任务
案例 (CelebA 数据集,40 个属性预测):
全部 40 个任务联合训练:平均准确率 83%
移除 10 个负迁移任务:平均准确率 86%
任务聚类(分 5 组):平均准确率 87%
Q7:多任务学习和集成学习的区别?
A :两者都涉及多个任务/模型,但目标不同:
多任务学习 :
目标 :通过共享表示提升每个任务的泛化能力
参数 :大部分参数共享,模型紧凑
训练 :联合训练,相互辅助
推理 :可以只计算感兴趣的任务
集成学习 :
目标 :通过组合多个模型的预测提升单个任务的性能
参数 :每个模型独立,参数量大
训练 :独立训练或顺序训练(Boosting)
推理 :必须计算所有模型然后组合
结合 :可以多任务学习的每个任务头部作为集成的基模型,结合两者优势。
Q8:如何在预训练模型上进行多任务学习?
A :在预训练模型(如 BERT 、
ResNet)上添加多任务学习:
策略 1:冻结预训练参数 :
冻结预训练编码器
只训练任务头部
优点:快速,避免灾难性遗忘
缺点:可能无法充分适应新任务
策略 2:全模型微调 :
所有参数都参与多任务训练
优点:充分适应新任务
缺点:可能遗忘预训练知识
策略 3:适配器(Adapter) :
在预训练模型的每层插入任务特定的适配器
只训练适配器和任务头部
优点:参数效率高,保留预训练知识
缺点:略微增加推理时间
实验 (BERT 在 GLUE 多任务):
冻结 BERT:平均得分 78.5
全模型微调:平均得分 82.1
适配器:平均得分 81.3,参数量仅 3%
适配器是最佳平衡。
Q9:多任务学习如何处理任务相关性随时间变化?
A :在训练过程中,任务之间的相关性可能发生变化:
动态架构 :
用路由网络(Routing Network)动态决定哪些任务共享哪些层
路由权重随训练过程更新
元学习 :
用元学习(Meta-Learning)周期性地重新评估任务关系
每隔 个
epoch,在验证集上测试不同任务组合的性能
注意力机制 :
用跨任务注意力(Cross-Task Attention)动态选择借用哪些任务的信息
注意力权重会随训练自动调整
实验 :在长期训练(>100
epochs)中,动态方法比静态方法平均提升 2-3%。
Q10:多任务学习如何应用于在线学习场景?
A :在线学习(Online Learning)中,数据流式到达:
挑战 :
任务到达顺序 :新任务不断加入
灾难性遗忘 :学习新任务时忘记旧任务
计算约束 :实时更新,不能重新训练
解决方案 :
渐进式网络(Progressive
Networks) :为每个新任务添加新列,保留旧任务参数
弹性权重巩固(EWC) :对重要参数添加正则化,防止大幅改变
记忆重放(Memory
Replay) :保留部分旧任务数据,混合训练
代码框架 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 for task_t in task_stream: if task_t is new: add_new_task_head(task_t) for batch in task_t_data: loss = compute_loss(task_t, batch) loss += ewc_penalty(old_tasks) if memory_buffer: replay_batch = memory_buffer.sample() loss += compute_loss(replay_batch) update_parameters(loss) memory_buffer.add(task_t_data.sample())
Q11:多任务学习中如何设计辅助任务?
A :辅助任务设计遵循以下原则:
相关性原则 :
辅助任务应与主任务共享底层结构
例如:目标检测(主) + 边缘检测(辅)
互补性原则 :
辅助任务提供主任务缺失的信息
例如:RGB 图像分类(主) + 深度估计(辅)
简单性原则 :
辅助任务应相对简单,易于学习
避免辅助任务太难导致主任务受干扰
自监督任务 :
旋转预测、拼图、对比学习等无需额外标注
通用性强,可广泛应用
领域知识 :
利用领域专家知识设计任务
例如:医疗影像中,器官分割可辅助病变检测
实验验证 :
Q12:多任务学习的可解释性如何提升?
A :多任务学习的可解释性有助于理解任务间关系:
任务关系可视化 :
绘制任务亲和性矩阵热图
显示哪些任务相互帮助,哪些冲突
特征重要性分析 :
计算每个任务对共享特征的依赖程度
用注意力权重或梯度幅度量化
消融研究 :
逐个移除任务,观察对其他任务的影响
构建任务依赖图
案例研究 :
选择具体样本,分析各任务的预测和注意力
显示任务间如何相互影响
工具 :
Grad-CAM for multi-task:可视化每个任务关注的区域
SHAP for multi-task:解释每个任务的特征贡献
论文推荐
经典论文
Caruana, "Multitask Learning", Machine Learning
1997
多任务学习的开创性论文
提出硬参数共享架构
理论分析 MTL 的泛化能力
Ruder, "An Overview of Multi-Task Learning in Deep Neural
Networks", 2017
架构设计
Misra et al., "Cross-Stitch Networks for Multi-task
Learning", CVPR 2016
提出交叉缝合网络
允许任务在多个层次交换信息
可学习的任务间连接
Liu et al., "End-to-End Multi-Task Learning with
Attention", CVPR 2019
多任务注意力网络(MTAN)
用注意力机制动态选择共享特征
在视觉任务上效果显著
Rebuffi et al., "Learning multiple visual domains with
residual adapters", NeurIPS 2017
提出适配器(Adapter)架构
在预训练模型上高效添加新任务
避免灾难性遗忘
任务关系学习
Zamir et al., "Taskonomy: Disentangling Task Transfer
Learning", CVPR 2018
大规模任务关系研究
构建任务亲和性矩阵
发现视觉任务的层次结构
Standley et al., "Which Tasks Should Be Learned Together
in Multi-task Learning?", ICML 2020
自动任务分组
用强化学习搜索最优分组方案
实验验证分组的重要性
梯度优化
Chen et al., "GradNorm: Gradient Normalization for
Adaptive Loss Balancing in Deep Multitask Networks", ICML
2018
提出 GradNorm 算法
动态调整任务权重
平衡梯度幅度和训练速度
Yu et al., "Gradient Surgery for Multi-Task Learning",
NeurIPS 2020
提出 PCGrad 算法
投影冲突梯度
理论保证和实验验证
Liu et al., "Conflict-Averse Gradient Descent for
Multi-task Learning", NeurIPS 2021
提出 CAGrad 算法
寻找帕累托最优梯度
优于 PCGrad 的理论性质
不确定性与权重
Kendall et al., "Multi-Task Learning Using Uncertainty to
Weigh Losses for Scene Geometry and Semantics", CVPR 2018
用不确定性自动调节任务权重
理论推导基于贝叶斯原理
在自动驾驶场景中应用
Sener & Koltun, "Multi-Task Learning as
Multi-Objective Optimization", NeurIPS 2018
将 MTL 建模为多目标优化
用 MGDA 算法寻找帕累托最优
理论严谨,实验充分
多任务学习是一个强大的学习范式,通过共享表示和联合优化,使模型能够同时掌握多个相关技能。从硬参数共享到软参数共享,从静态权重到动态平衡,从梯度冲突到帕累托最优,多任务学习涉及架构设计、优化策略、任务关系分析等多个层面。成功的多任务学习需要仔细选择相关任务、设计合理的共享结构、平衡不同任务的学习进度,并警惕负迁移的风险。随着计算资源的增长和理论的完善,多任务学习将在更多领域发挥重要作用,成为构建通用智能系统的关键技术之一。