时间序列模型(四)—— Attention 机制
Chen Kai BOSS

在时间序列里,很多关键信息并不在“最近一步”:可能是周期中的某个相位、某个突发后的回落,或是跨很长间隔的相似模式。 Attention 的好处是它不需要按时间一步步把信息“传”过来,而是直接学会“该看历史里的哪几段、看多少权重”,从而更擅长处理长距离依赖与不规则相关性。本文会把自注意力的计算流程按公式拆开(、缩放点积、 softmax 权重、加权求和),并结合代码层面的实现细节说明:这些矩阵运算到底在做什么、复杂度与序列长度的关系是什么,以及在时间序列任务里如何组织输入、如何解释注意力权重带来的可解释性。

数学原理

自注意力机制通过计算输入序列中每个位置与其他位置之间的相似度来生成新的表示。具体步骤如下:

输入表示:假设输入序列为 ,每个 是一个向量。

线性变换:通过学习的权重矩阵 将输入序列 转换为查询( Query)、键( Key)和值( Value)向量:

$$

Q = XW^Q, K = XW^K, V = XW^V $$

计算注意力得分:通过点积计算查询和键之间的相似度,并使用缩放因子 进行缩放:

归一化注意力得分:使用 softmax 函数对注意力得分进行归一化,得到注意力权重:

加权求和:将注意力权重应用于值向量,得到最终的注意力输出:

代码实现

缩放点积注意力:自注意力机制的核心计算

问题背景:传统 RNN/LSTM 通过递归传递隐藏状态处理序列,存在两个问题: 1)长距离依赖难以捕捉(梯度消失), 2)每个时间步只能看到之前的信息,无法并行计算。自注意力机制通过"查询-键-值"( Q/K/V)框架,让每个位置直接关注序列中所有位置,从而解决长距离依赖问题,且支持并行计算。

解决思路:自注意力的核心是"相似度加权"——对于每个查询位置 ,计算它与所有键位置的相似度(点积),通过 softmax 归一化为注意力权重,然后用这些权重对值向量加权求和。缩放因子 防止点积值过大导致 softmax 梯度消失。整个过程可以表示为:

设计考虑

  1. Q/K/V 的含义: Query(查询)表示"我想找什么", Key(键)表示"我是什么", Value(值)表示"我的内容"。自注意力中 Q=K=V(都是输入序列的线性变换)
  2. 缩放的必要性:点积 的方差随 增长,导致 softmax 饱和(梯度消失)。除以 保持方差为 1
  3. 掩码机制:通过 mask 屏蔽无效位置(如 padding 、未来信息),设置为 使得 softmax 后权重接近 0
  4. 计算复杂度,其中 是序列长度, 是特征维度。对于长序列,需要使用稀疏注意力或线性注意力

以下是一个简单的自注意力机制的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np

def scaled_dot_product_attention(Q, K, V, mask=None):
"""
缩放点积注意力( Scaled Dot-Product Attention):自注意力机制的核心计算

核心思想:通过查询-键相似度计算注意力权重,对值向量加权求和
数学表达: Attention(Q, K, V) = softmax(QK^T / √ d_k) · V
其中:
- Q:查询矩阵 (batch_size, seq_len_q, d_k)
- K:键矩阵 (batch_size, seq_len_k, d_k)
- V:值矩阵 (batch_size, seq_len_k, d_v)
- d_k:键/查询的维度(缩放因子)

计算流程:
1. 计算相似度得分: scores = Q · K^T / √ d_k
2. 应用掩码(可选): scores[mask==0] = -∞
3. 归一化权重: weights = softmax(scores)
4. 加权求和: output = weights · V

Parameters:
-----------
Q : numpy.ndarray, shape (batch_size, seq_len_q, d_k)
查询矩阵( Query)
- 每个位置表示"我想关注什么"
- 自注意力中: Q = X · W_Q(输入序列的线性变换)
- 维度: d_k 是查询/键的维度(通常 64 、 128 、 256 等)
K : numpy.ndarray, shape (batch_size, seq_len_k, d_k)
键矩阵( Key)
- 每个位置表示"我是谁"(用于匹配查询)
- 自注意力中: K = X · W_K(输入序列的线性变换)
- 维度: d_k 必须与 Q 的最后一维相同
V : numpy.ndarray, shape (batch_size, seq_len_k, d_v)
值矩阵( Value)
- 每个位置表示"我的内容"(实际被加权的信息)
- 自注意力中: V = X · W_V(输入序列的线性变换)
- 维度: d_v 可以与 d_k 不同(通常 d_v = d_k)
mask : numpy.ndarray, optional, shape (batch_size, seq_len_q, seq_len_k)
掩码矩阵,用于屏蔽无效位置
- mask[i, j] = 1:位置 i 可以关注位置 j
- mask[i, j] = 0:位置 i 不能关注位置 j(会被设为-∞)
- 常见用途:
* Padding 掩码:屏蔽 padding 位置
* Causal 掩码:屏蔽未来信息(解码器自注意力)
* 自定义掩码:屏蔽特定位置

Returns:
--------
output : numpy.ndarray, shape (batch_size, seq_len_q, d_v)
注意力输出:每个查询位置的加权值向量
- output[i, j] = Σ(attention_weights[i, j, k] * V[i, k])
- 表示位置 j 的查询关注所有键位置后的加权结果
attention_weights : numpy.ndarray, shape (batch_size, seq_len_q, seq_len_k)
注意力权重矩阵(归一化后的相似度)
- attention_weights[i, j, k]:位置 j 的查询对位置 k 的键的注意力权重
- 每行和为 1( softmax 归一化)
- 可用于可视化:哪些位置被关注最多

Notes:
------
- 缩放因子√ d_k 防止点积值过大导致 softmax 饱和(梯度消失)
- 复杂度: O(seq_len_q × seq_len_k × d_k),对于长序列可能成为瓶颈
- 自注意力中通常 seq_len_q = seq_len_k = n(序列长度)
- 掩码位置设为-1e9 而非-∞,因为 softmax(-1e9) ≈ 0(数值稳定)

Example:
--------
>>> # 自注意力: Q=K=V(都是输入序列的变换)
>>> X = np.random.rand(1, 10, 64) # (batch=1, seq_len=10, d_model=64)
>>> W_Q, W_K, W_V = np.random.rand(64, 64), np.random.rand(64, 64), np.random.rand(64, 64)
>>> Q, K, V = X @ W_Q, X @ W_K, X @ W_V
>>> output, weights = scaled_dot_product_attention(Q, K, V)
>>> # output 形状:(1, 10, 64), weights 形状:(1, 10, 10)
"""
# 步骤 1:获取键/查询的维度 d_k(用于缩放)
# d_k 是 Q 和 K 的最后一个维度,表示查询和键的特征维度
d_k = Q.shape[-1]

# 步骤 2:计算查询-键相似度得分
# Q · K^T:计算每个查询位置与每个键位置的相似度(点积)
# 形状:(batch_size, seq_len_q, d_k) @ (batch_size, d_k, seq_len_k)
# → (batch_size, seq_len_q, seq_len_k)
# 含义: scores[i, j, k] = Q[i, j] · K[i, k](位置 j 的查询与位置 k 的键的相似度)
scores = np.matmul(Q, K.transpose(-2, -1))

# 步骤 3:缩放(关键步骤!)
# 除以√ d_k:防止点积值过大导致 softmax 饱和
# 原因:点积 Q · K^T 的方差与 d_k 成正比,当 d_k 大时,点积值可能很大
# 这会导致 softmax 的输入很大, softmax 输出接近 one-hot(梯度消失)
# 缩放后:方差保持为 1, softmax 输入在合理范围内
scores = scores / np.sqrt(d_k)

# 步骤 4:应用掩码(如果提供)
# 掩码用于屏蔽无效位置(如 padding 、未来信息)
if mask is not None:
# 将掩码为 0 的位置设为-1e9(非常大的负数)
# softmax(-1e9) ≈ 0,使得这些位置的注意力权重接近 0
# 注意: numpy 没有 masked_fill,这里需要手动实现
scores = np.where(mask == 0, -1e9, scores)

# 步骤 5:归一化注意力权重( softmax)
# softmax(scores, axis=-1):对最后一个维度(键维度)归一化
# 形状:(batch_size, seq_len_q, seq_len_k)
# 含义: attention_weights[i, j, :]表示位置 j 的查询对所有键位置的注意力权重(和为 1)
attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True)) # 数值稳定
attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)

# 步骤 6:加权求和(应用注意力权重到值向量)
# attention_weights @ V:对值向量加权求和
# 形状:(batch_size, seq_len_q, seq_len_k) @ (batch_size, seq_len_k, d_v)
# → (batch_size, seq_len_q, d_v)
# 含义: output[i, j] = Σ(attention_weights[i, j, k] * V[i, k])
# 表示位置 j 的查询关注所有键位置后的加权值向量
output = np.matmul(attention_weights, V)

return output, attention_weights

# 使用示例:自注意力机制
# 生成示例输入(模拟时间序列数据)
batch_size = 1
seq_len = 10 # 序列长度(时间步数)
d_k = 64 # 查询/键的维度
d_v = 64 # 值的维度(通常 d_v = d_k)

# 模拟输入序列的 Q/K/V(实际应用中,这些是通过线性变换从输入 X 得到的)
# 形状:(batch_size, seq_len, d_k/d_v)
Q = np.random.rand(batch_size, seq_len, d_k) # 查询矩阵
K = np.random.rand(batch_size, seq_len, d_k) # 键矩阵
V = np.random.rand(batch_size, seq_len, d_v) # 值矩阵

# 计算自注意力
output, attention_weights = scaled_dot_product_attention(Q, K, V)

print(f"输入形状: Q={Q.shape}, K={K.shape}, V={V.shape}")
print(f"输出形状: output={output.shape}, attention_weights={attention_weights.shape}")
print(f"注意力权重每行和(应为 1): {attention_weights.sum(axis=-1)}")

# 可视化注意力权重(查看哪些位置被关注最多)
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.imshow(attention_weights[0], cmap='viridis', aspect='auto')
plt.colorbar(label='注意力权重')
plt.xlabel('键位置(被关注的位置)')
plt.ylabel('查询位置(关注的位置)')
plt.title('自注意力权重矩阵可视化')
plt.show()

关键点解读

  1. 缩放因子的重要性:缩放因子 看似简单,但至关重要。当 较大(如 512)时,点积 的值可能很大(如 100),导致 softmax 输入很大, softmax 输出接近 one-hot 分布(几乎只关注一个位置),梯度接近 0 。除以 后,点积的方差保持为 1, softmax 输入在合理范围(如[-2, 2]),梯度正常流动。这是 Transformer 成功的关键设计之一。

  2. Q/K/V 的语义解释:在自注意力中, Q/K/V 都来自同一输入序列 的线性变换,但语义不同。 Query 表示"我想关注什么特征", Key 表示"我提供什么特征用于匹配", Value 表示"我的实际内容"。注意力权重 表示"位置 的查询对位置 的键的匹配程度",然后用这个权重对值向量加权:。这允许每个位置直接关注序列中所有位置,无需递归传递。

  3. 掩码机制的应用:掩码在 Transformer 中至关重要。 Padding 掩码屏蔽 padding 位置(避免关注无效信息), Causal 掩码屏蔽未来信息(解码器自注意力中,位置 只能关注位置)。掩码通过将无效位置的得分设为 实现, softmax 后这些位置的权重接近 0,不影响计算。

常见问题

问题 解答
Q: 为什么需要缩放因子√ d_k? A: 防止点积值过大导致 softmax 饱和。点积 的方差与 成正比,当 大时,点积值可能很大,导致 softmax 输出接近 one-hot(梯度消失)。除以 保持方差为 1 。
Q: Q/K/V 必须来自同一输入吗? A: 不一定。自注意力中 Q=K=V(都来自输入 X),但交叉注意力( Cross-Attention)中 Q 来自解码器, K/V 来自编码器。这允许解码器关注编码器的所有位置。
Q: 注意力机制的复杂度是多少? A: ,其中 是序列长度, 是特征维度。对于长序列(如$ n=10000线O(n^2)$ 复杂度、位置信息需要额外编码(位置编码)。

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# 实际应用: PyTorch 实现(带梯度支持)
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
"""PyTorch 版本的缩放点积注意力"""
def __init__(self, d_k, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_k = d_k
self.dropout = nn.Dropout(dropout) # 注意力 dropout(正则化)

def forward(self, Q, K, V, mask=None):
"""
Q, K, V: (batch_size, seq_len, d_k/d_v)
mask: (batch_size, seq_len_q, seq_len_k) 或 (batch_size, 1, seq_len_k)
"""
# 计算相似度得分
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)

# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

# Softmax 归一化
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights) # Dropout 正则化

# 加权求和
output = torch.matmul(attention_weights, V)

return output, attention_weights

# 使用示例:时间序列预测
batch_size = 32
seq_len = 100
d_model = 128

# 模拟时间序列数据
X = torch.randn(batch_size, seq_len, d_model) # 输入序列

# 线性变换得到 Q/K/V
W_Q = nn.Linear(d_model, d_model)
W_K = nn.Linear(d_model, d_model)
W_V = nn.Linear(d_model, d_model)

Q = W_Q(X) # (32, 100, 128)
K = W_K(X) # (32, 100, 128)
V = W_V(X) # (32, 100, 128)

# 计算自注意力
attention = ScaledDotProductAttention(d_k=d_model)
output, weights = attention(Q, K, V)

print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")

# 可视化第一个样本的注意力权重
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 8))
plt.imshow(weights[0].detach().numpy(), cmap='viridis', aspect='auto')
plt.colorbar(label='注意力权重')
plt.xlabel('键位置')
plt.ylabel('查询位置')
plt.title('自注意力权重矩阵(第一个样本)')
plt.show()

Seq2Seq with Attention

数学原理

带有注意力机制的 Seq2Seq 模型通过动态调整解码器对编码器隐藏状态的关注来提高模型性能。以下是其核心原理:

编码器:将输入序列 通过 RNN(如 LSTM 或 GRU)处理,生成隐藏状态序列

注意力权重:在解码器的每个时间步 ,计算解码器隐藏状态 与编码器隐藏状态 之间的相似度,得到注意力权重

其中,,通常采用点积、双线性或 MLP 作为得分函数。

上下文向量:根据注意力权重对编码器隐藏状态加权求和,得到上下文向量 $ c_t$

c_t = {i=1}^{n} {t,i} h_i $$

解码器:将上下文向量 与解码器的输入和隐藏状态结合,生成当前时间步的输出。

代码实现

Seq2Seq with Attention:编码器-解码器架构的注意力增强

问题背景:传统 Seq2Seq 模型(编码器-解码器)存在信息瓶颈问题:编码器将整个输入序列压缩为固定长度的上下文向量(通常是最后一个隐藏状态),解码器只能基于这个向量生成输出。对于长序列,固定长度向量无法承载所有信息,导致性能下降。注意力机制通过让解码器在每个时间步动态关注编码器的所有位置,解决了信息瓶颈问题。

解决思路: Seq2Seq with Attention 的核心是"动态上下文向量"——解码器在每个时间步$ t s_t$ 与编码器所有隐藏状态 的相似度,得到注意力权重,然后加权求和得到上下文向量$ c_t = i {t,i} h_i$。这个上下文向量包含"当前时刻最需要关注的编码器信息",与解码器输入和隐藏状态结合生成输出。这样,解码器可以根据当前状态动态选择关注编码器的不同位置。

设计考虑: 1. 注意力得分函数:这里使用 MLP(多层感知机)计算得分:$ e_{t,i} = v^T (W [s_t; h_i])W线 v$ 是可学习向量。也可以使用点积、双线性等。 2. 上下文向量的使用:上下文向量 在两个地方使用: 1)与解码器输入拼接作为 LSTM 输入, 2)与解码器输出拼接后通过线性层生成最终输出。 3. Teacher Forcing vs Free Running:训练时使用 Teacher Forcing(使用真实目标序列),推理时使用 Free Running(使用模型自身输出)。这里实现的是训练模式。 4. 批处理支持:所有操作支持批处理,提高训练效率。

以下是一个带有注意力机制的 Seq2Seq 模型的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import torch
import torch.nn as nn
import torch.optim as optim

# 定义注意力机制的类( Bahdanau Attention)
class Attention(nn.Module):
"""
Bahdanau 注意力机制:用于 Seq2Seq 模型

核心思想:计算解码器隐藏状态与编码器所有隐藏状态的相似度,得到注意力权重
数学表达: e_{t,i} = v^T · tanh(W · [s_t; h_i])
α_{t,i} = softmax(e_{t,i})
c_t = Σ(α_{t,i} · h_i)

其中:
- s_t:解码器在时间 t 的隐藏状态(查询)
- h_i:编码器在位置 i 的隐藏状态(键和值)
- c_t:上下文向量(加权后的编码器信息)

Parameters:
-----------
hidden_dim : int
隐藏状态维度(编码器和解码器的隐藏维度必须相同)
"""
def __init__(self, hidden_dim):
super(Attention, self).__init__()
# 线性变换层:将[s_t; h_i](维度 2*hidden_dim)映射到 hidden_dim
# 输入:解码器隐藏状态 s_t 和编码器隐藏状态 h_i 的拼接
# 输出:能量分数( energy score)
self.attn = nn.Linear(hidden_dim * 2, hidden_dim)

# 可学习的参数向量 v:用于将能量分数转换为标量得分
# 维度: hidden_dim
# 作用: v^T · energy 得到标量得分,用于计算注意力权重
self.v = nn.Parameter(torch.rand(hidden_dim))

def forward(self, hidden, encoder_outputs):
"""
计算注意力权重

Parameters:
-----------
hidden : torch.Tensor, shape (batch_size, hidden_dim)
解码器当前时间步的隐藏状态(查询)
- 这是解码器 LSTM 的隐藏状态,表示"当前解码状态"
encoder_outputs : torch.Tensor, shape (batch_size, seq_len, hidden_dim)
编码器所有时间步的隐藏状态序列(键和值)
- encoder_outputs[:, i, :]:编码器在位置 i 的隐藏状态

Returns:
--------
attention_weights : torch.Tensor, shape (batch_size, seq_len)
注意力权重:每个编码器位置的权重(归一化后,每行和为 1)
- attention_weights[i, j]:样本 i 的解码器对编码器位置 j 的注意力权重
"""
# 获取编码器序列长度
timestep = encoder_outputs.size(1) # seq_len

# 扩展解码器隐藏状态:从(batch_size, hidden_dim)扩展到(batch_size, seq_len, hidden_dim)
# 目的:使 hidden 与 encoder_outputs 的每个位置配对
# hidden.repeat(timestep, 1, 1):在时间维度重复 seq_len 次
# transpose(0, 1):调整维度顺序,得到(batch_size, seq_len, hidden_dim)
h = hidden.repeat(timestep, 1, 1).transpose(0, 1)

# 计算能量分数: e_{t,i} = v^T · tanh(W · [s_t; h_i])
# torch.cat((h, encoder_outputs), 2):拼接解码器隐藏状态和编码器隐藏状态
# 形状:(batch_size, seq_len, 2*hidden_dim)
# self.attn(...):线性变换,输出(batch_size, seq_len, hidden_dim)
# torch.tanh(...):激活函数,输出(batch_size, seq_len, hidden_dim)
energy = torch.tanh(self.attn(torch.cat((h, encoder_outputs), 2)))

# 转置 energy:从(batch_size, seq_len, hidden_dim)到(batch_size, hidden_dim, seq_len)
# 目的:为后续的 batch 矩阵乘法准备
energy = energy.transpose(2, 1)

# 扩展参数向量 v:从(hidden_dim,)到(batch_size, 1, hidden_dim)
# 目的:与 energy 进行 batch 矩阵乘法
v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)

# 计算注意力得分: v^T · energy
# torch.bmm(v, energy): batch 矩阵乘法
# v: (batch_size, 1, hidden_dim)
# energy: (batch_size, hidden_dim, seq_len)
# 结果: (batch_size, 1, seq_len)
# squeeze(1):移除维度 1,得到(batch_size, seq_len)
attention_weights = torch.bmm(v, energy).squeeze(1)

# Softmax 归一化:将得分转换为概率分布(注意力权重)
# 每行的和为 1,表示对编码器所有位置的注意力权重分布
return torch.softmax(attention_weights, dim=1)

# 定义带有注意力机制的序列到序列模型
class Seq2SeqWithAttention(nn.Module):
"""
Seq2Seq 模型 + 注意力机制:用于序列到序列任务(如时间序列预测、机器翻译)

架构:
1. 编码器( Encoder): LSTM 处理输入序列,输出所有时间步的隐藏状态
2. 注意力机制( Attention):计算解码器隐藏状态与编码器隐藏状态的注意力权重
3. 解码器( Decoder): LSTM 生成输出序列,每个时间步使用注意力加权的上下文向量

数学流程:
- 编码: h_i = LSTM_encoder(x_i, h_{i-1}), i=1,...,n
- 解码(时间 t):
1. α_{t,i} = Attention(s_{t-1}, {h_i})
2. c_t = Σ(α_{t,i} · h_i)
3. s_t = LSTM_decoder([y_{t-1}; c_t], s_{t-1})
4. y_t = Linear([s_t; c_t])

Parameters:
-----------
input_dim : int
输入特征维度(编码器输入)
例如:时间序列的每个时间步的特征数
hidden_dim : int
隐藏状态维度(编码器和解码器共享)
典型值: 64, 128, 256, 512
output_dim : int
输出特征维度(解码器输出)
例如:预测序列的每个时间步的特征数
"""
def __init__(self, input_dim, hidden_dim, output_dim):
super(Seq2SeqWithAttention, self).__init__()
# 编码器: LSTM 处理输入序列
# 输入维度: input_dim
# 隐藏维度: hidden_dim
# batch_first=True:输入形状为(batch_size, seq_len, input_dim)
self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)

# 解码器: LSTM 生成输出序列
# 输入维度: hidden_dim + output_dim
# - hidden_dim:上下文向量维度
# - output_dim:上一时间步的输出维度(或目标序列维度)
# 隐藏维度: hidden_dim(与编码器相同)
self.decoder = nn.LSTM(hidden_dim + output_dim, hidden_dim, batch_first=True)

# 注意力机制:计算解码器对编码器的注意力权重
self.attention = Attention(hidden_dim)

# 输出层:将解码器隐藏状态和上下文向量映射到输出维度
# 输入维度: hidden_dim * 2(解码器隐藏状态 + 上下文向量)
# 输出维度: output_dim
self.fc = nn.Linear(hidden_dim * 2, output_dim)

def forward(self, src, trg):
"""
前向传播:编码-注意力-解码

Parameters:
-----------
src : torch.Tensor, shape (batch_size, src_seq_len, input_dim)
源序列(输入序列)
例如:历史时间序列数据
trg : torch.Tensor, shape (batch_size, trg_seq_len, output_dim)
目标序列(用于 Teacher Forcing 训练)
例如:未来时间序列数据(训练时)或初始值(推理时)

Returns:
--------
outputs : torch.Tensor, shape (batch_size, trg_seq_len, output_dim)
预测的输出序列
- outputs[:, 0, :]:通常是 0 或初始值(不使用)
- outputs[:, 1:, :]:实际的预测序列
"""
# ========== 编码阶段 ==========
# 编码器处理输入序列,得到所有时间步的隐藏状态
# encoder_outputs: (batch_size, src_seq_len, hidden_dim)
# - 包含编码器每个时间步的隐藏状态
# hidden: (1, batch_size, hidden_dim) - 最后一个时间步的隐藏状态
# cell: (1, batch_size, hidden_dim) - 最后一个时间步的细胞状态
encoder_outputs, (hidden, cell) = self.encoder(src)

# 调整 hidden 和 cell 的形状:从(1, batch_size, hidden_dim)到(batch_size, hidden_dim)
# 用于后续的注意力计算和解码器初始化
hidden = hidden.squeeze(0) # (batch_size, hidden_dim)
cell = cell.squeeze(0) # (batch_size, hidden_dim)

# 初始化输出张量:存储解码器每个时间步的输出
# 形状:(batch_size, trg_seq_len, output_dim)
outputs = torch.zeros(trg.size(0), trg.size(1), trg.size(2)).to(trg.device)

# 解码器初始输入:使用目标序列的第一个时间步( Teacher Forcing)
# 形状:(batch_size, output_dim)
input = trg[:, 0, :]

# ========== 解码阶段(逐个时间步生成)==========
# 从时间步 1 开始(时间步 0 通常是初始值或 padding)
for t in range(1, trg.size(1)):
# 步骤 1:计算注意力权重
# attention_weights: (batch_size, src_seq_len)
# 表示解码器当前隐藏状态对编码器每个位置的注意力权重
attention_weights = self.attention(hidden, encoder_outputs)

# 步骤 2:计算上下文向量(加权求和)
# attention_weights.unsqueeze(1): (batch_size, 1, src_seq_len)
# encoder_outputs: (batch_size, src_seq_len, hidden_dim)
# torch.bmm(...): batch 矩阵乘法,得到(batch_size, 1, hidden_dim)
# squeeze(1): 移除维度 1,得到(batch_size, hidden_dim)
# 上下文向量 c_t = Σ(α_{t,i} · h_i):编码器隐藏状态的加权平均
context = attention_weights.unsqueeze(1).bmm(encoder_outputs).squeeze(1)

# 步骤 3:准备解码器输入
# 将当前输入和上下文向量拼接:[y_{t-1}; c_t]
# torch.cat((input, context), dim=1): (batch_size, output_dim + hidden_dim)
# unsqueeze(1): 添加时间维度,得到(batch_size, 1, output_dim + hidden_dim)
rnn_input = torch.cat((input, context), dim=1).unsqueeze(1)

# 步骤 4:解码器前向传播
# output: (batch_size, 1, hidden_dim) - 当前时间步的解码器输出
# hidden: (1, batch_size, hidden_dim) - 更新后的隐藏状态
# cell: (1, batch_size, hidden_dim) - 更新后的细胞状态
output, (hidden, cell) = self.decoder(rnn_input, (hidden.unsqueeze(0), cell.unsqueeze(0)))

# 调整 hidden 和 cell 的形状:从(1, batch_size, hidden_dim)到(batch_size, hidden_dim)
hidden = hidden.squeeze(0)
cell = cell.squeeze(0)

# 步骤 5:生成最终输出
# 将解码器输出和上下文向量拼接:[s_t; c_t]
# output.squeeze(1): (batch_size, hidden_dim)
# context: (batch_size, hidden_dim)
# torch.cat(...): (batch_size, 2*hidden_dim)
# self.fc(...): 线性变换,得到(batch_size, output_dim)
output = self.fc(torch.cat((output.squeeze(1), context), dim=1))

# 保存当前时间步的输出
outputs[:, t, :] = output

# 步骤 6:准备下一个时间步的输入( Teacher Forcing)
# 训练时:使用真实目标序列( trg[:, t, :])
# 推理时:使用模型自身输出( output)
input = trg[:, t, :] # Teacher Forcing(训练模式)
# 如果要使用 Free Running(推理模式),改为: input = output

return outputs

# 使用示例:时间序列预测
# 设置模型参数
input_dim = 10 # 输入特征维度(如: 10 个特征)
hidden_dim = 64 # 隐藏状态维度
output_dim = 10 # 输出特征维度(预测 10 个特征)

# 生成示例数据
batch_size = 32
src_seq_len = 15 # 输入序列长度(历史 15 个时间步)
trg_seq_len = 20 # 输出序列长度(预测未来 20 个时间步)

# 模拟输入序列和目标序列
src = torch.rand(batch_size, src_seq_len, input_dim) # 历史数据
trg = torch.rand(batch_size, trg_seq_len, output_dim) # 未来数据(训练时)

# 创建模型实例
model = Seq2SeqWithAttention(input_dim, hidden_dim, output_dim)

# 前向传播
outputs = model(src, trg)

print(f"输入形状: src={src.shape}")
print(f"目标形状: trg={trg.shape}")
print(f"输出形状: outputs={outputs.shape}")

# 训练示例(简化版)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练一个 batch
optimizer.zero_grad()
outputs = model(src, trg)
loss = criterion(outputs[:, 1:, :], trg[:, 1:, :]) # 从时间步 1 开始计算损失
loss.backward()
optimizer.step()

print(f"训练损失: {loss.item():.4f}")

关键点解读

  1. 注意力机制解决信息瓶颈:传统 Seq2Seq 模型将整个输入序列压缩为固定长度的上下文向量(通常是编码器最后一个隐藏状态),这导致长序列信息丢失。注意力机制通过让解码器在每个时间步动态关注编码器的所有位置,解决了这个问题。上下文向量$ c_t = i {t,i} h_i$ 包含"当前时刻最需要的信息",而不是固定的压缩表示。

  2. Teacher Forcing vs Free Running:训练时使用 Teacher Forcing(使用真实目标序列作为解码器输入),这加速训练并提高稳定性。推理时使用 Free Running(使用模型自身输出),这更接近实际应用场景。代码中input = trg[:, t, :]是 Teacher Forcing 模式,推理时应改为input = output

  3. 上下文向量的双重使用:上下文向量 在两个地方使用: 1)与解码器输入拼接作为 LSTM 输入([y_{t-1}; c_t]),让解码器知道"应该关注编码器的哪些信息"; 2)与解码器输出拼接后通过线性层生成最终输出([s_t; c_t]),让输出包含"注意力加权的编码器信息"。这种设计使得模型能够充分利用注意力机制。

常见问题

问题 解答
Q: 为什么需要注意力机制? A: 传统 Seq2Seq 模型存在信息瓶颈:固定长度上下文向量无法承载长序列的所有信息。注意力机制让解码器动态关注编码器的所有位置,解决了这个问题。
Q: Bahdanau Attention vs Luong Attention? A: Bahdanau(这里实现)使用 MLP 计算得分, Luong 使用点积或双线性。 Bahdanau 更灵活但计算更慢, Luong 更快但表达能力稍弱。实际应用中两者性能相近。
Q: 如何可视化注意力权重? A: 注意力权重矩阵可视化:横轴是编码器位置,纵轴是解码器时间步。颜色深浅表示权重大小,可以用于可解释性分析。
Q: 训练和推理有什么区别? A: 训练时使用 Teacher Forcing(真实目标序列),推理时使用 Free Running(模型自身输出)。推理时第一个时间步需要提供初始值或使用特殊标记。
Q: 如何处理变长序列? A: 使用 padding 和 masking 。 padding 到相同长度, masking 屏蔽 padding 位置的注意力权重(设为-∞)。 PyTorch 的 pack_padded_sequence 可以处理变长序列。

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# 实际应用:时间序列预测(完整训练循环)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 创建模型
model = Seq2SeqWithAttention(input_dim=5, hidden_dim=128, output_dim=5)

# 准备数据(示例)
# 假设有 1000 个样本,每个样本有历史 20 步和未来 10 步
n_samples = 1000
src_data = torch.randn(n_samples, 20, 5) # 历史数据
trg_data = torch.randn(n_samples, 10, 5) # 未来数据

# 创建数据加载器
dataset = TensorDataset(src_data, trg_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 训练设置
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

# 训练循环
model.train()
for epoch in range(100):
total_loss = 0
for src, trg in dataloader:
optimizer.zero_grad()

# 前向传播( Teacher Forcing)
# 注意: trg 需要包含初始值(第一个时间步)
trg_input = torch.cat([torch.zeros(trg.size(0), 1, trg.size(2)), trg], dim=1)
outputs = model(src, trg_input)

# 计算损失(从时间步 1 开始,因为时间步 0 通常是初始值)
loss = criterion(outputs[:, 1:, :], trg)

# 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 梯度裁剪
optimizer.step()

total_loss += loss.item()

avg_loss = total_loss / len(dataloader)
scheduler.step(avg_loss)

if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

# 推理示例( Free Running)
model.eval()
with torch.no_grad():
# 单个样本推理
src_sample = src_data[0:1] # (1, 20, 5)

# 初始化目标序列(第一个时间步使用 0 或特殊值)
trg_init = torch.zeros(1, 1, 5) # (1, 1, 5)
trg_pred = trg_init.clone()

# 编码
encoder_outputs, (hidden, cell) = model.encoder(src_sample)
hidden = hidden.squeeze(0)
cell = cell.squeeze(0)

# 逐步解码( Free Running)
predictions = []
input = trg_init[:, 0, :] # 初始输入

for t in range(10): # 预测 10 步
# 注意力
attention_weights = model.attention(hidden, encoder_outputs)
context = attention_weights.unsqueeze(1).bmm(encoder_outputs).squeeze(1)

# 解码
rnn_input = torch.cat((input, context), dim=1).unsqueeze(1)
output, (hidden, cell) = model.decoder(rnn_input, (hidden.unsqueeze(0), cell.unsqueeze(0)))
hidden = hidden.squeeze(0)
cell = cell.squeeze(0)

# 生成输出
output = model.fc(torch.cat((output.squeeze(1), context), dim=1))
predictions.append(output)

# 使用自身输出作为下一时间步的输入( Free Running)
input = output

predictions = torch.stack(predictions, dim=1) # (1, 10, 5)
print(f"预测形状: {predictions.shape}")

❓ Q&A: Attention 常见疑问

位置编码( Positional Encoding)深度解析

为什么需要位置编码?

核心问题:自注意力机制是排列不变的( Permutation Invariant)

自注意力只计算词与词之间的相似度,完全不考虑位置信息。这意味着:

  • "我爱你" 和 "你爱我" 会被视为相同
  • "猫吃鱼" 和 "鱼吃猫" 会被视为相同

这在自然语言中是不可接受的,因为词序决定语义

Q1:什么是位置编码( Positional Encoding),为什么需要它?

核心问题:自注意力机制是排列不变的( Permutation Invariant)

想象一下,如果你把句子"我爱你"打乱成"爱你我"或"你我爱",自注意力会给出完全相同的输出!因为它只计算词与词之间的相似度,不关心词的位置顺序

正弦/余弦位置编码( Sinusoidal PE)

为什么选择正弦/余弦?

  1. 固定长度:不需要训练,可以外推到更长序列
  2. 相对位置信息 可以表示为 的线性组合

Python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)

def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1), :]

# 使用示例
pos_encoder = PositionalEncoding(d_model=512)
x = torch.randn(32, 100, 512) # (batch, seq_len, d_model)
x_encoded = pos_encoder(x) # 添加位置编码

可学习位置编码( Learned Positional Encoding)

1
2
3
4
5
6
7
8
9
class LearnedPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.pos_embedding = nn.Embedding(max_len, d_model)

def forward(self, x):
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
return x + self.pos_embedding(positions)

位置编码对比

类型 优点 缺点 适用场景
正弦/余弦 可外推,无需训练 固定模式 大多数场景
可学习 自适应 无法外推 固定长度序列

正弦位置编码的数学性质

正弦位置编码的一个重要性质是相对位置可表示性

$$

PE_{pos+k} = PE_{pos} M_k $$

其中 是一个旋转矩阵。这意味着:

  • 位置 的编码可以表示为位置 编码的线性变换
  • 模型可以学习相对位置关系,而不仅仅是绝对位置

位置编码的维度选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def analyze_positional_encoding(d_model=512, max_len=5000):
"""分析位置编码的性质"""
pe = PositionalEncoding(d_model, max_len)

# 可视化不同频率的编码
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 不同位置的编码(前 100 个位置)
pos_encoding = pe.pe[0, :100, :].numpy()
axes[0, 0].imshow(pos_encoding.T, cmap='coolwarm', aspect='auto')
axes[0, 0].set_title('位置编码热力图(前 100 位置)')
axes[0, 0].set_xlabel('位置')
axes[0, 0].set_ylabel('维度')

# 不同频率的编码(前 32 个维度)
axes[0, 1].plot(pos_encoding[:, :32])
axes[0, 1].set_title('不同频率的正弦/余弦编码')
axes[0, 1].set_xlabel('位置')
axes[0, 1].set_ylabel('编码值')

# 相对位置关系
pos_0 = pe.pe[0, 0, :].numpy()
pos_10 = pe.pe[0, 10, :].numpy()
pos_20 = pe.pe[0, 20, :].numpy()

axes[1, 0].plot(pos_0[:64], label='位置 0')
axes[1, 0].plot(pos_10[:64], label='位置 10')
axes[1, 0].plot(pos_20[:64], label='位置 20')
axes[1, 0].set_title('不同位置的编码模式')
axes[1, 0].legend()

# 编码的相似度矩阵
similarity = torch.matmul(pe.pe[0, :50, :], pe.pe[0, :50, :].T)
axes[1, 1].imshow(similarity.numpy(), cmap='YlOrRd')
axes[1, 1].set_title('位置编码相似度矩阵')
axes[1, 1].set_xlabel('位置 i')
axes[1, 1].set_ylabel('位置 j')

plt.tight_layout()
plt.show()

# 使用
analyze_positional_encoding(d_model=512, max_len=5000)

可学习位置编码的改进

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class LearnedPositionalEncoding(nn.Module):
"""可学习位置编码(带初始化策略)"""
def __init__(self, d_model, max_len=5000):
super().__init__()
# 使用正弦位置编码初始化(更好的起点)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# 初始化为正弦编码,但允许学习调整
self.pos_embedding = nn.Embedding(max_len, d_model)
self.pos_embedding.weight.data = pe.unsqueeze(0)
self.pos_embedding.weight.requires_grad = True # 允许学习

def forward(self, x):
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
return x + self.pos_embedding(positions)

相对位置编码( Relative Position Encoding)

除了绝对位置,还可以编码相对位置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class RelativePositionalEncoding(nn.Module):
"""相对位置编码"""
def __init__(self, d_model, max_rel_pos=128):
super().__init__()
self.max_rel_pos = max_rel_pos
# 相对位置嵌入:从-max_rel_pos 到 max_rel_pos
self.rel_pos_embedding = nn.Embedding(2 * max_rel_pos + 1, d_model)

def forward(self, q, k):
"""
q: [batch, seq_len_q, d_model]
k: [batch, seq_len_k, d_model]
"""
seq_len_q = q.size(1)
seq_len_k = k.size(1)

# 计算相对位置矩阵
rel_pos = torch.arange(seq_len_k, device=q.device).unsqueeze(0) - \
torch.arange(seq_len_q, device=q.device).unsqueeze(1)
# 限制在[-max_rel_pos, max_rel_pos]范围内
rel_pos = torch.clamp(rel_pos, -self.max_rel_pos, self.max_rel_pos)
rel_pos = rel_pos + self.max_rel_pos # 转换为索引[0, 2*max_rel_pos]

# 获取相对位置编码
rel_pos_emb = self.rel_pos_embedding(rel_pos) # [seq_len_q, seq_len_k, d_model]

# 应用到注意力计算中
# 这里简化处理,实际需要修改注意力计算
return rel_pos_emb

Multi-Head Attention 详解

为什么需要多头?

单头注意力只能学习一种表示模式,而多头注意力允许模型同时关注不同类型的依赖关系。就像人类理解句子时,会同时关注语法结构、语义关系、情感色彩等多个维度。

多头注意力的数学原理

完整公式

其中:

  • 是头的数量
  • 是第 个头的投影矩阵
  • 是输出投影矩阵
  • 是每个头的维度

关键设计

  • 参数共享:所有头共享输入 ,但使用不同的投影矩阵
  • 独立计算:每个头独立计算注意力,学习不同的表示子空间
  • 拼接融合:所有头的输出拼接后通过 投影回原始维度

每个头学习什么?

通过可视化不同头的注意力权重,可以发现:

头类型 关注模式 典型应用
局部头 关注相邻位置(对角线附近权重高) 捕捉局部依赖、短语结构
全局头 均匀关注所有位置(权重分布均匀) 捕捉全局语义、文档级信息
特定位置头 只关注少数关键位置(权重高度集中) 捕捉关键词、重要实体
句法头 关注句法相关位置(如主谓宾关系) 理解语法结构
语义头 关注语义相似的位置 理解同义词、语义关系

实验证据

在 BERT 等模型中,研究发现:

  • 头 1-2:主要关注局部依赖(相邻词)
  • 头 3-4:关注句法结构(主谓宾)
  • 头 5-6:关注语义关系(同义词、反义词)
  • 头 7-8:关注全局信息(文档级特征)

Q2:在多头注意力机制( Multi-Head Attention)中,每个头( head)是如何独立工作的?

核心思想不同的头关注不同的特征

多头注意力的优势

  • 每个头独立学习不同的表示子空间
  • 头 1 可能关注局部依赖(相邻词)
  • 头 2 可能关注长距离依赖(句子头尾)
  • 头 3 可能关注句法结构(主谓宾)

数学公式

Python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

# 为每个头创建独立的 Q, K, V 投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)

# 1. 线性变换并分割为多个头
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

# 2. 每个头独立计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

attention_weights = torch.softmax(scores, dim=-1)
head_outputs = torch.matmul(attention_weights, V)

# 3. 拼接所有头
concat = head_outputs.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)

# 4. 输出投影
output = self.W_o(concat)
return output, attention_weights

# 使用示例
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(32, 100, 512)
output, attn_weights = mha(x, x, x)
# output: (32, 100, 512)
# attn_weights: (32, 8, 100, 100) - 8 个头,每个头都有注意力权重

可视化不同头的注意力模式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_multi_head_attention(attention_weights, head_idx=0):
"""可视化特定头的注意力权重"""
# attention_weights: (batch, num_heads, seq_len, seq_len)
attn = attention_weights[0, head_idx].detach().numpy() # 第一个样本,指定头

plt.figure(figsize=(10, 8))
sns.heatmap(attn, cmap='YlOrRd', annot=False)
plt.title(f'Head {head_idx} 的注意力权重')
plt.xlabel('Key 位置')
plt.ylabel('Query 位置')
plt.show()

# 可视化所有头
for i in range(8):
visualize_multi_head_attention(attn_weights, head_idx=i)

头数选择指南

模型大小 推荐头数 原因
小模型( d_model=128) 4-8 平衡表达能力和计算成本
中等模型( d_model=512) 8-16 标准配置
大模型( d_model=1024) 16-32 需要更多表示子空间

Q3:如何使用掩码( Mask)来处理变长序列?

掩码的三种类型

1. 填充掩码( Padding Mask)

  • 作用:遮挡序列末尾的填充位置(通常是 0)

2. 因果掩码( Causal Mask / Look-Ahead Mask)

  • 作用:防止解码器在生成第 个 token 时看到未来的 token

3. 组合掩码

  • 编码器:只用填充掩码
  • 解码器:填充掩码 + 因果掩码

Python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import numpy as np

def create_padding_mask(seq, pad_token=0):
"""创建填充掩码"""
# seq: (batch, seq_len)
mask = (seq != pad_token).unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
return mask.float()

def create_causal_mask(seq_len):
"""创建因果掩码(下三角矩阵)"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
mask = mask.masked_fill(mask == 0, float(0.0))
return mask # (seq_len, seq_len)

def create_combined_mask(seq, pad_token=0):
"""组合填充掩码和因果掩码"""
seq_len = seq.size(1)

# 填充掩码
padding_mask = create_padding_mask(seq, pad_token) # (batch, 1, 1, seq_len)

# 因果掩码
causal_mask = create_causal_mask(seq_len) # (seq_len, seq_len)

# 组合(广播相加)
combined_mask = padding_mask + causal_mask.unsqueeze(0).unsqueeze(0)
combined_mask = combined_mask.masked_fill(combined_mask > 0, float('-inf'))

return combined_mask

# 使用示例
seq = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0]]) # 填充后的序列
mask = create_combined_mask(seq, pad_token=0)

# 在注意力计算中使用
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
scores = scores + mask # 应用掩码
attention_weights = torch.softmax(scores, dim=-1)

掩码可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt

def visualize_mask(mask, title='Mask'):
"""可视化掩码矩阵"""
plt.figure(figsize=(8, 8))
plt.imshow(mask[0, 0].detach().numpy(), cmap='gray')
plt.title(title)
plt.xlabel('Key 位置')
plt.ylabel('Query 位置')
plt.colorbar()
plt.show()

# 可视化不同类型的掩码
padding_mask = create_padding_mask(seq)
causal_mask = create_causal_mask(seq_len=10)
combined_mask = create_combined_mask(seq)

visualize_mask(padding_mask, '填充掩码')
visualize_mask(causal_mask.unsqueeze(0).unsqueeze(0), '因果掩码')
visualize_mask(combined_mask, '组合掩码')

Q4: Transformer 模型相比传统 RNN 模型有哪些优势?

维度 RNN/LSTM/GRU Transformer
并行计算 ❌ 顺序计算 ✅ 全并行
长距离依赖 ⚠️ 梯度消失/爆炸 ✅ 直接连接( O(1) 路径长度)
训练速度 慢(序列越长越慢) 快(序列长度不影响并行度)
内存占用 中等 高( 注意力矩阵)
可解释性 差(隐藏状态黑盒) ✅ 好(注意力权重可视化)

性能对比实验

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import time
import torch
import torch.nn as nn

# 对比训练速度
def benchmark_training(model, data_loader, n_epochs=10):
model.train()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.MSELoss()

start_time = time.time()
for epoch in range(n_epochs):
for batch in data_loader:
optimizer.zero_grad()
output = model(batch[0])
loss = criterion(output, batch[1])
loss.backward()
optimizer.step()
elapsed = time.time() - start_time

return elapsed / n_epochs

# RNN/LSTM/GRU(顺序计算)
rnn_model = nn.LSTM(input_size=10, hidden_size=128, num_layers=2)
rnn_time = benchmark_training(rnn_model, data_loader)

# Transformer(并行计算)
transformer_model = TransformerModel(...)
transformer_time = benchmark_training(transformer_model, data_loader)

print(f'RNN 训练时间: {rnn_time:.2f}s/epoch')
print(f'Transformer 训练时间: {transformer_time:.2f}s/epoch')
print(f'加速比: {rnn_time/transformer_time:.2f}x')

Q5: Attention 机制在时间序列预测中如何应用?

时间序列 Attention 的特殊性

时间序列数据与 NLP 不同:

  • 时间顺序很重要:需要因果掩码
  • 周期性模式:某些时间步可能更重要
  • 多变量:需要处理多个特征

时间序列 Transformer 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import math

class TimeSeriesTransformer(nn.Module):
def __init__(self, input_size, d_model, nhead, num_layers, output_size):
super().__init__()
self.input_projection = nn.Linear(input_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

self.output_projection = nn.Linear(d_model, output_size)

def forward(self, x, mask=None):
# x: (batch, seq_len, features)
x = self.input_projection(x)
x = self.pos_encoder(x)

# 创建因果掩码(时间序列只能看过去)
if mask is None:
seq_len = x.size(1)
mask = self.generate_causal_mask(seq_len, x.device)

x = self.transformer(x, mask=mask)
x = self.output_projection(x[:, -1, :]) # 只用最后时间步
return x

def generate_causal_mask(self, sz, device):
"""生成因果掩码"""
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask

# 使用示例
model = TimeSeriesTransformer(
input_size=10, # 10 个特征
d_model=128,
nhead=8,
num_layers=3,
output_size=1
)

x = torch.randn(32, 50, 10) # (batch, seq_len, features)
pred = model(x) # (32, 1)

Informer:针对长序列的改进

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class ProbSparseAttention(nn.Module):
"""ProbSparse Attention:只计算重要的注意力对"""
def __init__(self, d_model, nhead, factor=5):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.factor = factor # 采样因子

def forward(self, Q, K, V):
# 只采样 top-k 个 Query(基于与 Key 的相似度)
# 减少计算复杂度从 O(n^2) 到 O(n log n)
B, L_Q, _ = Q.size()
L_K = K.size(1)

# 计算采样索引
sample_size = max(L_Q // self.factor, 1)
sample_indices = self.sample_queries(Q, K, sample_size)

# 只计算采样 Query 的注意力
Q_sampled = Q[:, sample_indices, :]
scores = torch.matmul(Q_sampled, K.transpose(-2, -1)) / math.sqrt(self.d_model)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)

return output, attention_weights

def sample_queries(self, Q, K, sample_size):
"""采样重要的 Query"""
# 使用 M(Q, K) = max_j QK^T - mean_j QK^T 作为重要性指标
scores = torch.matmul(Q, K.transpose(-2, -1))
importance = scores.max(dim=-1)[0] - scores.mean(dim=-1)
_, indices = torch.topk(importance, sample_size, dim=-1)
return indices

时间序列 Attention 可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_timeseries_attention(model, x, timestamps):
"""可视化时间序列的注意力权重"""
model.eval()
with torch.no_grad():
# 提取注意力权重(需要修改模型以返回权重)
_, attn_weights = model.get_attention_weights(x)

# attn_weights: (batch, nhead, seq_len, seq_len)
attn = attn_weights[0, 0].detach().numpy() # 第一个样本,第一个头

plt.figure(figsize=(15, 10))
sns.heatmap(attn, cmap='YlOrRd', xticklabels=timestamps, yticklabels=timestamps)
plt.title('时间序列注意力权重热力图')
plt.xlabel('Key 时间步')
plt.ylabel('Query 时间步')
plt.show()

# 分析:哪些时间步被重点关注
attention_sums = attn.sum(axis=0) # 每个时间步被关注的总权重
important_timesteps = np.argsort(attention_sums)[-10:] # Top 10
print(f'最重要的时间步: {important_timesteps}')

Q6:如何解决 Transformer 的 计算复杂度问题?

问题分析

标准自注意力的计算复杂度:

  • 时间复杂度,其中 是序列长度, 是特征维度
  • 空间复杂度(注意力矩阵)

当序列长度 时,注意力矩阵大小为 个元素 当 时,矩阵大小为 个元素!

解决方案 1:稀疏注意力( Sparse Attention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class SparseAttention(nn.Module):
"""只计算局部窗口内的注意力"""
def __init__(self, d_model, nhead, window_size=50):
super().__init__()
self.window_size = window_size
self.attention = nn.MultiheadAttention(d_model, nhead)

def forward(self, x):
# 将序列分成多个窗口,只在窗口内计算注意力
batch_size, seq_len, d_model = x.size()
num_windows = (seq_len + self.window_size - 1) // self.window_size

outputs = []
for i in range(num_windows):
start = i * self.window_size
end = min((i + 1) * self.window_size, seq_len)
window_x = x[:, start:end, :]

# 只在窗口内计算注意力
window_out, _ = self.attention(window_x, window_x, window_x)
outputs.append(window_out)

return torch.cat(outputs, dim=1)

解决方案 2:线性注意力( Linear Attention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class LinearAttention(nn.Module):
"""线性复杂度注意力: O(n) 而非 O(n^2)"""
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)

def forward(self, x):
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)

# 使用特征映射将复杂度从 O(n^2) 降到 O(n)
# 关键:先计算 KV^T,再与 Q 相乘
# 标准: QK^T V → O(n^2 d)
# 线性: Q (K^T V) → O(n d^2)

# 使用 elu + 1 作为特征映射(保证正定)
K = torch.nn.functional.elu(K) + 1
Q = torch.nn.functional.elu(Q) + 1

# 线性复杂度计算
KV = torch.einsum('bnd,bne->bde', K, V) # (batch, d, d)
QKV = torch.einsum('bnd,bde->bne', Q, KV) # (batch, n, d)

# 归一化
normalizer = torch.einsum('bnd,bd->bn', Q, K.sum(dim=1)) # (batch, n)
output = QKV / (normalizer.unsqueeze(-1) + 1e-6)

return output

解决方案 3: Performer(随机特征)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class PerformerAttention(nn.Module):
"""Performer:使用随机特征近似注意力"""
def __init__(self, d_model, n_random_features=64):
super().__init__()
self.d_model = d_model
self.n_random_features = n_random_features

# 随机特征矩阵(固定或可学习)
self.random_features = nn.Parameter(
torch.randn(n_random_features, d_model)
)

def forward(self, Q, K, V):
# 使用随机特征近似 softmax(QK^T)
# 复杂度从 O(n^2) 降到 O(n m),其中 m << n
batch_size, seq_len, _ = Q.size()

# 随机特征映射
Q_features = self.random_feature_map(Q) # (batch, seq_len, m)
K_features = self.random_feature_map(K) # (batch, seq_len, m)

# 线性复杂度计算
QK_features = torch.einsum('bnm,bkm->bnk', Q_features, K_features)
attention_weights = QK_features / math.sqrt(self.d_model)
output = torch.einsum('bnk,bkd->bnd', attention_weights, V)

return output

def random_feature_map(self, x):
"""随机特征映射"""
# 使用正随机特征( Positive Random Features)
projection = torch.matmul(x, self.random_features.t()) # (batch, n, m)
return torch.nn.functional.relu(projection)

性能对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import time

def benchmark_attention(attention_module, x, n_iter=100):
"""对比不同注意力机制的速度"""
attention_module.eval()
start = time.time()
with torch.no_grad():
for _ in range(n_iter):
_ = attention_module(x)
elapsed = time.time() - start
return elapsed / n_iter * 1000 # 毫秒

x = torch.randn(1, 1000, 128) # 长序列

# 标准注意力
standard_attn = StandardAttention(128)
standard_time = benchmark_attention(standard_attn, x)

# 稀疏注意力
sparse_attn = SparseAttention(128, window_size=100)
sparse_time = benchmark_attention(sparse_attn, x)

# 线性注意力
linear_attn = LinearAttention(128)
linear_time = benchmark_attention(linear_attn, x)

print(f'标准注意力: {standard_time:.2f} ms')
print(f'稀疏注意力: {sparse_time:.2f} ms (加速 {standard_time/sparse_time:.2f}x)')
print(f'线性注意力: {linear_time:.2f} ms (加速 {standard_time/linear_time:.2f}x)')

Q7: Attention 机制的可解释性如何利用?

1. 注意力权重可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention_weights(attention_weights, input_tokens=None, head_idx=0):
"""可视化注意力权重矩阵"""
# attention_weights: (batch, nhead, seq_len, seq_len)
attn = attention_weights[0, head_idx].detach().numpy()

plt.figure(figsize=(12, 10))
sns.heatmap(
attn,
cmap='YlOrRd',
xticklabels=input_tokens,
yticklabels=input_tokens,
annot=False,
fmt='.2f'
)
plt.title(f'Attention Head {head_idx} 的权重分布')
plt.xlabel('Key 位置')
plt.ylabel('Query 位置')
plt.tight_layout()
plt.show()

# 使用
model.eval()
with torch.no_grad():
output, attn_weights = model(x, return_attention=True)
visualize_attention_weights(attn_weights, input_tokens=token_names)

2. 注意力头分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def analyze_attention_heads(attention_weights):
"""分析不同头的关注模式"""
n_heads = attention_weights.size(1)

fig, axes = plt.subplots(2, n_heads // 2, figsize=(20, 8))
axes = axes.flatten()

for head_idx in range(n_heads):
attn = attention_weights[0, head_idx].detach().numpy()

# 计算每个位置的平均注意力
avg_attention = attn.mean(axis=0)

axes[head_idx].bar(range(len(avg_attention)), avg_attention)
axes[head_idx].set_title(f'Head {head_idx}')
axes[head_idx].set_xlabel('位置')
axes[head_idx].set_ylabel('平均注意力权重')

plt.tight_layout()
plt.show()

3. 特征重要性分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def compute_feature_importance(model, x, target_idx):
"""计算每个输入特征对预测的贡献"""
model.eval()

# 获取注意力权重
_, attn_weights = model.get_attention_weights(x)

# 对最后一个时间步的预测,计算每个输入位置的贡献
# attn_weights: (batch, nhead, seq_len, seq_len)
# 最后一个 Query 位置对所有 Key 位置的注意力
last_query_attention = attn_weights[0, :, -1, :].mean(dim=0) # 平均所有头

# 可视化
plt.figure(figsize=(12, 6))
plt.bar(range(len(last_query_attention)), last_query_attention.detach().numpy())
plt.xlabel('输入时间步')
plt.ylabel('注意力权重')
plt.title('预测时各时间步的重要性')
plt.show()

return last_query_attention

4. 注意力模式分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def classify_attention_patterns(attention_weights):
"""识别常见的注意力模式"""
patterns = {
'局部关注': [], # 关注相邻位置
'全局关注': [], # 均匀关注所有位置
'特定位置': [], # 只关注少数位置
'周期性': [] # 周期性关注模式
}

for head_idx in range(attention_weights.size(1)):
attn = attention_weights[0, head_idx].detach().numpy()

# 计算局部性(对角线附近的权重)
local_weight = np.trace(attn) + np.trace(np.roll(attn, 1, axis=0))

# 计算均匀性(熵)
entropy = -np.sum(attn * np.log(attn + 1e-10))
max_entropy = np.log(attn.shape[0])
uniformity = entropy / max_entropy

# 分类
if local_weight > 0.5:
patterns['局部关注'].append(head_idx)
elif uniformity > 0.8:
patterns['全局关注'].append(head_idx)
elif attn.max() > 0.7:
patterns['特定位置'].append(head_idx)
else:
patterns['周期性'].append(head_idx)

return patterns

5. 注意力引导的特征选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def attention_based_feature_selection(model, x, top_k=10):
"""基于注意力权重选择最重要的特征"""
model.eval()
with torch.no_grad():
_, attn_weights = model.get_attention_weights(x)

# 计算每个特征的平均注意力
# 假设 x: (batch, seq_len, features)
feature_importance = attn_weights.mean(dim=(0, 1, 3)) # 平均 batch, head, key
# feature_importance: (seq_len,) - 每个时间步的重要性

# 选择 top-k 最重要的时间步
top_indices = torch.topk(feature_importance, k=top_k).indices

return top_indices, feature_importance

# 使用
important_indices, importance_scores = attention_based_feature_selection(model, x, top_k=10)
print(f'最重要的时间步: {important_indices}')

Self-Attention vs Cross-Attention 深度对比

核心区别

Self-Attention(自注意力)

  • 都来自同一个输入序列
  • 计算序列内部的关系
  • 用于编码器( Encoder)或单序列任务

Cross-Attention(交叉注意力)

  • 来自一个序列, 来自另一个序列
  • 计算两个序列之间的关系
  • 用于解码器( Decoder)或多模态任务

数学公式对比

Self-Attention

Cross-Attention

其中 来自序列 A, 来自序列 B 。

代码实现对比

Self-Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class SelfAttention(nn.Module):
"""自注意力: Q, K, V 都来自同一个输入"""
def __init__(self, d_model, nhead):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)

def forward(self, x):
"""
x: [batch, seq_len, d_model]
返回: [batch, seq_len, d_model]
"""
# Q, K, V 都是 x
out, attn_weights = self.attention(x, x, x)
return out, attn_weights

# 使用示例:文本编码
text = torch.randn(32, 100, 512) # [batch, seq_len, d_model]
self_attn = SelfAttention(d_model=512, nhead=8)
encoded, attn = self_attn(text) # 编码文本内部关系

Cross-Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class CrossAttention(nn.Module):
"""交叉注意力: Q 来自一个序列, K 和 V 来自另一个序列"""
def __init__(self, d_model, nhead):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)

def forward(self, query_seq, key_value_seq):
"""
query_seq: [batch, seq_len_q, d_model] - 查询序列
key_value_seq: [batch, seq_len_kv, d_model] - 键值序列
返回: [batch, seq_len_q, d_model]
"""
# Q 来自 query_seq, K 和 V 来自 key_value_seq
out, attn_weights = self.attention(
query_seq, # Q
key_value_seq, # K
key_value_seq # V
)
return out, attn_weights

# 使用示例:机器翻译(解码器关注编码器输出)
encoder_output = torch.randn(32, 50, 512) # 源语言编码
decoder_input = torch.randn(32, 30, 512) # 目标语言查询
cross_attn = CrossAttention(d_model=512, nhead=8)
decoded, attn = cross_attn(decoder_input, encoder_output)
# decoded: 目标语言序列,关注源语言信息

应用场景对比

维度 Self-Attention Cross-Attention
典型应用 BERT 编码器、 GPT 、图像分类 机器翻译解码器、图像描述生成
输入数量 1 个序列 2 个序列(查询序列+键值序列)
计算关系 序列内部关系 序列间关系
注意力模式 对称矩阵($ n n n_q n_{kv}$)
可解释性 理解序列内部依赖 理解跨序列对齐关系

完整 Transformer 解码器示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class TransformerDecoderLayer(nn.Module):
"""Transformer 解码器层:包含 Self-Attention 和 Cross-Attention"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
# Self-Attention:目标序列内部关系
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)

# Cross-Attention:目标序列关注源序列
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)

# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout)
)

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
"""
tgt: [batch, tgt_len, d_model] - 目标序列(解码器输入)
memory: [batch, src_len, d_model] - 源序列(编码器输出)
"""
# 1. Self-Attention:目标序列内部关系
tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = self.norm1(tgt + self.dropout(tgt2))

# 2. Cross-Attention:目标序列关注源序列
tgt2, cross_attn_weights = self.cross_attn(
tgt, # Q: 来自目标序列
memory, # K: 来自源序列
memory, # V: 来自源序列
attn_mask=memory_mask
)
tgt = self.norm2(tgt + self.dropout(tgt2))

# 3. 前馈网络
tgt2 = self.ffn(tgt)
tgt = self.norm3(tgt + tgt2)

return tgt, self_attn_weights, cross_attn_weights

# 使用示例:机器翻译
encoder_output = torch.randn(32, 50, 512) # 源语言:"I love you"
decoder_input = torch.randn(32, 30, 512) # 目标语言:"我 爱 你"
decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)

# Self-Attention:理解"我 爱 你"的内部关系
# Cross-Attention:将"我"对齐到"I","爱"对齐到"love","你"对齐到"you"
output, self_attn, cross_attn = decoder_layer(decoder_input, encoder_output)

可视化对比

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def visualize_attention_comparison(self_attn_weights, cross_attn_weights):
"""可视化 Self-Attention 和 Cross-Attention 的区别"""
fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# Self-Attention:对称矩阵
axes[0].imshow(self_attn_weights[0, 0].detach().numpy(), cmap='YlOrRd')
axes[0].set_title('Self-Attention 权重(对称)')
axes[0].set_xlabel('Key 位置(目标序列)')
axes[0].set_ylabel('Query 位置(目标序列)')

# Cross-Attention:非对称矩阵
axes[1].imshow(cross_attn_weights[0, 0].detach().numpy(), cmap='YlOrRd')
axes[1].set_title('Cross-Attention 权重(非对称)')
axes[1].set_xlabel('Key 位置(源序列)')
axes[1].set_ylabel('Query 位置(目标序列)')

plt.tight_layout()
plt.show()

# 使用
visualize_attention_comparison(self_attn, cross_attn)

选择指南

使用 Self-Attention

  • ✅ 单序列任务(文本分类、语言模型)
  • ✅ 需要理解序列内部关系
  • ✅ 编码器架构

使用 Cross-Attention

  • ✅ 序列到序列任务(机器翻译、摘要生成)
  • ✅ 多模态任务(图像描述、视频理解)
  • ✅ 解码器架构
  • ✅ 需要跨序列对齐

Q8:如何将 Attention 机制与 LSTM/GRU 结合?

1. Attention-LSTM 架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn

class AttentionLSTM(nn.Module):
"""LSTM + Attention 混合模型"""
def __init__(self, input_size, hidden_size, num_layers, nhead):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.attention = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)

def forward(self, x):
# LSTM 编码
lstm_out, _ = self.lstm(x) # (batch, seq_len, hidden_size)

# Attention 增强
attn_out, attn_weights = self.attention(
lstm_out, lstm_out, lstm_out
) # (batch, seq_len, hidden_size)

# 融合 LSTM 和 Attention 输出
combined = lstm_out + attn_out # 残差连接

# 预测
output = self.fc(combined[:, -1, :])
return output, attn_weights

# 使用
model = AttentionLSTM(input_size=10, hidden_size=64, num_layers=2, nhead=4)
x = torch.randn(32, 50, 10)
pred, attn = model(x)

2. 双向 Attention-GRU

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class BidirectionalAttentionGRU(nn.Module):
"""双向 GRU + Attention"""
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.gru = nn.GRU(
input_size,
hidden_size,
num_layers,
batch_first=True,
bidirectional=True
)

# Attention 层
self.attention = nn.MultiheadAttention(
hidden_size * 2, # 双向所以 *2
num_heads=8,
batch_first=True
)

self.fc = nn.Linear(hidden_size * 2, 1)

def forward(self, x):
# 双向 GRU
gru_out, _ = self.gru(x)

# Attention
attn_out, attn_weights = self.attention(gru_out, gru_out, gru_out)

# 加权平均(而不是只用最后时间步)
weights = torch.softmax(attn_out.sum(dim=-1), dim=1).unsqueeze(-1)
weighted_out = (attn_out * weights).sum(dim=1)

output = self.fc(weighted_out)
return output, attn_weights

3. Hierarchical Attention(层次注意力)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class HierarchicalAttention(nn.Module):
"""层次注意力:先局部后全局"""
def __init__(self, input_size, hidden_size):
super().__init__()
# 第一层:局部 Attention(窗口内)
self.local_attention = nn.MultiheadAttention(
input_size, num_heads=4, batch_first=True
)

# GRU 处理局部特征
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)

# 第二层:全局 Attention(跨窗口)
self.global_attention = nn.MultiheadAttention(
hidden_size, num_heads=8, batch_first=True
)

self.fc = nn.Linear(hidden_size, 1)

def forward(self, x, window_size=10):
batch_size, seq_len, _ = x.size()

# 第一层:局部 Attention
local_outputs = []
for i in range(0, seq_len, window_size):
window = x[:, i:i+window_size, :]
local_out, _ = self.local_attention(window, window, window)
local_outputs.append(local_out)

local_combined = torch.cat(local_outputs, dim=1)

# GRU 处理
gru_out, _ = self.gru(local_combined)

# 第二层:全局 Attention
global_out, attn_weights = self.global_attention(
gru_out, gru_out, gru_out
)

output = self.fc(global_out[:, -1, :])
return output, attn_weights

4. 对比实验

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def compare_models(models_dict, X_test, y_test):
"""对比不同架构的性能"""
results = {}

for name, model in models_dict.items():
model.eval()
with torch.no_grad():
pred = model(X_test)
mae = torch.mean(torch.abs(pred - y_test)).item()
rmse = torch.sqrt(torch.mean((pred - y_test) ** 2)).item()

results[name] = {'MAE': mae, 'RMSE': rmse}

return results

# 对比
models = {
'LSTM': LSTMOnlyModel(...),
'GRU': GRUOnlyModel(...),
'LSTM+Attention': AttentionLSTM(...),
'GRU+Attention': BidirectionalAttentionGRU(...),
'Transformer': TransformerModel(...)
}

results = compare_models(models, X_test, y_test)
for name, metrics in results.items():
print(f'{name}: MAE={metrics["MAE"]:.4f}, RMSE={metrics["RMSE"]:.4f}')

Q9: Attention 机制在时间序列异常检测中的应用?

异常检测的 Attention 模式

异常点通常表现出: 1. 注意力权重异常:与其他时间步的关联度低 2. 重构误差大: Attention 无法很好地重构异常点 3. 注意力分布异常:注意力权重分布与正常模式不同

基于 Attention 的异常检测模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class AttentionAnomalyDetector(nn.Module):
"""使用 Attention 进行异常检测"""
def __init__(self, input_size, d_model, nhead):
super().__init__()
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model, nhead, batch_first=True),
num_layers=3
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model, nhead, batch_first=True),
num_layers=3
)
self.input_proj = nn.Linear(input_size, d_model)
self.output_proj = nn.Linear(d_model, input_size)

def forward(self, x):
# 编码
x_proj = self.input_proj(x)
encoded = self.encoder(x_proj)

# 解码(自编码器)
decoded = self.decoder(x_proj, encoded)
reconstructed = self.output_proj(decoded)

# 计算重构误差
reconstruction_error = torch.mean((x - reconstructed) ** 2, dim=-1)

return reconstructed, reconstruction_error, encoded

# 训练
model = AttentionAnomalyDetector(input_size=10, d_model=128, nhead=8)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(100):
for batch in train_loader:
reconstructed, error, _ = model(batch)
loss = criterion(reconstructed, batch)
loss.backward()
optimizer.step()

异常评分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def compute_anomaly_score(model, x, threshold_percentile=95):
"""计算异常分数"""
model.eval()
with torch.no_grad():
reconstructed, reconstruction_error, encoded = model(x)

# 方法 1:基于重构误差
error_scores = reconstruction_error.mean(dim=1) # (batch,)
threshold = np.percentile(error_scores.numpy(), threshold_percentile)
anomalies_by_error = error_scores > threshold

# 方法 2:基于注意力权重异常
# 获取注意力权重(需要修改模型返回)
attention_entropy = compute_attention_entropy(model, x)
entropy_threshold = np.percentile(attention_entropy.numpy(), threshold_percentile)
anomalies_by_attention = attention_entropy < entropy_threshold # 低熵 = 异常

# 组合两种方法
final_anomalies = anomalies_by_error | anomalies_by_attention

return {
'anomalies': final_anomalies,
'error_scores': error_scores,
'attention_entropy': attention_entropy
}

def compute_attention_entropy(model, x):
"""计算注意力权重的熵(低熵 = 异常)"""
# 获取注意力权重
attn_weights = model.get_attention_weights(x)

# 计算每个时间步的注意力熵
# attn_weights: (batch, nhead, seq_len, seq_len)
entropy = -torch.sum(
attn_weights * torch.log(attn_weights + 1e-10),
dim=-1
).mean(dim=(1, 2)) # 平均 head 和 query

return entropy

可视化异常检测结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def visualize_anomalies(data, anomalies, attention_weights):
"""可视化异常检测结果"""
fig, axes = plt.subplots(3, 1, figsize=(15, 10))

# 1. 原始数据 + 异常标记
axes[0].plot(data, label='原始数据')
anomaly_indices = np.where(anomalies)[0]
axes[0].scatter(anomaly_indices, data[anomaly_indices],
color='red', s=100, label='异常点', zorder=5)
axes[0].set_title('异常检测结果')
axes[0].legend()

# 2. 重构误差
axes[1].plot(reconstruction_error, label='重构误差')
axes[1].axhline(y=threshold, color='r', linestyle='--', label='阈值')
axes[1].set_title('重构误差')
axes[1].legend()

# 3. 注意力权重热力图(突出异常区域)
sns.heatmap(attention_weights[0, 0].detach().numpy(),
ax=axes[2], cmap='YlOrRd')
axes[2].set_title('注意力权重(异常区域通常权重异常)')

plt.tight_layout()
plt.show()

Q10:如何优化 Attention 机制的计算效率?

1. Flash Attention(内存高效)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Flash Attention 通过分块计算减少内存占用
# 标准实现需要 O(n^2) 内存存储注意力矩阵
# Flash Attention 只需要 O(n) 内存

# 注意: Flash Attention 通常需要 CUDA 实现
# 这里提供简化版本说明原理

class FlashAttention(nn.Module):
"""内存高效的注意力计算"""
def __init__(self, d_model, block_size=64):
super().__init__()
self.d_model = d_model
self.block_size = block_size # 分块大小

def forward(self, Q, K, V):
batch_size, seq_len, _ = Q.size()

# 分块计算,避免存储完整的注意力矩阵
output = torch.zeros_like(V)

for i in range(0, seq_len, self.block_size):
Q_block = Q[:, i:i+self.block_size, :]

# 对每个 Q 块,计算与所有 K 的注意力
# 但只存储输出,不存储中间注意力矩阵
scores_block = torch.matmul(Q_block, K.transpose(-2, -1))
attn_block = torch.softmax(scores_block, dim=-1)
output[:, i:i+self.block_size, :] = torch.matmul(attn_block, V)

return output

2. 低秩近似( Low-Rank Approximation)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class LowRankAttention(nn.Module):
"""使用低秩矩阵近似注意力"""
def __init__(self, d_model, rank=32):
super().__init__()
self.rank = rank
self.W_q = nn.Linear(d_model, rank)
self.W_k = nn.Linear(d_model, rank)
self.W_v = nn.Linear(d_model, d_model)

def forward(self, Q, K, V):
# 将 Q, K 投影到低维空间
Q_low = self.W_q(Q) # (batch, n, rank)
K_low = self.W_k(K) # (batch, n, rank)

# 低秩注意力矩阵: O(n * rank) 而非 O(n^2)
attn_low = torch.matmul(Q_low, K_low.transpose(-2, -1))
attn_weights = torch.softmax(attn_low, dim=-1)

# 应用到 V
output = torch.matmul(attn_weights, self.W_v(V))
return output

3. 局部敏感哈希( LSH) Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class LSHAttention(nn.Module):
"""使用 LSH 快速找到相似的 Query-Key 对"""
def __init__(self, d_model, num_hashes=4, bucket_size=64):
super().__init__()
self.num_hashes = num_hashes
self.bucket_size = bucket_size
self.random_rotations = nn.Parameter(
torch.randn(num_hashes, d_model, d_model)
)

def hash_vectors(self, x):
"""将向量哈希到桶中"""
# 使用随机旋转 + argmax 作为哈希函数
rotated = torch.matmul(x, self.random_rotations)
hashes = rotated.argmax(dim=-1)
return hashes

def forward(self, Q, K, V):
# 哈希 Q 和 K
Q_hashes = self.hash_vectors(Q)
K_hashes = self.hash_vectors(K)

# 只计算相同桶内的注意力
# 复杂度从 O(n^2) 降到 O(n * bucket_size)
output = self.sparse_attention(Q, K, V, Q_hashes, K_hashes)
return output

4. 性能优化技巧总结

方法 原理 加速比 适用场景
Flash Attention 分块计算,减少内存 2-4x 长序列(>1000)
稀疏注意力 只计算局部窗口 5-10x 局部依赖强
线性注意力 改变计算顺序 3-5x 中等长度序列
低秩近似 降维投影 2-3x 特征维度大
LSH Attention 哈希快速匹配 10-20x 超长序列(>10000)

实际应用建议

1
2
3
4
5
6
7
8
9
10
# 根据序列长度选择优化方法
def get_optimal_attention(seq_len, d_model):
if seq_len < 100:
return StandardAttention(d_model) # 标准注意力足够快
elif seq_len < 1000:
return SparseAttention(d_model, window_size=100) # 稀疏注意力
elif seq_len < 10000:
return LinearAttention(d_model) # 线性注意力
else:
return LSHAttention(d_model) # LSH 注意力

实战技巧与性能优化

Attention 机制的超参数调优

1. 注意力头数( Num Heads)选择

模型维度 推荐头数 原因
d_model = 128 4-8 头 平衡表达能力和计算成本
d_model = 256 8-16 头 标准配置
d_model = 512 8-16 头 Transformer 标准
d_model = 1024 16-32 头 大模型需要更多表示子空间

选择原则

  • 必须能被头数整除
  • 每个头的维度 通常设为 64 或 128
  • 头数过多可能导致过拟合,头数过少表达能力不足
1
2
3
4
5
6
7
8
9
10
11
def get_optimal_heads(d_model):
"""根据模型维度选择最优头数"""
# 常见的头数选择
possible_heads = [4, 8, 16, 32]
# 选择能整除 d_model 且 d_k 在合理范围内的头数
for n_heads in possible_heads:
if d_model % n_heads == 0:
d_k = d_model // n_heads
if 32 <= d_k <= 128: # d_k 在合理范围
return n_heads
return 8 # 默认值

2. 缩放因子 的重要性

为什么需要缩放?防止点积值过大导致 softmax 梯度消失:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def demonstrate_scaling_importance():
"""演示缩放的重要性"""
d_k = 64
Q = torch.randn(10, d_k)
K = torch.randn(10, d_k)

# 未缩放的点积
scores_unscaled = torch.matmul(Q, K.t())
print(f"未缩放分数范围: [{scores_unscaled.min():.2f}, {scores_unscaled.max():.2f}]")

# 缩放后的点积
scores_scaled = scores_unscaled / np.sqrt(d_k)
print(f"缩放后分数范围: [{scores_scaled.min():.2f}, {scores_scaled.max():.2f}]")

# Softmax 后的分布
probs_unscaled = F.softmax(scores_unscaled, dim=-1)
probs_scaled = F.softmax(scores_scaled, dim=-1)

# 未缩放的 softmax 可能过于尖锐(接近 one-hot)
print(f"未缩放 softmax 熵: {-torch.sum(probs_unscaled * torch.log(probs_unscaled + 1e-10)):.4f}")
print(f"缩放后 softmax 熵: {-torch.sum(probs_scaled * torch.log(probs_scaled + 1e-10)):.4f}")

3. Dropout 在 Attention 中的应用

1
2
3
4
5
6
7
8
9
10
11
12
class AttentionWithDropout(nn.Module):
"""带 Dropout 的注意力机制"""
def __init__(self, d_model, nhead, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True
)

def forward(self, x):
# Attention 内部已经应用了 Dropout
out, attn_weights = self.attention(x, x, x)
return out, attn_weights

Dropout 选择指南

  • 小数据集: 0.2-0.3
  • 中等数据集: 0.1-0.2
  • 大数据集: 0.05-0.1

时间序列 Attention 的特殊优化

1. 因果掩码的高效实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def create_causal_mask_efficient(seq_len, device):
"""高效创建因果掩码(避免存储完整矩阵)"""
# 方法 1:使用 torch.tril(下三角矩阵)
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask.bool()

# 方法 2:在计算时直接应用(更节省内存)
def causal_attention(Q, K, V):
"""直接应用因果掩码的注意力"""
seq_len = Q.size(1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(Q.size(-1))

# 创建下三角掩码
mask = torch.tril(torch.ones(seq_len, seq_len, device=Q.device))
scores = scores.masked_fill(mask == 0, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights

2. 局部注意力( Local Attention)

对于长序列,可以使用局部注意力减少计算量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class LocalAttention(nn.Module):
"""局部注意力:只关注窗口内的位置"""
def __init__(self, d_model, nhead, window_size=50):
super().__init__()
self.window_size = window_size
self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)

def forward(self, x):
batch_size, seq_len, d_model = x.size()

# 如果序列长度小于窗口大小,使用标准注意力
if seq_len <= self.window_size:
return self.attention(x, x, x)

# 否则,分块处理
outputs = []
for i in range(0, seq_len, self.window_size):
end = min(i + self.window_size, seq_len)
window_x = x[:, i:end, :]
window_out, _ = self.attention(window_x, window_x, window_x)
outputs.append(window_out)

return torch.cat(outputs, dim=1), None

3. 稀疏注意力模式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class SparseAttentionPatterns:
"""常见的稀疏注意力模式"""

@staticmethod
def strided_pattern(seq_len, stride=2):
"""步长模式:每隔 stride 个位置关注一次"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
for j in range(0, i+1, stride):
mask[i, j] = 1
return mask.bool()

@staticmethod
def dilated_pattern(seq_len, dilation=2):
"""膨胀模式:关注 dilation 倍数的位置"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
for j in range(0, i+1):
if j % dilation == 0 or j == i:
mask[i, j] = 1
return mask.bool()

@staticmethod
def random_pattern(seq_len, sparsity=0.5):
"""随机模式:随机选择 sparsity 比例的位置"""
mask = torch.rand(seq_len, seq_len) > sparsity
# 确保因果性(下三角)
mask = mask & torch.tril(torch.ones(seq_len, seq_len)).bool()
return mask

常见问题排查

问题 1: Attention 权重过于均匀(没有聚焦)

可能原因:

  • 初始化不当
  • 学习率太小
  • 数据噪声太大

解决方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 1. 使用 Xavier 初始化
def init_attention_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

model.apply(init_attention_weights)

# 2. 增加学习率
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 从 1e-4 增加到 1e-3

# 3. 数据去噪
from scipy import signal
denoised_data = signal.savgol_filter(data, window_length=5, polyorder=2)

问题 2: Attention 计算 OOM(内存不足)

解决方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 1. 使用梯度检查点
from torch.utils.checkpoint import checkpoint

class MemoryEfficientAttention(nn.Module):
def forward(self, x):
return checkpoint(self.attention, x, x, x)

# 2. 分块计算注意力
def chunked_attention(Q, K, V, chunk_size=32):
"""分块计算注意力,减少内存占用"""
batch_size, seq_len, d_model = Q.size()
outputs = []

for i in range(0, seq_len, chunk_size):
Q_chunk = Q[:, i:i+chunk_size, :]
scores = torch.matmul(Q_chunk, K.transpose(-2, -1)) / np.sqrt(d_model)
attn_weights = F.softmax(scores, dim=-1)
output_chunk = torch.matmul(attn_weights, V)
outputs.append(output_chunk)

return torch.cat(outputs, dim=1)

# 3. 使用线性注意力(降低复杂度)
# 见前面的 LinearAttention 实现

问题 3:训练不稳定( Loss 震荡)

可能原因:

  • 梯度爆炸
  • 学习率过大
  • 注意力权重分布异常

解决方案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 1. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 2. 学习率 warm-up
def get_warmup_lr(step, warmup_steps, d_model):
if step < warmup_steps:
return step / warmup_steps
else:
return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))

# 3. 监控注意力权重分布
def monitor_attention_distribution(attn_weights):
"""监控注意力权重分布"""
entropy = -torch.sum(attn_weights * torch.log(attn_weights + 1e-10), dim=-1)
print(f"注意力熵: {entropy.mean():.4f} ± {entropy.std():.4f}")

# 如果熵过低(接近 0),说明注意力过于集中,可能有问题
if entropy.mean() < 1.0:
print("警告:注意力权重过于集中!")

Attention 机制的性能基准测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import time
import torch

def benchmark_attention(attention_module, x, n_iter=100):
"""性能基准测试"""
attention_module.eval()

# 预热
with torch.no_grad():
_ = attention_module(x)

# 测试
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()

with torch.no_grad():
for _ in range(n_iter):
_ = attention_module(x)

torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = time.time() - start

# 计算 FLOPS(简化版本)
batch_size, seq_len, d_model = x.size()
flops = n_iter * batch_size * seq_len * seq_len * d_model * 2 # 简化的 FLOPS 计算
gflops = flops / 1e9 / elapsed

return elapsed / n_iter * 1000, gflops # 返回毫秒和 GFLOPS

# 测试不同配置
x = torch.randn(32, 100, 128)

configs = [
("标准 Attention", nn.MultiheadAttention(128, 8, batch_first=True)),
("稀疏 Attention", LocalAttention(128, 8, window_size=50)),
("线性 Attention", LinearAttention(128)),
]

for name, model in configs:
time_ms, gflops = benchmark_attention(model, x)
print(f"{name}: {time_ms:.2f}ms, {gflops:.2f} GFLOPS")

Attention 机制的可视化工具

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention_patterns(attn_weights, save_path=None):
"""可视化注意力模式"""
# attn_weights: (batch, nhead, seq_len, seq_len)
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

# 可视化每个头的注意力模式
for head_idx in range(min(8, attn_weights.size(1))):
attn = attn_weights[0, head_idx].detach().cpu().numpy()

sns.heatmap(attn, ax=axes[head_idx], cmap='YlOrRd', cbar=True)
axes[head_idx].set_title(f'Head {head_idx}')
axes[head_idx].set_xlabel('Key Position')
axes[head_idx].set_ylabel('Query Position')

plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()

def analyze_attention_statistics(attn_weights):
"""分析注意力统计信息"""
# 计算每个位置的平均注意力
avg_attention = attn_weights.mean(dim=(0, 1)) # 平均 batch 和 head

# 找出最受关注的位置
top_k_positions = torch.topk(avg_attention.sum(dim=0), k=10)

print("最受关注的 10 个位置:")
for idx, score in zip(top_k_positions.indices, top_k_positions.values):
print(f"位置 {idx}: {score:.4f}")

# 计算注意力熵(衡量注意力集中程度)
entropy = -torch.sum(attn_weights * torch.log(attn_weights + 1e-10), dim=-1)
print(f"\n 注意力熵: {entropy.mean():.4f} ± {entropy.std():.4f}")

return avg_attention, top_k_positions

Attention 可视化与解释性分析

注意力权重可视化工具

1. 热力图可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_attention_heatmap(attention_weights, tokens=None, head_idx=0,
save_path=None, figsize=(12, 10)):
"""
绘制注意力权重热力图

Parameters:
-----------
attention_weights : torch.Tensor
形状为 [batch, nhead, seq_len_q, seq_len_k] 的注意力权重
tokens : list, optional
Token 列表,用于标注坐标轴
head_idx : int
要可视化的头索引
save_path : str, optional
保存路径
"""
# 提取指定头的注意力权重
attn = attention_weights[0, head_idx].detach().cpu().numpy()

plt.figure(figsize=figsize)
sns.heatmap(
attn,
cmap='YlOrRd',
xticklabels=tokens if tokens else False,
yticklabels=tokens if tokens else False,
annot=False, # 如果序列短可以设为 True 显示数值
fmt='.2f',
cbar_kws={'label': '注意力权重'}
)
plt.title(f'Attention Head {head_idx} 权重分布')
plt.xlabel('Key 位置')
plt.ylabel('Query 位置')
plt.tight_layout()

if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()

# 使用示例
# attn_weights: [1, 8, 100, 100]
# tokens = ['我', '爱', '你', ...]
# plot_attention_heatmap(attn_weights, tokens=tokens, head_idx=0)

2. 多头注意力对比可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def plot_multi_head_attention(attention_weights, tokens=None, 
n_heads_to_show=8, figsize=(20, 15)):
"""
可视化多个头的注意力模式

Parameters:
-----------
attention_weights : torch.Tensor
形状为 [batch, nhead, seq_len_q, seq_len_k]
tokens : list, optional
Token 列表
n_heads_to_show : int
要显示的头数
"""
n_heads = min(attention_weights.size(1), n_heads_to_show)
n_cols = 4
n_rows = (n_heads + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = axes.flatten() if n_heads > 1 else [axes]

for head_idx in range(n_heads):
attn = attention_weights[0, head_idx].detach().cpu().numpy()

sns.heatmap(
attn,
ax=axes[head_idx],
cmap='YlOrRd',
xticklabels=False,
yticklabels=False,
cbar=True if head_idx == 0 else False
)
axes[head_idx].set_title(f'Head {head_idx}')
axes[head_idx].set_xlabel('Key 位置')
axes[head_idx].set_ylabel('Query 位置')

# 隐藏多余的子图
for idx in range(n_heads, len(axes)):
axes[idx].axis('off')

plt.tight_layout()
plt.show()

# 使用示例
# plot_multi_head_attention(attn_weights, tokens=tokens)

3. 注意力权重统计分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def analyze_attention_statistics(attention_weights):
"""
分析注意力权重的统计特性

Returns:
--------
dict: 包含各种统计信息的字典
"""
attn = attention_weights.detach().cpu().numpy()
batch_size, n_heads, seq_len_q, seq_len_k = attn.shape

stats = {
'entropy': [], # 注意力熵(衡量集中程度)
'max_weight': [], # 最大权重值
'sparsity': [], # 稀疏度(接近 0 的权重比例)
'local_focus': [] # 局部关注度(对角线附近的权重)
}

for head_idx in range(n_heads):
head_attn = attn[0, head_idx]

# 计算熵(每个 Query 位置的注意力分布熵)
entropy_per_query = []
for q_idx in range(seq_len_q):
dist = head_attn[q_idx, :]
entropy = -np.sum(dist * np.log(dist + 1e-10))
entropy_per_query.append(entropy)
stats['entropy'].append(np.mean(entropy_per_query))

# 最大权重
stats['max_weight'].append(head_attn.max())

# 稀疏度(权重小于 0.01 的比例)
stats['sparsity'].append(np.mean(head_attn < 0.01))

# 局部关注度(关注相邻 3 个位置的比例)
local_mask = np.abs(np.arange(seq_len_q)[:, None] -
np.arange(seq_len_k)[None, :]) <= 3
stats['local_focus'].append(np.mean(head_attn[local_mask]))

return stats

def plot_attention_statistics(stats):
"""可视化注意力统计信息"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

n_heads = len(stats['entropy'])

# 熵分布
axes[0, 0].bar(range(n_heads), stats['entropy'])
axes[0, 0].set_title('注意力熵(越高越分散)')
axes[0, 0].set_xlabel('头索引')
axes[0, 0].set_ylabel('平均熵')
axes[0, 0].axhline(y=np.log(n_heads), color='r', linestyle='--',
label='最大熵(均匀分布)')
axes[0, 0].legend()

# 最大权重
axes[0, 1].bar(range(n_heads), stats['max_weight'])
axes[0, 1].set_title('最大注意力权重')
axes[0, 1].set_xlabel('头索引')
axes[0, 1].set_ylabel('最大权重值')

# 稀疏度
axes[1, 0].bar(range(n_heads), stats['sparsity'])
axes[1, 0].set_title('注意力稀疏度(权重<0.01 的比例)')
axes[1, 0].set_xlabel('头索引')
axes[1, 0].set_ylabel('稀疏度')

# 局部关注度
axes[1, 1].bar(range(n_heads), stats['local_focus'])
axes[1, 1].set_title('局部关注度(相邻 3 个位置)')
axes[1, 1].set_xlabel('头索引')
axes[1, 1].set_ylabel('局部权重比例')

plt.tight_layout()
plt.show()

# 使用示例
# stats = analyze_attention_statistics(attn_weights)
# plot_attention_statistics(stats)

4. 注意力权重模式分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def classify_attention_patterns(attention_weights, threshold=0.1):
"""
自动分类注意力模式

Returns:
--------
dict: 每个头的模式分类
"""
attn = attention_weights.detach().cpu().numpy()
n_heads = attn.shape[1]

patterns = {
'local': [], # 局部关注(对角线附近)
'global': [], # 全局关注(均匀分布)
'specific': [], # 特定位置(高度集中)
'periodic': [] # 周期性模式
}

for head_idx in range(n_heads):
head_attn = attn[0, head_idx]
seq_len = head_attn.shape[0]

# 计算局部性(对角线附近权重)
local_mask = np.abs(np.arange(seq_len)[:, None] -
np.arange(seq_len)[None, :]) <= 5
local_ratio = np.mean(head_attn[local_mask])

# 计算均匀性(熵)
entropy = -np.sum(head_attn * np.log(head_attn + 1e-10))
max_entropy = np.log(seq_len)
uniformity = entropy / max_entropy

# 计算集中度(最大权重)
max_weight = head_attn.max()

# 分类
if local_ratio > 0.5:
patterns['local'].append(head_idx)
elif uniformity > 0.8:
patterns['global'].append(head_idx)
elif max_weight > 0.7:
patterns['specific'].append(head_idx)
else:
patterns['periodic'].append(head_idx)

return patterns

# 使用示例
# patterns = classify_attention_patterns(attn_weights)
# print("局部关注头:", patterns['local'])
# print("全局关注头:", patterns['global'])
# print("特定位置头:", patterns['specific'])

5. 时间序列注意力可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def plot_timeseries_attention(attention_weights, timestamps, values=None,
head_idx=0, top_k=5):
"""
可视化时间序列的注意力权重

Parameters:
-----------
attention_weights : torch.Tensor
注意力权重 [batch, nhead, seq_len_q, seq_len_k]
timestamps : array-like
时间戳列表
values : array-like, optional
时间序列值(用于叠加显示)
head_idx : int
要可视化的头索引
top_k : int
显示 top-k 个最受关注的时间步
"""
attn = attention_weights[0, head_idx].detach().cpu().numpy()
seq_len = attn.shape[0]

# 计算每个时间步被关注的总权重
attention_sums = attn.sum(axis=0) # [seq_len]
top_indices = np.argsort(attention_sums)[-top_k:][::-1]

fig, axes = plt.subplots(2, 1, figsize=(15, 10))

# 上图:时间序列 + 注意力权重
if values is not None:
axes[0].plot(timestamps, values, 'b-', label='时间序列', linewidth=2)
axes[0].scatter(
timestamps[top_indices],
values[top_indices],
color='red', s=200, zorder=5, label=f'Top-{top_k}关注点'
)
axes[0].set_title('时间序列与注意力关注点')
axes[0].set_xlabel('时间')
axes[0].set_ylabel('值')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 下图:注意力权重热力图
im = axes[1].imshow(attn, cmap='YlOrRd', aspect='auto')
axes[1].set_title(f'Attention Head {head_idx} 权重热力图')
axes[1].set_xlabel('Key 时间步')
axes[1].set_ylabel('Query 时间步')
plt.colorbar(im, ax=axes[1], label='注意力权重')

# 标注 top-k 位置
for idx in top_indices:
axes[1].axvline(x=idx, color='blue', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print(f"Top-{top_k}最受关注的时间步:")
for i, idx in enumerate(top_indices):
print(f" {i+1}. 时间步 {idx} (权重: {attention_sums[idx]:.4f})")

# 使用示例
# timestamps = pd.date_range('2024-01-01', periods=100, freq='D')
# values = np.random.randn(100).cumsum()
# plot_timeseries_attention(attn_weights, timestamps, values, head_idx=0)

注意力解释性分析

1. 特征重要性分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def compute_feature_importance(attention_weights, input_features):
"""
计算每个输入特征对预测的贡献

Parameters:
-----------
attention_weights : torch.Tensor
注意力权重 [batch, nhead, seq_len_q, seq_len_k]
input_features : torch.Tensor
输入特征 [batch, seq_len, d_model]

Returns:
--------
importance : torch.Tensor
特征重要性 [batch, seq_len]
"""
# 对最后一个 Query 位置(预测位置)的注意力权重
last_query_attention = attention_weights[:, :, -1, :].mean(dim=1) # [batch, seq_len_k]

# 加权特征重要性
importance = last_query_attention.unsqueeze(-1) * input_features # [batch, seq_len, d_model]
importance = importance.abs().sum(dim=-1) # [batch, seq_len]

return importance

# 使用示例
# importance = compute_feature_importance(attn_weights, input_features)
# plt.bar(range(len(importance[0])), importance[0].detach().numpy())
# plt.xlabel('时间步')
# plt.ylabel('重要性')
# plt.title('各时间步对预测的贡献')
# plt.show()

2. 注意力引导的特征选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def attention_based_feature_selection(attention_weights, top_k=10):
"""
基于注意力权重选择最重要的特征

Returns:
--------
selected_indices : list
选中的特征索引
"""
# 平均所有头和 Query 位置
avg_attention = attention_weights.mean(dim=(0, 1, 2)) # [seq_len_k]

# 选择 top-k
top_indices = torch.topk(avg_attention, k=top_k).indices.tolist()

return top_indices

# 使用示例
# selected = attention_based_feature_selection(attn_weights, top_k=10)
# print(f"最重要的 10 个时间步: {selected}")

🎓 总结: Attention 核心要点

自注意力计算流程线

记忆口诀: > Q 问 K 答计算分数,缩放 softmax 归一权重,权重乘 V 得到输出,多头并行捕捉特征!

实战优化清单

  • 本文标题:时间序列模型(四)—— Attention 机制
  • 本文作者:Chen Kai
  • 创建时间:2020-05-02 15:00:00
  • 本文链接:https://www.chenk.top/%E6%97%B6%E9%97%B4%E5%BA%8F%E5%88%97%E6%A8%A1%E5%9E%8B%EF%BC%88%E5%9B%9B%EF%BC%89%E2%80%94%E2%80%94-Attention%E6%9C%BA%E5%88%B6/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论