在时间序列里,很多关键信息并不在“最近一步”:可能是周期中的某个相位、某个突发后的回落,或是跨很长间隔的相似模式。
Attention
的好处是它不需要按时间一步步把信息“传”过来,而是直接学会“该看历史里的哪几段、看多少权重”,从而更擅长处理长距离依赖与不规则相关性。本文会把自注意力的计算流程按公式拆开( 、缩放点积、 softmax
权重、加权求和),并结合代码层面的实现细节说明:这些矩阵运算到底在做什么、复杂度与序列长度的关系是什么,以及在时间序列任务里如何组织输入、如何解释注意力权重带来的可解释性。
数学原理
自注意力机制通过计算输入序列中每个位置与其他位置之间的相似度来生成新的表示。具体步骤如下:
输入表示 :假设输入序列为 ,每个 是一个向量。
线性变换 :通过学习的权重矩阵 将输入序列 转换为查询( Query)、键( Key)和值(
Value)向量:
$$
Q = XW^Q, K = XW^K, V = XW^V $$
计算注意力得分 :通过点积计算查询和键之间的相似度,并使用缩放因子
进行缩放:
归一化注意力得分 :使用 softmax
函数对注意力得分进行归一化,得到注意力权重:
加权求和 :将注意力权重应用于值向量,得到最终的注意力输出:
代码实现
缩放点积注意力:自注意力机制的核心计算
问题背景 :传统 RNN/LSTM
通过递归传递隐藏状态处理序列,存在两个问题:
1)长距离依赖难以捕捉(梯度消失),
2)每个时间步只能看到之前的信息,无法并行计算。自注意力机制通过"查询-键-值"(
Q/K/V)框架,让每个位置直接关注序列中所有位置,从而解决长距离依赖问题,且支持并行计算。
解决思路 :自注意力的核心是"相似度加权"——对于每个查询位置
,计算它与所有键位置的相似度(点积),通过
softmax 归一化为注意力权重,然后用这些权重对值向量加权求和。缩放因子
防止点积值过大导致
softmax 梯度消失。整个过程可以表示为: 。
设计考虑 :
Q/K/V 的含义 : Query(查询)表示"我想找什么",
Key(键)表示"我是什么", Value(值)表示"我的内容"。自注意力中
Q=K=V(都是输入序列的线性变换)
缩放的必要性 :点积 的方差随 增长,导致 softmax
饱和(梯度消失)。除以
保持方差为 1
掩码机制 :通过 mask 屏蔽无效位置(如 padding
、未来信息),设置为 使得
softmax 后权重接近 0
计算复杂度 : ,其中
是序列长度,
是特征维度。对于长序列,需要使用稀疏注意力或线性注意力
以下是一个简单的自注意力机制的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 import numpy as npdef scaled_dot_product_attention (Q, K, V, mask=None ): """ 缩放点积注意力( Scaled Dot-Product Attention):自注意力机制的核心计算 核心思想:通过查询-键相似度计算注意力权重,对值向量加权求和 数学表达: Attention(Q, K, V) = softmax(QK^T / √ d_k) · V 其中: - Q:查询矩阵 (batch_size, seq_len_q, d_k) - K:键矩阵 (batch_size, seq_len_k, d_k) - V:值矩阵 (batch_size, seq_len_k, d_v) - d_k:键/查询的维度(缩放因子) 计算流程: 1. 计算相似度得分: scores = Q · K^T / √ d_k 2. 应用掩码(可选): scores[mask==0] = -∞ 3. 归一化权重: weights = softmax(scores) 4. 加权求和: output = weights · V Parameters: ----------- Q : numpy.ndarray, shape (batch_size, seq_len_q, d_k) 查询矩阵( Query) - 每个位置表示"我想关注什么" - 自注意力中: Q = X · W_Q(输入序列的线性变换) - 维度: d_k 是查询/键的维度(通常 64 、 128 、 256 等) K : numpy.ndarray, shape (batch_size, seq_len_k, d_k) 键矩阵( Key) - 每个位置表示"我是谁"(用于匹配查询) - 自注意力中: K = X · W_K(输入序列的线性变换) - 维度: d_k 必须与 Q 的最后一维相同 V : numpy.ndarray, shape (batch_size, seq_len_k, d_v) 值矩阵( Value) - 每个位置表示"我的内容"(实际被加权的信息) - 自注意力中: V = X · W_V(输入序列的线性变换) - 维度: d_v 可以与 d_k 不同(通常 d_v = d_k) mask : numpy.ndarray, optional, shape (batch_size, seq_len_q, seq_len_k) 掩码矩阵,用于屏蔽无效位置 - mask[i, j] = 1:位置 i 可以关注位置 j - mask[i, j] = 0:位置 i 不能关注位置 j(会被设为-∞) - 常见用途: * Padding 掩码:屏蔽 padding 位置 * Causal 掩码:屏蔽未来信息(解码器自注意力) * 自定义掩码:屏蔽特定位置 Returns: -------- output : numpy.ndarray, shape (batch_size, seq_len_q, d_v) 注意力输出:每个查询位置的加权值向量 - output[i, j] = Σ(attention_weights[i, j, k] * V[i, k]) - 表示位置 j 的查询关注所有键位置后的加权结果 attention_weights : numpy.ndarray, shape (batch_size, seq_len_q, seq_len_k) 注意力权重矩阵(归一化后的相似度) - attention_weights[i, j, k]:位置 j 的查询对位置 k 的键的注意力权重 - 每行和为 1( softmax 归一化) - 可用于可视化:哪些位置被关注最多 Notes: ------ - 缩放因子√ d_k 防止点积值过大导致 softmax 饱和(梯度消失) - 复杂度: O(seq_len_q × seq_len_k × d_k),对于长序列可能成为瓶颈 - 自注意力中通常 seq_len_q = seq_len_k = n(序列长度) - 掩码位置设为-1e9 而非-∞,因为 softmax(-1e9) ≈ 0(数值稳定) Example: -------- >>> # 自注意力: Q=K=V(都是输入序列的变换) >>> X = np.random.rand(1, 10, 64) # (batch=1, seq_len=10, d_model=64) >>> W_Q, W_K, W_V = np.random.rand(64, 64), np.random.rand(64, 64), np.random.rand(64, 64) >>> Q, K, V = X @ W_Q, X @ W_K, X @ W_V >>> output, weights = scaled_dot_product_attention(Q, K, V) >>> # output 形状:(1, 10, 64), weights 形状:(1, 10, 10) """ d_k = Q.shape[-1 ] scores = np.matmul(Q, K.transpose(-2 , -1 )) scores = scores / np.sqrt(d_k) if mask is not None : scores = np.where(mask == 0 , -1e9 , scores) attention_weights = np.exp(scores - np.max (scores, axis=-1 , keepdims=True )) attention_weights = attention_weights / np.sum (attention_weights, axis=-1 , keepdims=True ) output = np.matmul(attention_weights, V) return output, attention_weights batch_size = 1 seq_len = 10 d_k = 64 d_v = 64 Q = np.random.rand(batch_size, seq_len, d_k) K = np.random.rand(batch_size, seq_len, d_k) V = np.random.rand(batch_size, seq_len, d_v) output, attention_weights = scaled_dot_product_attention(Q, K, V) print (f"输入形状: Q={Q.shape} , K={K.shape} , V={V.shape} " )print (f"输出形状: output={output.shape} , attention_weights={attention_weights.shape} " )print (f"注意力权重每行和(应为 1): {attention_weights.sum (axis=-1 )} " )import matplotlib.pyplot as pltplt.figure(figsize=(8 , 6 )) plt.imshow(attention_weights[0 ], cmap='viridis' , aspect='auto' ) plt.colorbar(label='注意力权重' ) plt.xlabel('键位置(被关注的位置)' ) plt.ylabel('查询位置(关注的位置)' ) plt.title('自注意力权重矩阵可视化' ) plt.show()
关键点解读 :
缩放因子的重要性 :缩放因子 看似简单,但至关重要。当
较大(如 512)时,点积
的值可能很大(如 100),导致 softmax 输入很大, softmax 输出接近 one-hot
分布(几乎只关注一个位置),梯度接近 0 。除以 后,点积的方差保持为 1,
softmax 输入在合理范围(如[-2, 2]),梯度正常流动。这是 Transformer
成功的关键设计之一。
Q/K/V 的语义解释 :在自注意力中, Q/K/V
都来自同一输入序列
的线性变换,但语义不同。 Query 表示"我想关注什么特征", Key
表示"我提供什么特征用于匹配", Value 表示"我的实际内容"。注意力权重 表示"位置 的查询对位置 的键的匹配程度",然后用这个权重对值向量加权: 。这允许每个位置直接关注序列中所有位置,无需递归传递。
掩码机制的应用 :掩码在 Transformer 中至关重要。
Padding 掩码屏蔽 padding 位置(避免关注无效信息), Causal
掩码屏蔽未来信息(解码器自注意力中,位置 只能关注位置 )。掩码通过将无效位置的得分设为 实现, softmax 后这些位置的权重接近
0,不影响计算。
常见问题 :
问题
解答
Q: 为什么需要缩放因子√ d_k?
A: 防止点积值过大导致 softmax 饱和。点积 的方差与 成正比,当
大时,点积值可能很大,导致 softmax 输出接近
one-hot(梯度消失)。除以
保持方差为 1 。
Q: Q/K/V 必须来自同一输入吗?
A: 不一定。自注意力中 Q=K=V(都来自输入 X),但交叉注意力(
Cross-Attention)中 Q 来自解码器, K/V
来自编码器。这允许解码器关注编码器的所有位置。
Q: 注意力机制的复杂度是多少?
A: ,其中
是序列长度, 是特征维度。对于长序列(如$ n=10000) , 这 成 为 瓶 颈 。 解 决 方 案 : 稀 疏 注 意 力 ( 只 关 注 部 分 位 置 ) 、 线 性 注 意 力 ( 近 似 计 算 ) 、 ( 内 存 优 化 ) 。 如 何 解 释 注 意 力 权 重 ? 注 意 力 权 重 矩 阵 可 视 化 : 横 轴 是 被 关 注 的 位 置 , 纵 轴 是 关 注 的 位 置 。 颜 色 深 浅 表 示 权 重 大 小 。 可 以 用 于 可 解 释 性 分 析 : 哪 些 历 史 位 置 对 当 前 预 测 最 重 要 。 自 注 意 力 能 替 代 吗 ? 在 中 , 自 注 意 力 完 全 替 代 了 。 优 势 : 并 行 计 算 、 长 距 离 依 赖 、 可 解 释 性 。 劣 势 : O(n^2)$
复杂度、位置信息需要额外编码(位置编码)。
使用示例 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 import torchimport torch.nn as nnimport torch.nn.functional as Fclass ScaledDotProductAttention (nn.Module): """PyTorch 版本的缩放点积注意力""" def __init__ (self, d_k, dropout=0.1 ): super (ScaledDotProductAttention, self).__init__() self.d_k = d_k self.dropout = nn.Dropout(dropout) def forward (self, Q, K, V, mask=None ): """ Q, K, V: (batch_size, seq_len, d_k/d_v) mask: (batch_size, seq_len_q, seq_len_k) 或 (batch_size, 1, seq_len_k) """ scores = torch.matmul(Q, K.transpose(-2 , -1 )) / np.sqrt(self.d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) attention_weights = F.softmax(scores, dim=-1 ) attention_weights = self.dropout(attention_weights) output = torch.matmul(attention_weights, V) return output, attention_weights batch_size = 32 seq_len = 100 d_model = 128 X = torch.randn(batch_size, seq_len, d_model) W_Q = nn.Linear(d_model, d_model) W_K = nn.Linear(d_model, d_model) W_V = nn.Linear(d_model, d_model) Q = W_Q(X) K = W_K(X) V = W_V(X) attention = ScaledDotProductAttention(d_k=d_model) output, weights = attention(Q, K, V) print (f"输入形状: {X.shape} " )print (f"输出形状: {output.shape} " )print (f"注意力权重形状: {weights.shape} " )import matplotlib.pyplot as pltplt.figure(figsize=(10 , 8 )) plt.imshow(weights[0 ].detach().numpy(), cmap='viridis' , aspect='auto' ) plt.colorbar(label='注意力权重' ) plt.xlabel('键位置' ) plt.ylabel('查询位置' ) plt.title('自注意力权重矩阵(第一个样本)' ) plt.show()
Seq2Seq with Attention
数学原理
带有注意力机制的 Seq2Seq
模型通过动态调整解码器对编码器隐藏状态的关注来提高模型性能。以下是其核心原理:
编码器 :将输入序列 通过 RNN(如 LSTM 或
GRU)处理,生成隐藏状态序列 。
注意力权重 :在解码器的每个时间步 ,计算解码器隐藏状态 与编码器隐藏状态
之间的相似度,得到注意力权重 :
其中, ,通常采用点积、双线性或 MLP
作为得分函数。
上下文向量 :根据注意力权重对编码器隐藏状态加权求和,得到上下文向量
$ c_t: $
c_t = {i=1}^{n} {t,i} h_i $$
解码器 :将上下文向量
与解码器的输入和隐藏状态结合,生成当前时间步的输出。
代码实现
Seq2Seq
with Attention:编码器-解码器架构的注意力增强
问题背景 :传统 Seq2Seq
模型(编码器-解码器)存在信息瓶颈问题:编码器将整个输入序列压缩为固定长度的上下文向量(通常是最后一个隐藏状态),解码器只能基于这个向量生成输出。对于长序列,固定长度向量无法承载所有信息,导致性能下降。注意力机制通过让解码器在每个时间步动态关注编码器的所有位置,解决了信息瓶颈问题。
解决思路 : Seq2Seq with Attention
的核心是"动态上下文向量"——解码器在每个时间步$ t, 计 算 当 前 隐 藏 状 态 s_t$
与编码器所有隐藏状态 的相似度,得到注意力权重 ,然后加权求和得到上下文向量$
c_t = i {t,i}
h_i$。这个上下文向量包含"当前时刻最需要关注的编码器信息",与解码器输入和隐藏状态结合生成输出。这样,解码器可以根据当前状态动态选择关注编码器的不同位置。
设计考虑 : 1.
注意力得分函数 :这里使用 MLP(多层感知机)计算得分:$
e_{t,i} = v^T (W [s_t; h_i]), 其 中 W是 线 性 层 , v$
是可学习向量。也可以使用点积、双线性等。 2.
上下文向量的使用 :上下文向量 在两个地方使用:
1)与解码器输入拼接作为 LSTM 输入,
2)与解码器输出拼接后通过线性层生成最终输出。 3. Teacher Forcing
vs Free Running :训练时使用 Teacher
Forcing(使用真实目标序列),推理时使用 Free
Running(使用模型自身输出)。这里实现的是训练模式。 4.
批处理支持 :所有操作支持批处理,提高训练效率。
以下是一个带有注意力机制的 Seq2Seq 模型的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 import torchimport torch.nn as nnimport torch.optim as optimclass Attention (nn.Module): """ Bahdanau 注意力机制:用于 Seq2Seq 模型 核心思想:计算解码器隐藏状态与编码器所有隐藏状态的相似度,得到注意力权重 数学表达: e_{t,i} = v^T · tanh(W · [s_t; h_i]) α_{t,i} = softmax(e_{t,i}) c_t = Σ(α_{t,i} · h_i) 其中: - s_t:解码器在时间 t 的隐藏状态(查询) - h_i:编码器在位置 i 的隐藏状态(键和值) - c_t:上下文向量(加权后的编码器信息) Parameters: ----------- hidden_dim : int 隐藏状态维度(编码器和解码器的隐藏维度必须相同) """ def __init__ (self, hidden_dim ): super (Attention, self).__init__() self.attn = nn.Linear(hidden_dim * 2 , hidden_dim) self.v = nn.Parameter(torch.rand(hidden_dim)) def forward (self, hidden, encoder_outputs ): """ 计算注意力权重 Parameters: ----------- hidden : torch.Tensor, shape (batch_size, hidden_dim) 解码器当前时间步的隐藏状态(查询) - 这是解码器 LSTM 的隐藏状态,表示"当前解码状态" encoder_outputs : torch.Tensor, shape (batch_size, seq_len, hidden_dim) 编码器所有时间步的隐藏状态序列(键和值) - encoder_outputs[:, i, :]:编码器在位置 i 的隐藏状态 Returns: -------- attention_weights : torch.Tensor, shape (batch_size, seq_len) 注意力权重:每个编码器位置的权重(归一化后,每行和为 1) - attention_weights[i, j]:样本 i 的解码器对编码器位置 j 的注意力权重 """ timestep = encoder_outputs.size(1 ) h = hidden.repeat(timestep, 1 , 1 ).transpose(0 , 1 ) energy = torch.tanh(self.attn(torch.cat((h, encoder_outputs), 2 ))) energy = energy.transpose(2 , 1 ) v = self.v.repeat(encoder_outputs.size(0 ), 1 ).unsqueeze(1 ) attention_weights = torch.bmm(v, energy).squeeze(1 ) return torch.softmax(attention_weights, dim=1 ) class Seq2SeqWithAttention (nn.Module): """ Seq2Seq 模型 + 注意力机制:用于序列到序列任务(如时间序列预测、机器翻译) 架构: 1. 编码器( Encoder): LSTM 处理输入序列,输出所有时间步的隐藏状态 2. 注意力机制( Attention):计算解码器隐藏状态与编码器隐藏状态的注意力权重 3. 解码器( Decoder): LSTM 生成输出序列,每个时间步使用注意力加权的上下文向量 数学流程: - 编码: h_i = LSTM_encoder(x_i, h_{i-1}), i=1,...,n - 解码(时间 t): 1. α_{t,i} = Attention(s_{t-1}, {h_i}) 2. c_t = Σ(α_{t,i} · h_i) 3. s_t = LSTM_decoder([y_{t-1}; c_t], s_{t-1}) 4. y_t = Linear([s_t; c_t]) Parameters: ----------- input_dim : int 输入特征维度(编码器输入) 例如:时间序列的每个时间步的特征数 hidden_dim : int 隐藏状态维度(编码器和解码器共享) 典型值: 64, 128, 256, 512 output_dim : int 输出特征维度(解码器输出) 例如:预测序列的每个时间步的特征数 """ def __init__ (self, input_dim, hidden_dim, output_dim ): super (Seq2SeqWithAttention, self).__init__() self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True ) self.decoder = nn.LSTM(hidden_dim + output_dim, hidden_dim, batch_first=True ) self.attention = Attention(hidden_dim) self.fc = nn.Linear(hidden_dim * 2 , output_dim) def forward (self, src, trg ): """ 前向传播:编码-注意力-解码 Parameters: ----------- src : torch.Tensor, shape (batch_size, src_seq_len, input_dim) 源序列(输入序列) 例如:历史时间序列数据 trg : torch.Tensor, shape (batch_size, trg_seq_len, output_dim) 目标序列(用于 Teacher Forcing 训练) 例如:未来时间序列数据(训练时)或初始值(推理时) Returns: -------- outputs : torch.Tensor, shape (batch_size, trg_seq_len, output_dim) 预测的输出序列 - outputs[:, 0, :]:通常是 0 或初始值(不使用) - outputs[:, 1:, :]:实际的预测序列 """ encoder_outputs, (hidden, cell) = self.encoder(src) hidden = hidden.squeeze(0 ) cell = cell.squeeze(0 ) outputs = torch.zeros(trg.size(0 ), trg.size(1 ), trg.size(2 )).to(trg.device) input = trg[:, 0 , :] for t in range (1 , trg.size(1 )): attention_weights = self.attention(hidden, encoder_outputs) context = attention_weights.unsqueeze(1 ).bmm(encoder_outputs).squeeze(1 ) rnn_input = torch.cat((input , context), dim=1 ).unsqueeze(1 ) output, (hidden, cell) = self.decoder(rnn_input, (hidden.unsqueeze(0 ), cell.unsqueeze(0 ))) hidden = hidden.squeeze(0 ) cell = cell.squeeze(0 ) output = self.fc(torch.cat((output.squeeze(1 ), context), dim=1 )) outputs[:, t, :] = output input = trg[:, t, :] return outputs input_dim = 10 hidden_dim = 64 output_dim = 10 batch_size = 32 src_seq_len = 15 trg_seq_len = 20 src = torch.rand(batch_size, src_seq_len, input_dim) trg = torch.rand(batch_size, trg_seq_len, output_dim) model = Seq2SeqWithAttention(input_dim, hidden_dim, output_dim) outputs = model(src, trg) print (f"输入形状: src={src.shape} " )print (f"目标形状: trg={trg.shape} " )print (f"输出形状: outputs={outputs.shape} " )criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001 ) optimizer.zero_grad() outputs = model(src, trg) loss = criterion(outputs[:, 1 :, :], trg[:, 1 :, :]) loss.backward() optimizer.step() print (f"训练损失: {loss.item():.4 f} " )
关键点解读 :
注意力机制解决信息瓶颈 :传统 Seq2Seq
模型将整个输入序列压缩为固定长度的上下文向量(通常是编码器最后一个隐藏状态),这导致长序列信息丢失。注意力机制通过让解码器在每个时间步动态关注编码器的所有位置,解决了这个问题。上下文向量$
c_t = i {t,i} h_i$
包含"当前时刻最需要的信息",而不是固定的压缩表示。
Teacher Forcing vs Free Running :训练时使用
Teacher
Forcing(使用真实目标序列作为解码器输入),这加速训练并提高稳定性。推理时使用
Free
Running(使用模型自身输出),这更接近实际应用场景。代码中input = trg[:, t, :]是
Teacher Forcing 模式,推理时应改为input = output。
上下文向量的双重使用 :上下文向量
在两个地方使用: 1)与解码器输入拼接作为 LSTM
输入([y_{t-1}; c_t]),让解码器知道"应该关注编码器的哪些信息";
2)与解码器输出拼接后通过线性层生成最终输出([s_t; c_t]),让输出包含"注意力加权的编码器信息"。这种设计使得模型能够充分利用注意力机制。
常见问题 :
问题
解答
Q: 为什么需要注意力机制?
A: 传统 Seq2Seq
模型存在信息瓶颈:固定长度上下文向量无法承载长序列的所有信息。注意力机制让解码器动态关注编码器的所有位置,解决了这个问题。
Q: Bahdanau Attention vs Luong Attention?
A: Bahdanau(这里实现)使用 MLP 计算得分, Luong 使用点积或双线性。
Bahdanau 更灵活但计算更慢, Luong
更快但表达能力稍弱。实际应用中两者性能相近。
Q: 如何可视化注意力权重?
A:
注意力权重矩阵可视化:横轴是编码器位置,纵轴是解码器时间步。颜色深浅表示权重大小,可以用于可解释性分析。
Q: 训练和推理有什么区别?
A: 训练时使用 Teacher Forcing(真实目标序列),推理时使用 Free
Running(模型自身输出)。推理时第一个时间步需要提供初始值或使用特殊标记。
Q: 如何处理变长序列?
A: 使用 padding 和 masking 。 padding 到相同长度, masking 屏蔽
padding 位置的注意力权重(设为-∞)。 PyTorch 的 pack_padded_sequence
可以处理变长序列。
使用示例 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoader, TensorDatasetmodel = Seq2SeqWithAttention(input_dim=5 , hidden_dim=128 , output_dim=5 ) n_samples = 1000 src_data = torch.randn(n_samples, 20 , 5 ) trg_data = torch.randn(n_samples, 10 , 5 ) dataset = TensorDataset(src_data, trg_data) dataloader = DataLoader(dataset, batch_size=32 , shuffle=True ) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001 ) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min' , factor=0.5 , patience=10 ) model.train() for epoch in range (100 ): total_loss = 0 for src, trg in dataloader: optimizer.zero_grad() trg_input = torch.cat([torch.zeros(trg.size(0 ), 1 , trg.size(2 )), trg], dim=1 ) outputs = model(src, trg_input) loss = criterion(outputs[:, 1 :, :], trg) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) optimizer.step() total_loss += loss.item() avg_loss = total_loss / len (dataloader) scheduler.step(avg_loss) if (epoch + 1 ) % 10 == 0 : print (f"Epoch {epoch+1 } , Loss: {avg_loss:.4 f} , LR: {optimizer.param_groups[0 ]['lr' ]:.6 f} " ) model.eval () with torch.no_grad(): src_sample = src_data[0 :1 ] trg_init = torch.zeros(1 , 1 , 5 ) trg_pred = trg_init.clone() encoder_outputs, (hidden, cell) = model.encoder(src_sample) hidden = hidden.squeeze(0 ) cell = cell.squeeze(0 ) predictions = [] input = trg_init[:, 0 , :] for t in range (10 ): attention_weights = model.attention(hidden, encoder_outputs) context = attention_weights.unsqueeze(1 ).bmm(encoder_outputs).squeeze(1 ) rnn_input = torch.cat((input , context), dim=1 ).unsqueeze(1 ) output, (hidden, cell) = model.decoder(rnn_input, (hidden.unsqueeze(0 ), cell.unsqueeze(0 ))) hidden = hidden.squeeze(0 ) cell = cell.squeeze(0 ) output = model.fc(torch.cat((output.squeeze(1 ), context), dim=1 )) predictions.append(output) input = output predictions = torch.stack(predictions, dim=1 ) print (f"预测形状: {predictions.shape} " )
❓ Q&A: Attention 常见疑问
位置编码( Positional
Encoding)深度解析
为什么需要位置编码?
核心问题 :自注意力机制是排列不变的 (
Permutation Invariant)
自注意力只计算词与词之间的相似度,完全不考虑位置信息。这意味着:
"我爱你" 和 "你爱我" 会被视为相同
"猫吃鱼" 和 "鱼吃猫" 会被视为相同
这在自然语言中是不可接受的,因为词序决定语义 。
Q1:什么是位置编码(
Positional Encoding),为什么需要它?
核心问题 :自注意力机制是排列不变的 (
Permutation Invariant)
想象一下,如果你把句子"我爱你"打乱成"爱你我"或"你我爱",自注意力会给出完全相同的输出 !因为它只计算词与词之间的相似度,不关心词的位置顺序 。
正弦/余弦位置编码( Sinusoidal PE) :
为什么选择正弦/余弦?
固定长度 :不需要训练,可以外推到更长序列
相对位置信息 : 可以表示为 的线性组合
Python 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 import torchimport torch.nn as nnimport mathclass PositionalEncoding (nn.Module): def __init__ (self, d_model, max_len=5000 ): super ().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0 , max_len, dtype=torch.float ).unsqueeze(1 ) div_term = torch.exp(torch.arange(0 , d_model, 2 ).float () * (-math.log(10000.0 ) / d_model)) pe[:, 0 ::2 ] = torch.sin(position * div_term) pe[:, 1 ::2 ] = torch.cos(position * div_term) pe = pe.unsqueeze(0 ) self.register_buffer('pe' , pe) def forward (self, x ): return x + self.pe[:, :x.size(1 ), :] pos_encoder = PositionalEncoding(d_model=512 ) x = torch.randn(32 , 100 , 512 ) x_encoded = pos_encoder(x)
可学习位置编码( Learned Positional Encoding) :
1 2 3 4 5 6 7 8 9 class LearnedPositionalEncoding (nn.Module): def __init__ (self, d_model, max_len=5000 ): super ().__init__() self.pos_embedding = nn.Embedding(max_len, d_model) def forward (self, x ): seq_len = x.size(1 ) positions = torch.arange(seq_len, device=x.device).unsqueeze(0 ) return x + self.pos_embedding(positions)
位置编码对比 :
类型
优点
缺点
适用场景
正弦/余弦
可外推,无需训练
固定模式
大多数场景
可学习
自适应
无法外推
固定长度序列
正弦位置编码的数学性质 :
正弦位置编码的一个重要性质是相对位置可表示性 :
$$
PE_{pos+k} = PE_{pos} M_k $$
其中
是一个旋转矩阵。这意味着:
位置 的编码可以表示为位置 编码的线性变换
模型可以学习相对位置关系,而不仅仅是绝对位置
位置编码的维度选择 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 def analyze_positional_encoding (d_model=512 , max_len=5000 ): """分析位置编码的性质""" pe = PositionalEncoding(d_model, max_len) fig, axes = plt.subplots(2 , 2 , figsize=(15 , 10 )) pos_encoding = pe.pe[0 , :100 , :].numpy() axes[0 , 0 ].imshow(pos_encoding.T, cmap='coolwarm' , aspect='auto' ) axes[0 , 0 ].set_title('位置编码热力图(前 100 位置)' ) axes[0 , 0 ].set_xlabel('位置' ) axes[0 , 0 ].set_ylabel('维度' ) axes[0 , 1 ].plot(pos_encoding[:, :32 ]) axes[0 , 1 ].set_title('不同频率的正弦/余弦编码' ) axes[0 , 1 ].set_xlabel('位置' ) axes[0 , 1 ].set_ylabel('编码值' ) pos_0 = pe.pe[0 , 0 , :].numpy() pos_10 = pe.pe[0 , 10 , :].numpy() pos_20 = pe.pe[0 , 20 , :].numpy() axes[1 , 0 ].plot(pos_0[:64 ], label='位置 0' ) axes[1 , 0 ].plot(pos_10[:64 ], label='位置 10' ) axes[1 , 0 ].plot(pos_20[:64 ], label='位置 20' ) axes[1 , 0 ].set_title('不同位置的编码模式' ) axes[1 , 0 ].legend() similarity = torch.matmul(pe.pe[0 , :50 , :], pe.pe[0 , :50 , :].T) axes[1 , 1 ].imshow(similarity.numpy(), cmap='YlOrRd' ) axes[1 , 1 ].set_title('位置编码相似度矩阵' ) axes[1 , 1 ].set_xlabel('位置 i' ) axes[1 , 1 ].set_ylabel('位置 j' ) plt.tight_layout() plt.show() analyze_positional_encoding(d_model=512 , max_len=5000 )
可学习位置编码的改进 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class LearnedPositionalEncoding (nn.Module): """可学习位置编码(带初始化策略)""" def __init__ (self, d_model, max_len=5000 ): super ().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0 , max_len, dtype=torch.float ).unsqueeze(1 ) div_term = torch.exp(torch.arange(0 , d_model, 2 ).float () * (-math.log(10000.0 ) / d_model)) pe[:, 0 ::2 ] = torch.sin(position * div_term) pe[:, 1 ::2 ] = torch.cos(position * div_term) self.pos_embedding = nn.Embedding(max_len, d_model) self.pos_embedding.weight.data = pe.unsqueeze(0 ) self.pos_embedding.weight.requires_grad = True def forward (self, x ): seq_len = x.size(1 ) positions = torch.arange(seq_len, device=x.device).unsqueeze(0 ) return x + self.pos_embedding(positions)
相对位置编码( Relative Position Encoding) :
除了绝对位置,还可以编码相对位置:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class RelativePositionalEncoding (nn.Module): """相对位置编码""" def __init__ (self, d_model, max_rel_pos=128 ): super ().__init__() self.max_rel_pos = max_rel_pos self.rel_pos_embedding = nn.Embedding(2 * max_rel_pos + 1 , d_model) def forward (self, q, k ): """ q: [batch, seq_len_q, d_model] k: [batch, seq_len_k, d_model] """ seq_len_q = q.size(1 ) seq_len_k = k.size(1 ) rel_pos = torch.arange(seq_len_k, device=q.device).unsqueeze(0 ) - \ torch.arange(seq_len_q, device=q.device).unsqueeze(1 ) rel_pos = torch.clamp(rel_pos, -self.max_rel_pos, self.max_rel_pos) rel_pos = rel_pos + self.max_rel_pos rel_pos_emb = self.rel_pos_embedding(rel_pos) return rel_pos_emb
Multi-Head Attention 详解
为什么需要多头?
单头注意力只能学习一种表示模式,而多头注意力允许模型同时关注不同类型的依赖关系。就像人类理解句子时,会同时关注语法结构、语义关系、情感色彩等多个维度。
多头注意力的数学原理
完整公式 :
其中:
是头的数量
是第 个头的投影矩阵
是输出投影矩阵
是每个头的维度
关键设计 :
参数共享 :所有头共享输入 ,但使用不同的投影矩阵
独立计算 :每个头独立计算注意力,学习不同的表示子空间
拼接融合 :所有头的输出拼接后通过 投影回原始维度
每个头学习什么?
通过可视化不同头的注意力权重,可以发现:
头类型
关注模式
典型应用
局部头
关注相邻位置(对角线附近权重高)
捕捉局部依赖、短语结构
全局头
均匀关注所有位置(权重分布均匀)
捕捉全局语义、文档级信息
特定位置头
只关注少数关键位置(权重高度集中)
捕捉关键词、重要实体
句法头
关注句法相关位置(如主谓宾关系)
理解语法结构
语义头
关注语义相似的位置
理解同义词、语义关系
实验证据 :
在 BERT 等模型中,研究发现:
头 1-2:主要关注局部依赖(相邻词)
头 3-4:关注句法结构(主谓宾)
头 5-6:关注语义关系(同义词、反义词)
头 7-8:关注全局信息(文档级特征)
Q2:在多头注意力机制(
Multi-Head Attention)中,每个头( head)是如何独立工作的?
核心思想 :不同的头关注不同的特征
多头注意力的优势 :
每个头独立学习不同的表示子空间
头 1 可能关注局部依赖 (相邻词)
头 2 可能关注长距离依赖 (句子头尾)
头 3 可能关注句法结构 (主谓宾)
数学公式 :
Python 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 import torchimport torch.nn as nnimport mathclass MultiHeadAttention (nn.Module): def __init__ (self, d_model, num_heads ): super ().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward (self, Q, K, V, mask=None ): batch_size = Q.size(0 ) Q = self.W_q(Q).view(batch_size, -1 , self.num_heads, self.d_k).transpose(1 , 2 ) K = self.W_k(K).view(batch_size, -1 , self.num_heads, self.d_k).transpose(1 , 2 ) V = self.W_v(V).view(batch_size, -1 , self.num_heads, self.d_k).transpose(1 , 2 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self.d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) attention_weights = torch.softmax(scores, dim=-1 ) head_outputs = torch.matmul(attention_weights, V) concat = head_outputs.transpose(1 , 2 ).contiguous().view( batch_size, -1 , self.d_model ) output = self.W_o(concat) return output, attention_weights mha = MultiHeadAttention(d_model=512 , num_heads=8 ) x = torch.randn(32 , 100 , 512 ) output, attn_weights = mha(x, x, x)
可视化不同头的注意力模式 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 import matplotlib.pyplot as pltimport seaborn as snsdef visualize_multi_head_attention (attention_weights, head_idx=0 ): """可视化特定头的注意力权重""" attn = attention_weights[0 , head_idx].detach().numpy() plt.figure(figsize=(10 , 8 )) sns.heatmap(attn, cmap='YlOrRd' , annot=False ) plt.title(f'Head {head_idx} 的注意力权重' ) plt.xlabel('Key 位置' ) plt.ylabel('Query 位置' ) plt.show() for i in range (8 ): visualize_multi_head_attention(attn_weights, head_idx=i)
头数选择指南 :
模型大小
推荐头数
原因
小模型 ( d_model=128)
4-8
平衡表达能力和计算成本
中等模型 ( d_model=512)
8-16
标准配置
大模型 ( d_model=1024)
16-32
需要更多表示子空间
Q3:如何使用掩码(
Mask)来处理变长序列?
掩码的三种类型 :
1. 填充掩码( Padding Mask) :
2. 因果掩码( Causal Mask / Look-Ahead Mask) :
作用 :防止解码器在生成第 个 token 时看到未来的
token
3. 组合掩码 :
编码器:只用填充掩码
解码器:填充掩码 + 因果掩码
Python 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 import torchimport torch.nn as nnimport numpy as npdef create_padding_mask (seq, pad_token=0 ): """创建填充掩码""" mask = (seq != pad_token).unsqueeze(1 ).unsqueeze(2 ) return mask.float () def create_causal_mask (seq_len ): """创建因果掩码(下三角矩阵)""" mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1 ) mask = mask.masked_fill(mask == 1 , float ('-inf' )) mask = mask.masked_fill(mask == 0 , float (0.0 )) return mask def create_combined_mask (seq, pad_token=0 ): """组合填充掩码和因果掩码""" seq_len = seq.size(1 ) padding_mask = create_padding_mask(seq, pad_token) causal_mask = create_causal_mask(seq_len) combined_mask = padding_mask + causal_mask.unsqueeze(0 ).unsqueeze(0 ) combined_mask = combined_mask.masked_fill(combined_mask > 0 , float ('-inf' )) return combined_mask seq = torch.tensor([[1 , 2 , 3 , 0 , 0 ], [4 , 5 , 0 , 0 , 0 ]]) mask = create_combined_mask(seq, pad_token=0 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(d_k) scores = scores + mask attention_weights = torch.softmax(scores, dim=-1 )
掩码可视化 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import matplotlib.pyplot as pltdef visualize_mask (mask, title='Mask' ): """可视化掩码矩阵""" plt.figure(figsize=(8 , 8 )) plt.imshow(mask[0 , 0 ].detach().numpy(), cmap='gray' ) plt.title(title) plt.xlabel('Key 位置' ) plt.ylabel('Query 位置' ) plt.colorbar() plt.show() padding_mask = create_padding_mask(seq) causal_mask = create_causal_mask(seq_len=10 ) combined_mask = create_combined_mask(seq) visualize_mask(padding_mask, '填充掩码' ) visualize_mask(causal_mask.unsqueeze(0 ).unsqueeze(0 ), '因果掩码' ) visualize_mask(combined_mask, '组合掩码' )
维度
RNN/LSTM/GRU
Transformer
并行计算
❌ 顺序计算
✅ 全并行
长距离依赖
⚠️ 梯度消失/爆炸
✅ 直接连接( O(1) 路径长度)
训练速度
慢(序列越长越慢)
快(序列长度不影响并行度)
内存占用
中等
高( 注意力矩阵)
可解释性
差(隐藏状态黑盒)
✅ 好(注意力权重可视化)
性能对比实验 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 import timeimport torchimport torch.nn as nndef benchmark_training (model, data_loader, n_epochs=10 ): model.train() optimizer = torch.optim.Adam(model.parameters()) criterion = nn.MSELoss() start_time = time.time() for epoch in range (n_epochs): for batch in data_loader: optimizer.zero_grad() output = model(batch[0 ]) loss = criterion(output, batch[1 ]) loss.backward() optimizer.step() elapsed = time.time() - start_time return elapsed / n_epochs rnn_model = nn.LSTM(input_size=10 , hidden_size=128 , num_layers=2 ) rnn_time = benchmark_training(rnn_model, data_loader) transformer_model = TransformerModel(...) transformer_time = benchmark_training(transformer_model, data_loader) print (f'RNN 训练时间: {rnn_time:.2 f} s/epoch' )print (f'Transformer 训练时间: {transformer_time:.2 f} s/epoch' )print (f'加速比: {rnn_time/transformer_time:.2 f} x' )
Q5: Attention
机制在时间序列预测中如何应用?
时间序列 Attention 的特殊性 :
时间序列数据与 NLP 不同:
时间顺序很重要 :需要因果掩码
周期性模式 :某些时间步可能更重要
多变量 :需要处理多个特征
时间序列 Transformer 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 import torchimport torch.nn as nnimport mathclass TimeSeriesTransformer (nn.Module): def __init__ (self, input_size, d_model, nhead, num_layers, output_size ): super ().__init__() self.input_projection = nn.Linear(input_size, d_model) self.pos_encoder = PositionalEncoding(d_model) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4 , dropout=0.1 , batch_first=True ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_projection = nn.Linear(d_model, output_size) def forward (self, x, mask=None ): x = self.input_projection(x) x = self.pos_encoder(x) if mask is None : seq_len = x.size(1 ) mask = self.generate_causal_mask(seq_len, x.device) x = self.transformer(x, mask=mask) x = self.output_projection(x[:, -1 , :]) return x def generate_causal_mask (self, sz, device ): """生成因果掩码""" mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1 ) mask = mask.masked_fill(mask == 1 , float ('-inf' )) return mask model = TimeSeriesTransformer( input_size=10 , d_model=128 , nhead=8 , num_layers=3 , output_size=1 ) x = torch.randn(32 , 50 , 10 ) pred = model(x)
Informer:针对长序列的改进 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 class ProbSparseAttention (nn.Module): """ProbSparse Attention:只计算重要的注意力对""" def __init__ (self, d_model, nhead, factor=5 ): super ().__init__() self.d_model = d_model self.nhead = nhead self.factor = factor def forward (self, Q, K, V ): B, L_Q, _ = Q.size() L_K = K.size(1 ) sample_size = max (L_Q // self.factor, 1 ) sample_indices = self.sample_queries(Q, K, sample_size) Q_sampled = Q[:, sample_indices, :] scores = torch.matmul(Q_sampled, K.transpose(-2 , -1 )) / math.sqrt(self.d_model) attention_weights = torch.softmax(scores, dim=-1 ) output = torch.matmul(attention_weights, V) return output, attention_weights def sample_queries (self, Q, K, sample_size ): """采样重要的 Query""" scores = torch.matmul(Q, K.transpose(-2 , -1 )) importance = scores.max (dim=-1 )[0 ] - scores.mean(dim=-1 ) _, indices = torch.topk(importance, sample_size, dim=-1 ) return indices
时间序列 Attention 可视化 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 import matplotlib.pyplot as pltimport seaborn as snsdef visualize_timeseries_attention (model, x, timestamps ): """可视化时间序列的注意力权重""" model.eval () with torch.no_grad(): _, attn_weights = model.get_attention_weights(x) attn = attn_weights[0 , 0 ].detach().numpy() plt.figure(figsize=(15 , 10 )) sns.heatmap(attn, cmap='YlOrRd' , xticklabels=timestamps, yticklabels=timestamps) plt.title('时间序列注意力权重热力图' ) plt.xlabel('Key 时间步' ) plt.ylabel('Query 时间步' ) plt.show() attention_sums = attn.sum (axis=0 ) important_timesteps = np.argsort(attention_sums)[-10 :] print (f'最重要的时间步: {important_timesteps} ' )
问题分析 :
标准自注意力的计算复杂度:
时间复杂度 : ,其中 是序列长度, 是特征维度
空间复杂度 : (注意力矩阵)
当序列长度 时,注意力矩阵大小为 个元素 当 时,矩阵大小为
个元素!
解决方案 1:稀疏注意力( Sparse Attention) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class SparseAttention (nn.Module): """只计算局部窗口内的注意力""" def __init__ (self, d_model, nhead, window_size=50 ): super ().__init__() self.window_size = window_size self.attention = nn.MultiheadAttention(d_model, nhead) def forward (self, x ): batch_size, seq_len, d_model = x.size() num_windows = (seq_len + self.window_size - 1 ) // self.window_size outputs = [] for i in range (num_windows): start = i * self.window_size end = min ((i + 1 ) * self.window_size, seq_len) window_x = x[:, start:end, :] window_out, _ = self.attention(window_x, window_x, window_x) outputs.append(window_out) return torch.cat(outputs, dim=1 )
解决方案 2:线性注意力( Linear Attention) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class LinearAttention (nn.Module): """线性复杂度注意力: O(n) 而非 O(n^2)""" def __init__ (self, d_model ): super ().__init__() self.d_model = d_model self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) def forward (self, x ): Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) K = torch.nn.functional.elu(K) + 1 Q = torch.nn.functional.elu(Q) + 1 KV = torch.einsum('bnd,bne->bde' , K, V) QKV = torch.einsum('bnd,bde->bne' , Q, KV) normalizer = torch.einsum('bnd,bd->bn' , Q, K.sum (dim=1 )) output = QKV / (normalizer.unsqueeze(-1 ) + 1e-6 ) return output
解决方案 3: Performer(随机特征) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 class PerformerAttention (nn.Module): """Performer:使用随机特征近似注意力""" def __init__ (self, d_model, n_random_features=64 ): super ().__init__() self.d_model = d_model self.n_random_features = n_random_features self.random_features = nn.Parameter( torch.randn(n_random_features, d_model) ) def forward (self, Q, K, V ): batch_size, seq_len, _ = Q.size() Q_features = self.random_feature_map(Q) K_features = self.random_feature_map(K) QK_features = torch.einsum('bnm,bkm->bnk' , Q_features, K_features) attention_weights = QK_features / math.sqrt(self.d_model) output = torch.einsum('bnk,bkd->bnd' , attention_weights, V) return output def random_feature_map (self, x ): """随机特征映射""" projection = torch.matmul(x, self.random_features.t()) return torch.nn.functional.relu(projection)
性能对比 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 import timedef benchmark_attention (attention_module, x, n_iter=100 ): """对比不同注意力机制的速度""" attention_module.eval () start = time.time() with torch.no_grad(): for _ in range (n_iter): _ = attention_module(x) elapsed = time.time() - start return elapsed / n_iter * 1000 x = torch.randn(1 , 1000 , 128 ) standard_attn = StandardAttention(128 ) standard_time = benchmark_attention(standard_attn, x) sparse_attn = SparseAttention(128 , window_size=100 ) sparse_time = benchmark_attention(sparse_attn, x) linear_attn = LinearAttention(128 ) linear_time = benchmark_attention(linear_attn, x) print (f'标准注意力: {standard_time:.2 f} ms' )print (f'稀疏注意力: {sparse_time:.2 f} ms (加速 {standard_time/sparse_time:.2 f} x)' )print (f'线性注意力: {linear_time:.2 f} ms (加速 {standard_time/linear_time:.2 f} x)' )
Q7: Attention
机制的可解释性如何利用?
1. 注意力权重可视化 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import matplotlib.pyplot as pltimport seaborn as snsdef visualize_attention_weights (attention_weights, input_tokens=None , head_idx=0 ): """可视化注意力权重矩阵""" attn = attention_weights[0 , head_idx].detach().numpy() plt.figure(figsize=(12 , 10 )) sns.heatmap( attn, cmap='YlOrRd' , xticklabels=input_tokens, yticklabels=input_tokens, annot=False , fmt='.2f' ) plt.title(f'Attention Head {head_idx} 的权重分布' ) plt.xlabel('Key 位置' ) plt.ylabel('Query 位置' ) plt.tight_layout() plt.show() model.eval () with torch.no_grad(): output, attn_weights = model(x, return_attention=True ) visualize_attention_weights(attn_weights, input_tokens=token_names)
2. 注意力头分析 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def analyze_attention_heads (attention_weights ): """分析不同头的关注模式""" n_heads = attention_weights.size(1 ) fig, axes = plt.subplots(2 , n_heads // 2 , figsize=(20 , 8 )) axes = axes.flatten() for head_idx in range (n_heads): attn = attention_weights[0 , head_idx].detach().numpy() avg_attention = attn.mean(axis=0 ) axes[head_idx].bar(range (len (avg_attention)), avg_attention) axes[head_idx].set_title(f'Head {head_idx} ' ) axes[head_idx].set_xlabel('位置' ) axes[head_idx].set_ylabel('平均注意力权重' ) plt.tight_layout() plt.show()
3. 特征重要性分析 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def compute_feature_importance (model, x, target_idx ): """计算每个输入特征对预测的贡献""" model.eval () _, attn_weights = model.get_attention_weights(x) last_query_attention = attn_weights[0 , :, -1 , :].mean(dim=0 ) plt.figure(figsize=(12 , 6 )) plt.bar(range (len (last_query_attention)), last_query_attention.detach().numpy()) plt.xlabel('输入时间步' ) plt.ylabel('注意力权重' ) plt.title('预测时各时间步的重要性' ) plt.show() return last_query_attention
4. 注意力模式分类 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def classify_attention_patterns (attention_weights ): """识别常见的注意力模式""" patterns = { '局部关注' : [], '全局关注' : [], '特定位置' : [], '周期性' : [] } for head_idx in range (attention_weights.size(1 )): attn = attention_weights[0 , head_idx].detach().numpy() local_weight = np.trace(attn) + np.trace(np.roll(attn, 1 , axis=0 )) entropy = -np.sum (attn * np.log(attn + 1e-10 )) max_entropy = np.log(attn.shape[0 ]) uniformity = entropy / max_entropy if local_weight > 0.5 : patterns['局部关注' ].append(head_idx) elif uniformity > 0.8 : patterns['全局关注' ].append(head_idx) elif attn.max () > 0.7 : patterns['特定位置' ].append(head_idx) else : patterns['周期性' ].append(head_idx) return patterns
5. 注意力引导的特征选择 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def attention_based_feature_selection (model, x, top_k=10 ): """基于注意力权重选择最重要的特征""" model.eval () with torch.no_grad(): _, attn_weights = model.get_attention_weights(x) feature_importance = attn_weights.mean(dim=(0 , 1 , 3 )) top_indices = torch.topk(feature_importance, k=top_k).indices return top_indices, feature_importance important_indices, importance_scores = attention_based_feature_selection(model, x, top_k=10 ) print (f'最重要的时间步: {important_indices} ' )
Self-Attention vs
Cross-Attention 深度对比
核心区别
Self-Attention(自注意力) :
都来自同一个输入序列
计算序列内部的关系
用于编码器( Encoder)或单序列任务
Cross-Attention(交叉注意力) :
来自一个序列, 来自另一个序列
计算两个序列之间的关系
用于解码器( Decoder)或多模态任务
数学公式对比
Self-Attention :
Cross-Attention :
其中 来自序列 A, 来自序列 B 。
代码实现对比
Self-Attention 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class SelfAttention (nn.Module): """自注意力: Q, K, V 都来自同一个输入""" def __init__ (self, d_model, nhead ): super ().__init__() self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True ) def forward (self, x ): """ x: [batch, seq_len, d_model] 返回: [batch, seq_len, d_model] """ out, attn_weights = self.attention(x, x, x) return out, attn_weights text = torch.randn(32 , 100 , 512 ) self_attn = SelfAttention(d_model=512 , nhead=8 ) encoded, attn = self_attn(text)
Cross-Attention 实现 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class CrossAttention (nn.Module): """交叉注意力: Q 来自一个序列, K 和 V 来自另一个序列""" def __init__ (self, d_model, nhead ): super ().__init__() self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True ) def forward (self, query_seq, key_value_seq ): """ query_seq: [batch, seq_len_q, d_model] - 查询序列 key_value_seq: [batch, seq_len_kv, d_model] - 键值序列 返回: [batch, seq_len_q, d_model] """ out, attn_weights = self.attention( query_seq, key_value_seq, key_value_seq ) return out, attn_weights encoder_output = torch.randn(32 , 50 , 512 ) decoder_input = torch.randn(32 , 30 , 512 ) cross_attn = CrossAttention(d_model=512 , nhead=8 ) decoded, attn = cross_attn(decoder_input, encoder_output)
应用场景对比
维度
Self-Attention
Cross-Attention
典型应用
BERT 编码器、 GPT 、图像分类
机器翻译解码器、图像描述生成
输入数量
1 个序列
2 个序列(查询序列+键值序列)
计算关系
序列内部关系
序列间关系
注意力模式
对称矩阵($ n n) 非 对 称 矩 阵 ( n_q n_{kv}$)
可解释性
理解序列内部依赖
理解跨序列对齐关系
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 class TransformerDecoderLayer (nn.Module): """Transformer 解码器层:包含 Self-Attention 和 Cross-Attention""" def __init__ (self, d_model, nhead, dim_feedforward=2048 , dropout=0.1 ): super ().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True ) self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True ) self.ffn = nn.Sequential( nn.Linear(d_model, dim_feedforward), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), nn.Dropout(dropout) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward (self, tgt, memory, tgt_mask=None , memory_mask=None ): """ tgt: [batch, tgt_len, d_model] - 目标序列(解码器输入) memory: [batch, src_len, d_model] - 源序列(编码器输出) """ tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask) tgt = self.norm1(tgt + self.dropout(tgt2)) tgt2, cross_attn_weights = self.cross_attn( tgt, memory, memory, attn_mask=memory_mask ) tgt = self.norm2(tgt + self.dropout(tgt2)) tgt2 = self.ffn(tgt) tgt = self.norm3(tgt + tgt2) return tgt, self_attn_weights, cross_attn_weights encoder_output = torch.randn(32 , 50 , 512 ) decoder_input = torch.randn(32 , 30 , 512 ) decoder_layer = TransformerDecoderLayer(d_model=512 , nhead=8 ) output, self_attn, cross_attn = decoder_layer(decoder_input, encoder_output)
可视化对比
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def visualize_attention_comparison (self_attn_weights, cross_attn_weights ): """可视化 Self-Attention 和 Cross-Attention 的区别""" fig, axes = plt.subplots(1 , 2 , figsize=(20 , 8 )) axes[0 ].imshow(self_attn_weights[0 , 0 ].detach().numpy(), cmap='YlOrRd' ) axes[0 ].set_title('Self-Attention 权重(对称)' ) axes[0 ].set_xlabel('Key 位置(目标序列)' ) axes[0 ].set_ylabel('Query 位置(目标序列)' ) axes[1 ].imshow(cross_attn_weights[0 , 0 ].detach().numpy(), cmap='YlOrRd' ) axes[1 ].set_title('Cross-Attention 权重(非对称)' ) axes[1 ].set_xlabel('Key 位置(源序列)' ) axes[1 ].set_ylabel('Query 位置(目标序列)' ) plt.tight_layout() plt.show() visualize_attention_comparison(self_attn, cross_attn)
选择指南
使用 Self-Attention :
✅ 单序列任务(文本分类、语言模型)
✅ 需要理解序列内部关系
✅ 编码器架构
使用 Cross-Attention :
✅ 序列到序列任务(机器翻译、摘要生成)
✅ 多模态任务(图像描述、视频理解)
✅ 解码器架构
✅ 需要跨序列对齐
Q8:如何将 Attention
机制与 LSTM/GRU 结合?
1. Attention-LSTM 架构 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 import torchimport torch.nn as nnclass AttentionLSTM (nn.Module): """LSTM + Attention 混合模型""" def __init__ (self, input_size, hidden_size, num_layers, nhead ): super ().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True ) self.attention = nn.MultiheadAttention(hidden_size, nhead, batch_first=True ) self.fc = nn.Linear(hidden_size, 1 ) def forward (self, x ): lstm_out, _ = self.lstm(x) attn_out, attn_weights = self.attention( lstm_out, lstm_out, lstm_out ) combined = lstm_out + attn_out output = self.fc(combined[:, -1 , :]) return output, attn_weights model = AttentionLSTM(input_size=10 , hidden_size=64 , num_layers=2 , nhead=4 ) x = torch.randn(32 , 50 , 10 ) pred, attn = model(x)
2. 双向 Attention-GRU :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 class BidirectionalAttentionGRU (nn.Module): """双向 GRU + Attention""" def __init__ (self, input_size, hidden_size, num_layers ): super ().__init__() self.gru = nn.GRU( input_size, hidden_size, num_layers, batch_first=True , bidirectional=True ) self.attention = nn.MultiheadAttention( hidden_size * 2 , num_heads=8 , batch_first=True ) self.fc = nn.Linear(hidden_size * 2 , 1 ) def forward (self, x ): gru_out, _ = self.gru(x) attn_out, attn_weights = self.attention(gru_out, gru_out, gru_out) weights = torch.softmax(attn_out.sum (dim=-1 ), dim=1 ).unsqueeze(-1 ) weighted_out = (attn_out * weights).sum (dim=1 ) output = self.fc(weighted_out) return output, attn_weights
3. Hierarchical Attention(层次注意力) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class HierarchicalAttention (nn.Module): """层次注意力:先局部后全局""" def __init__ (self, input_size, hidden_size ): super ().__init__() self.local_attention = nn.MultiheadAttention( input_size, num_heads=4 , batch_first=True ) self.gru = nn.GRU(input_size, hidden_size, batch_first=True ) self.global_attention = nn.MultiheadAttention( hidden_size, num_heads=8 , batch_first=True ) self.fc = nn.Linear(hidden_size, 1 ) def forward (self, x, window_size=10 ): batch_size, seq_len, _ = x.size() local_outputs = [] for i in range (0 , seq_len, window_size): window = x[:, i:i+window_size, :] local_out, _ = self.local_attention(window, window, window) local_outputs.append(local_out) local_combined = torch.cat(local_outputs, dim=1 ) gru_out, _ = self.gru(local_combined) global_out, attn_weights = self.global_attention( gru_out, gru_out, gru_out ) output = self.fc(global_out[:, -1 , :]) return output, attn_weights
4. 对比实验 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def compare_models (models_dict, X_test, y_test ): """对比不同架构的性能""" results = {} for name, model in models_dict.items(): model.eval () with torch.no_grad(): pred = model(X_test) mae = torch.mean(torch.abs (pred - y_test)).item() rmse = torch.sqrt(torch.mean((pred - y_test) ** 2 )).item() results[name] = {'MAE' : mae, 'RMSE' : rmse} return results models = { 'LSTM' : LSTMOnlyModel(...), 'GRU' : GRUOnlyModel(...), 'LSTM+Attention' : AttentionLSTM(...), 'GRU+Attention' : BidirectionalAttentionGRU(...), 'Transformer' : TransformerModel(...) } results = compare_models(models, X_test, y_test) for name, metrics in results.items(): print (f'{name} : MAE={metrics["MAE" ]:.4 f} , RMSE={metrics["RMSE" ]:.4 f} ' )
Q9: Attention
机制在时间序列异常检测中的应用?
异常检测的 Attention 模式 :
异常点通常表现出: 1.
注意力权重异常 :与其他时间步的关联度低 2.
重构误差大 : Attention 无法很好地重构异常点 3.
注意力分布异常 :注意力权重分布与正常模式不同
基于 Attention 的异常检测模型 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 class AttentionAnomalyDetector (nn.Module): """使用 Attention 进行异常检测""" def __init__ (self, input_size, d_model, nhead ): super ().__init__() self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model, nhead, batch_first=True ), num_layers=3 ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model, nhead, batch_first=True ), num_layers=3 ) self.input_proj = nn.Linear(input_size, d_model) self.output_proj = nn.Linear(d_model, input_size) def forward (self, x ): x_proj = self.input_proj(x) encoded = self.encoder(x_proj) decoded = self.decoder(x_proj, encoded) reconstructed = self.output_proj(decoded) reconstruction_error = torch.mean((x - reconstructed) ** 2 , dim=-1 ) return reconstructed, reconstruction_error, encoded model = AttentionAnomalyDetector(input_size=10 , d_model=128 , nhead=8 ) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters()) for epoch in range (100 ): for batch in train_loader: reconstructed, error, _ = model(batch) loss = criterion(reconstructed, batch) loss.backward() optimizer.step()
异常评分 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 def compute_anomaly_score (model, x, threshold_percentile=95 ): """计算异常分数""" model.eval () with torch.no_grad(): reconstructed, reconstruction_error, encoded = model(x) error_scores = reconstruction_error.mean(dim=1 ) threshold = np.percentile(error_scores.numpy(), threshold_percentile) anomalies_by_error = error_scores > threshold attention_entropy = compute_attention_entropy(model, x) entropy_threshold = np.percentile(attention_entropy.numpy(), threshold_percentile) anomalies_by_attention = attention_entropy < entropy_threshold final_anomalies = anomalies_by_error | anomalies_by_attention return { 'anomalies' : final_anomalies, 'error_scores' : error_scores, 'attention_entropy' : attention_entropy } def compute_attention_entropy (model, x ): """计算注意力权重的熵(低熵 = 异常)""" attn_weights = model.get_attention_weights(x) entropy = -torch.sum ( attn_weights * torch.log(attn_weights + 1e-10 ), dim=-1 ).mean(dim=(1 , 2 )) return entropy
可视化异常检测结果 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def visualize_anomalies (data, anomalies, attention_weights ): """可视化异常检测结果""" fig, axes = plt.subplots(3 , 1 , figsize=(15 , 10 )) axes[0 ].plot(data, label='原始数据' ) anomaly_indices = np.where(anomalies)[0 ] axes[0 ].scatter(anomaly_indices, data[anomaly_indices], color='red' , s=100 , label='异常点' , zorder=5 ) axes[0 ].set_title('异常检测结果' ) axes[0 ].legend() axes[1 ].plot(reconstruction_error, label='重构误差' ) axes[1 ].axhline(y=threshold, color='r' , linestyle='--' , label='阈值' ) axes[1 ].set_title('重构误差' ) axes[1 ].legend() sns.heatmap(attention_weights[0 , 0 ].detach().numpy(), ax=axes[2 ], cmap='YlOrRd' ) axes[2 ].set_title('注意力权重(异常区域通常权重异常)' ) plt.tight_layout() plt.show()
Q10:如何优化 Attention
机制的计算效率?
1. Flash Attention(内存高效) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 class FlashAttention (nn.Module): """内存高效的注意力计算""" def __init__ (self, d_model, block_size=64 ): super ().__init__() self.d_model = d_model self.block_size = block_size def forward (self, Q, K, V ): batch_size, seq_len, _ = Q.size() output = torch.zeros_like(V) for i in range (0 , seq_len, self.block_size): Q_block = Q[:, i:i+self.block_size, :] scores_block = torch.matmul(Q_block, K.transpose(-2 , -1 )) attn_block = torch.softmax(scores_block, dim=-1 ) output[:, i:i+self.block_size, :] = torch.matmul(attn_block, V) return output
2. 低秩近似( Low-Rank Approximation) :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class LowRankAttention (nn.Module): """使用低秩矩阵近似注意力""" def __init__ (self, d_model, rank=32 ): super ().__init__() self.rank = rank self.W_q = nn.Linear(d_model, rank) self.W_k = nn.Linear(d_model, rank) self.W_v = nn.Linear(d_model, d_model) def forward (self, Q, K, V ): Q_low = self.W_q(Q) K_low = self.W_k(K) attn_low = torch.matmul(Q_low, K_low.transpose(-2 , -1 )) attn_weights = torch.softmax(attn_low, dim=-1 ) output = torch.matmul(attn_weights, self.W_v(V)) return output
3. 局部敏感哈希( LSH) Attention :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class LSHAttention (nn.Module): """使用 LSH 快速找到相似的 Query-Key 对""" def __init__ (self, d_model, num_hashes=4 , bucket_size=64 ): super ().__init__() self.num_hashes = num_hashes self.bucket_size = bucket_size self.random_rotations = nn.Parameter( torch.randn(num_hashes, d_model, d_model) ) def hash_vectors (self, x ): """将向量哈希到桶中""" rotated = torch.matmul(x, self.random_rotations) hashes = rotated.argmax(dim=-1 ) return hashes def forward (self, Q, K, V ): Q_hashes = self.hash_vectors(Q) K_hashes = self.hash_vectors(K) output = self.sparse_attention(Q, K, V, Q_hashes, K_hashes) return output
4. 性能优化技巧总结 :
方法
原理
加速比
适用场景
Flash Attention
分块计算,减少内存
2-4x
长序列(>1000)
稀疏注意力
只计算局部窗口
5-10x
局部依赖强
线性注意力
改变计算顺序
3-5x
中等长度序列
低秩近似
降维投影
2-3x
特征维度大
LSH Attention
哈希快速匹配
10-20x
超长序列(>10000)
实际应用建议 :
1 2 3 4 5 6 7 8 9 10 def get_optimal_attention (seq_len, d_model ): if seq_len < 100 : return StandardAttention(d_model) elif seq_len < 1000 : return SparseAttention(d_model, window_size=100 ) elif seq_len < 10000 : return LinearAttention(d_model) else : return LSHAttention(d_model)
实战技巧与性能优化
Attention 机制的超参数调优
1. 注意力头数( Num Heads)选择
模型维度
推荐头数
原因
d_model = 128
4-8 头
平衡表达能力和计算成本
d_model = 256
8-16 头
标准配置
d_model = 512
8-16 头
Transformer 标准
d_model = 1024
16-32 头
大模型需要更多表示子空间
选择原则 :
必须能被头数整除
每个头的维度 通常设为 64 或 128
头数过多可能导致过拟合,头数过少表达能力不足
1 2 3 4 5 6 7 8 9 10 11 def get_optimal_heads (d_model ): """根据模型维度选择最优头数""" possible_heads = [4 , 8 , 16 , 32 ] for n_heads in possible_heads: if d_model % n_heads == 0 : d_k = d_model // n_heads if 32 <= d_k <= 128 : return n_heads return 8
2. 缩放因子
的重要性
为什么需要缩放?防止点积值过大导致 softmax 梯度消失:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def demonstrate_scaling_importance (): """演示缩放的重要性""" d_k = 64 Q = torch.randn(10 , d_k) K = torch.randn(10 , d_k) scores_unscaled = torch.matmul(Q, K.t()) print (f"未缩放分数范围: [{scores_unscaled.min ():.2 f} , {scores_unscaled.max ():.2 f} ]" ) scores_scaled = scores_unscaled / np.sqrt(d_k) print (f"缩放后分数范围: [{scores_scaled.min ():.2 f} , {scores_scaled.max ():.2 f} ]" ) probs_unscaled = F.softmax(scores_unscaled, dim=-1 ) probs_scaled = F.softmax(scores_scaled, dim=-1 ) print (f"未缩放 softmax 熵: {-torch.sum (probs_unscaled * torch.log(probs_unscaled + 1e-10 )):.4 f} " ) print (f"缩放后 softmax 熵: {-torch.sum (probs_scaled * torch.log(probs_scaled + 1e-10 )):.4 f} " )
3. Dropout 在 Attention 中的应用
1 2 3 4 5 6 7 8 9 10 11 12 class AttentionWithDropout (nn.Module): """带 Dropout 的注意力机制""" def __init__ (self, d_model, nhead, dropout=0.1 ): super ().__init__() self.attention = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=True ) def forward (self, x ): out, attn_weights = self.attention(x, x, x) return out, attn_weights
Dropout 选择指南 :
小数据集: 0.2-0.3
中等数据集: 0.1-0.2
大数据集: 0.05-0.1
时间序列 Attention
的特殊优化
1. 因果掩码的高效实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def create_causal_mask_efficient (seq_len, device ): """高效创建因果掩码(避免存储完整矩阵)""" mask = torch.tril(torch.ones(seq_len, seq_len, device=device)) return mask.bool () def causal_attention (Q, K, V ): """直接应用因果掩码的注意力""" seq_len = Q.size(1 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / np.sqrt(Q.size(-1 )) mask = torch.tril(torch.ones(seq_len, seq_len, device=Q.device)) scores = scores.masked_fill(mask == 0 , float ('-inf' )) attn_weights = F.softmax(scores, dim=-1 ) output = torch.matmul(attn_weights, V) return output, attn_weights
2. 局部注意力( Local Attention)
对于长序列,可以使用局部注意力减少计算量:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class LocalAttention (nn.Module): """局部注意力:只关注窗口内的位置""" def __init__ (self, d_model, nhead, window_size=50 ): super ().__init__() self.window_size = window_size self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True ) def forward (self, x ): batch_size, seq_len, d_model = x.size() if seq_len <= self.window_size: return self.attention(x, x, x) outputs = [] for i in range (0 , seq_len, self.window_size): end = min (i + self.window_size, seq_len) window_x = x[:, i:end, :] window_out, _ = self.attention(window_x, window_x, window_x) outputs.append(window_out) return torch.cat(outputs, dim=1 ), None
3. 稀疏注意力模式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class SparseAttentionPatterns : """常见的稀疏注意力模式""" @staticmethod def strided_pattern (seq_len, stride=2 ): """步长模式:每隔 stride 个位置关注一次""" mask = torch.zeros(seq_len, seq_len) for i in range (seq_len): for j in range (0 , i+1 , stride): mask[i, j] = 1 return mask.bool () @staticmethod def dilated_pattern (seq_len, dilation=2 ): """膨胀模式:关注 dilation 倍数的位置""" mask = torch.zeros(seq_len, seq_len) for i in range (seq_len): for j in range (0 , i+1 ): if j % dilation == 0 or j == i: mask[i, j] = 1 return mask.bool () @staticmethod def random_pattern (seq_len, sparsity=0.5 ): """随机模式:随机选择 sparsity 比例的位置""" mask = torch.rand(seq_len, seq_len) > sparsity mask = mask & torch.tril(torch.ones(seq_len, seq_len)).bool () return mask
常见问题排查
问题 1: Attention 权重过于均匀(没有聚焦)
可能原因:
解决方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def init_attention_weights (m ): if isinstance (m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None : nn.init.constant_(m.bias, 0 ) model.apply(init_attention_weights) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3 ) from scipy import signaldenoised_data = signal.savgol_filter(data, window_length=5 , polyorder=2 )
问题 2: Attention 计算 OOM(内存不足)
解决方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 from torch.utils.checkpoint import checkpointclass MemoryEfficientAttention (nn.Module): def forward (self, x ): return checkpoint(self.attention, x, x, x) def chunked_attention (Q, K, V, chunk_size=32 ): """分块计算注意力,减少内存占用""" batch_size, seq_len, d_model = Q.size() outputs = [] for i in range (0 , seq_len, chunk_size): Q_chunk = Q[:, i:i+chunk_size, :] scores = torch.matmul(Q_chunk, K.transpose(-2 , -1 )) / np.sqrt(d_model) attn_weights = F.softmax(scores, dim=-1 ) output_chunk = torch.matmul(attn_weights, V) outputs.append(output_chunk) return torch.cat(outputs, dim=1 )
问题 3:训练不稳定( Loss 震荡)
可能原因:
解决方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) def get_warmup_lr (step, warmup_steps, d_model ): if step < warmup_steps: return step / warmup_steps else : return (d_model ** -0.5 ) * min (step ** -0.5 , step * (warmup_steps ** -1.5 )) def monitor_attention_distribution (attn_weights ): """监控注意力权重分布""" entropy = -torch.sum (attn_weights * torch.log(attn_weights + 1e-10 ), dim=-1 ) print (f"注意力熵: {entropy.mean():.4 f} ± {entropy.std():.4 f} " ) if entropy.mean() < 1.0 : print ("警告:注意力权重过于集中!" )
Attention 机制的性能基准测试
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 import timeimport torchdef benchmark_attention (attention_module, x, n_iter=100 ): """性能基准测试""" attention_module.eval () with torch.no_grad(): _ = attention_module(x) torch.cuda.synchronize() if torch.cuda.is_available() else None start = time.time() with torch.no_grad(): for _ in range (n_iter): _ = attention_module(x) torch.cuda.synchronize() if torch.cuda.is_available() else None elapsed = time.time() - start batch_size, seq_len, d_model = x.size() flops = n_iter * batch_size * seq_len * seq_len * d_model * 2 gflops = flops / 1e9 / elapsed return elapsed / n_iter * 1000 , gflops x = torch.randn(32 , 100 , 128 ) configs = [ ("标准 Attention" , nn.MultiheadAttention(128 , 8 , batch_first=True )), ("稀疏 Attention" , LocalAttention(128 , 8 , window_size=50 )), ("线性 Attention" , LinearAttention(128 )), ] for name, model in configs: time_ms, gflops = benchmark_attention(model, x) print (f"{name} : {time_ms:.2 f} ms, {gflops:.2 f} GFLOPS" )
Attention 机制的可视化工具
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import matplotlib.pyplot as pltimport seaborn as snsdef visualize_attention_patterns (attn_weights, save_path=None ): """可视化注意力模式""" fig, axes = plt.subplots(2 , 4 , figsize=(20 , 10 )) axes = axes.flatten() for head_idx in range (min (8 , attn_weights.size(1 ))): attn = attn_weights[0 , head_idx].detach().cpu().numpy() sns.heatmap(attn, ax=axes[head_idx], cmap='YlOrRd' , cbar=True ) axes[head_idx].set_title(f'Head {head_idx} ' ) axes[head_idx].set_xlabel('Key Position' ) axes[head_idx].set_ylabel('Query Position' ) plt.tight_layout() if save_path: plt.savefig(save_path) plt.show() def analyze_attention_statistics (attn_weights ): """分析注意力统计信息""" avg_attention = attn_weights.mean(dim=(0 , 1 )) top_k_positions = torch.topk(avg_attention.sum (dim=0 ), k=10 ) print ("最受关注的 10 个位置:" ) for idx, score in zip (top_k_positions.indices, top_k_positions.values): print (f"位置 {idx} : {score:.4 f} " ) entropy = -torch.sum (attn_weights * torch.log(attn_weights + 1e-10 ), dim=-1 ) print (f"\n 注意力熵: {entropy.mean():.4 f} ± {entropy.std():.4 f} " ) return avg_attention, top_k_positions
Attention 可视化与解释性分析
注意力权重可视化工具
1. 热力图可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 import matplotlib.pyplot as pltimport seaborn as snsimport numpy as npdef plot_attention_heatmap (attention_weights, tokens=None , head_idx=0 , save_path=None , figsize=(12 , 10 ) ): """ 绘制注意力权重热力图 Parameters: ----------- attention_weights : torch.Tensor 形状为 [batch, nhead, seq_len_q, seq_len_k] 的注意力权重 tokens : list, optional Token 列表,用于标注坐标轴 head_idx : int 要可视化的头索引 save_path : str, optional 保存路径 """ attn = attention_weights[0 , head_idx].detach().cpu().numpy() plt.figure(figsize=figsize) sns.heatmap( attn, cmap='YlOrRd' , xticklabels=tokens if tokens else False , yticklabels=tokens if tokens else False , annot=False , fmt='.2f' , cbar_kws={'label' : '注意力权重' } ) plt.title(f'Attention Head {head_idx} 权重分布' ) plt.xlabel('Key 位置' ) plt.ylabel('Query 位置' ) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300 , bbox_inches='tight' ) plt.show()
2. 多头注意力对比可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 def plot_multi_head_attention (attention_weights, tokens=None , n_heads_to_show=8 , figsize=(20 , 15 ) ): """ 可视化多个头的注意力模式 Parameters: ----------- attention_weights : torch.Tensor 形状为 [batch, nhead, seq_len_q, seq_len_k] tokens : list, optional Token 列表 n_heads_to_show : int 要显示的头数 """ n_heads = min (attention_weights.size(1 ), n_heads_to_show) n_cols = 4 n_rows = (n_heads + n_cols - 1 ) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) axes = axes.flatten() if n_heads > 1 else [axes] for head_idx in range (n_heads): attn = attention_weights[0 , head_idx].detach().cpu().numpy() sns.heatmap( attn, ax=axes[head_idx], cmap='YlOrRd' , xticklabels=False , yticklabels=False , cbar=True if head_idx == 0 else False ) axes[head_idx].set_title(f'Head {head_idx} ' ) axes[head_idx].set_xlabel('Key 位置' ) axes[head_idx].set_ylabel('Query 位置' ) for idx in range (n_heads, len (axes)): axes[idx].axis('off' ) plt.tight_layout() plt.show()
3. 注意力权重统计分析
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 def analyze_attention_statistics (attention_weights ): """ 分析注意力权重的统计特性 Returns: -------- dict: 包含各种统计信息的字典 """ attn = attention_weights.detach().cpu().numpy() batch_size, n_heads, seq_len_q, seq_len_k = attn.shape stats = { 'entropy' : [], 'max_weight' : [], 'sparsity' : [], 'local_focus' : [] } for head_idx in range (n_heads): head_attn = attn[0 , head_idx] entropy_per_query = [] for q_idx in range (seq_len_q): dist = head_attn[q_idx, :] entropy = -np.sum (dist * np.log(dist + 1e-10 )) entropy_per_query.append(entropy) stats['entropy' ].append(np.mean(entropy_per_query)) stats['max_weight' ].append(head_attn.max ()) stats['sparsity' ].append(np.mean(head_attn < 0.01 )) local_mask = np.abs (np.arange(seq_len_q)[:, None ] - np.arange(seq_len_k)[None , :]) <= 3 stats['local_focus' ].append(np.mean(head_attn[local_mask])) return stats def plot_attention_statistics (stats ): """可视化注意力统计信息""" fig, axes = plt.subplots(2 , 2 , figsize=(15 , 10 )) n_heads = len (stats['entropy' ]) axes[0 , 0 ].bar(range (n_heads), stats['entropy' ]) axes[0 , 0 ].set_title('注意力熵(越高越分散)' ) axes[0 , 0 ].set_xlabel('头索引' ) axes[0 , 0 ].set_ylabel('平均熵' ) axes[0 , 0 ].axhline(y=np.log(n_heads), color='r' , linestyle='--' , label='最大熵(均匀分布)' ) axes[0 , 0 ].legend() axes[0 , 1 ].bar(range (n_heads), stats['max_weight' ]) axes[0 , 1 ].set_title('最大注意力权重' ) axes[0 , 1 ].set_xlabel('头索引' ) axes[0 , 1 ].set_ylabel('最大权重值' ) axes[1 , 0 ].bar(range (n_heads), stats['sparsity' ]) axes[1 , 0 ].set_title('注意力稀疏度(权重<0.01 的比例)' ) axes[1 , 0 ].set_xlabel('头索引' ) axes[1 , 0 ].set_ylabel('稀疏度' ) axes[1 , 1 ].bar(range (n_heads), stats['local_focus' ]) axes[1 , 1 ].set_title('局部关注度(相邻 3 个位置)' ) axes[1 , 1 ].set_xlabel('头索引' ) axes[1 , 1 ].set_ylabel('局部权重比例' ) plt.tight_layout() plt.show()
4. 注意力权重模式分类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 def classify_attention_patterns (attention_weights, threshold=0.1 ): """ 自动分类注意力模式 Returns: -------- dict: 每个头的模式分类 """ attn = attention_weights.detach().cpu().numpy() n_heads = attn.shape[1 ] patterns = { 'local' : [], 'global' : [], 'specific' : [], 'periodic' : [] } for head_idx in range (n_heads): head_attn = attn[0 , head_idx] seq_len = head_attn.shape[0 ] local_mask = np.abs (np.arange(seq_len)[:, None ] - np.arange(seq_len)[None , :]) <= 5 local_ratio = np.mean(head_attn[local_mask]) entropy = -np.sum (head_attn * np.log(head_attn + 1e-10 )) max_entropy = np.log(seq_len) uniformity = entropy / max_entropy max_weight = head_attn.max () if local_ratio > 0.5 : patterns['local' ].append(head_idx) elif uniformity > 0.8 : patterns['global' ].append(head_idx) elif max_weight > 0.7 : patterns['specific' ].append(head_idx) else : patterns['periodic' ].append(head_idx) return patterns
5. 时间序列注意力可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 def plot_timeseries_attention (attention_weights, timestamps, values=None , head_idx=0 , top_k=5 ): """ 可视化时间序列的注意力权重 Parameters: ----------- attention_weights : torch.Tensor 注意力权重 [batch, nhead, seq_len_q, seq_len_k] timestamps : array-like 时间戳列表 values : array-like, optional 时间序列值(用于叠加显示) head_idx : int 要可视化的头索引 top_k : int 显示 top-k 个最受关注的时间步 """ attn = attention_weights[0 , head_idx].detach().cpu().numpy() seq_len = attn.shape[0 ] attention_sums = attn.sum (axis=0 ) top_indices = np.argsort(attention_sums)[-top_k:][::-1 ] fig, axes = plt.subplots(2 , 1 , figsize=(15 , 10 )) if values is not None : axes[0 ].plot(timestamps, values, 'b-' , label='时间序列' , linewidth=2 ) axes[0 ].scatter( timestamps[top_indices], values[top_indices], color='red' , s=200 , zorder=5 , label=f'Top-{top_k} 关注点' ) axes[0 ].set_title('时间序列与注意力关注点' ) axes[0 ].set_xlabel('时间' ) axes[0 ].set_ylabel('值' ) axes[0 ].legend() axes[0 ].grid(True , alpha=0.3 ) im = axes[1 ].imshow(attn, cmap='YlOrRd' , aspect='auto' ) axes[1 ].set_title(f'Attention Head {head_idx} 权重热力图' ) axes[1 ].set_xlabel('Key 时间步' ) axes[1 ].set_ylabel('Query 时间步' ) plt.colorbar(im, ax=axes[1 ], label='注意力权重' ) for idx in top_indices: axes[1 ].axvline(x=idx, color='blue' , linestyle='--' , alpha=0.5 ) plt.tight_layout() plt.show() print (f"Top-{top_k} 最受关注的时间步:" ) for i, idx in enumerate (top_indices): print (f" {i+1 } . 时间步 {idx} (权重: {attention_sums[idx]:.4 f} )" )
注意力解释性分析
1. 特征重要性分析
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def compute_feature_importance (attention_weights, input_features ): """ 计算每个输入特征对预测的贡献 Parameters: ----------- attention_weights : torch.Tensor 注意力权重 [batch, nhead, seq_len_q, seq_len_k] input_features : torch.Tensor 输入特征 [batch, seq_len, d_model] Returns: -------- importance : torch.Tensor 特征重要性 [batch, seq_len] """ last_query_attention = attention_weights[:, :, -1 , :].mean(dim=1 ) importance = last_query_attention.unsqueeze(-1 ) * input_features importance = importance.abs ().sum (dim=-1 ) return importance
2. 注意力引导的特征选择
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def attention_based_feature_selection (attention_weights, top_k=10 ): """ 基于注意力权重选择最重要的特征 Returns: -------- selected_indices : list 选中的特征索引 """ avg_attention = attention_weights.mean(dim=(0 , 1 , 2 )) top_indices = torch.topk(avg_attention, k=top_k).indices.tolist() return top_indices
🎓 总结: Attention 核心要点
自注意力计算流程 : 线 性 变 换 缩 放 点 积 归 一 化 加 权 求 和
记忆口诀 : > Q 问 K 答计算分数,缩放
softmax 归一权重,权重乘 V 得到输出,多头并行捕捉特征!
实战优化清单 :