在时间序列预测领域,大多数深度学习模型都像"黑盒"一样工作——我们能看到预测结果,却很难理解模型到底学到了什么。
N-BEATS( Neural Basis Expansion Analysis for Time
Series)的出现改变了这一现状。这个在 M4
时间序列竞赛中夺冠的模型,不仅预测精度高,更重要的是它通过可解释的架构设计,让我们能够清晰地看到模型是如何分解时间序列的趋势和季节性成分的。
N-BEATS 的核心创新在于"基函数展开"( Basis
Expansion)的思想。就像傅里叶变换将信号分解为正弦波和余弦波的组合一样,
N-BEATS
通过神经网络学习基函数,将时间序列分解为趋势和季节性成分。更巧妙的是,它通过"双残差堆叠"(
Double Residual
Stacking)的设计,让每个块都能专注于提取特定尺度的模式,最终实现多层次的模式识别。
下面深入解析 N-BEATS
的架构设计,从可解释性架构到通用架构,从趋势块到季节性块,并附上完整的
PyTorch 实现和两个实战案例。无论你是时间序列预测的初学者,还是希望理解
M4 竞赛冠军方案的技术细节,都能找到清晰的路径。
N-BEATS 核心思想:基函数展开
从傅里叶变换到神经网络
在信号处理中,傅里叶变换告诉我们:任何周期信号都可以表示为正弦波和余弦波的线性组合。
N-BEATS 借鉴了这一思想,但用神经网络替代了固定的三角函数基。
传统的时间序列分解通常假设:
趋势是多项式函数(如线性、二次)
季节性成分是固定周期的正弦波
但现实中的时间序列往往更复杂:
趋势可能是非线性的,甚至在不同时间段有不同的斜率
季节性可能不是严格的周期,而是缓慢变化的模式
N-BEATS
的解决方案是:让神经网络学习基函数 。每个"块"(
Block)负责学习一组基函数,这些基函数可以灵活地表示趋势或季节性模式。
基函数展开的数学形式
假设我们要预测未来
个时间步,给定历史
个时间步的观测值 。
N-BEATS 将预测分解为:
$$
y_{T+h} = _{k=1}^{K} _k b_k(h) $$
其中:
是第 个基函数,输入是未来时间步
是基函数的系数,由神经网络从历史数据中学习
对于趋势块,基函数通常是多项式:
$$
b_k^{}(h) = h^{k-1} $$
对于季节性块,基函数是正弦/余弦函数:
$$
b_k^{}(h) = () () $$
其中 是周期长度。
可解释性架构 vs 通用架构
N-BEATS 提供了两种架构模式:可解释性架构(
Interpretable)和通用架构( Generic)。
可解释性架构
在可解释性架构中,每个块被明确指定为"趋势块"或"季节性块":
1 2 3 4 5 6 7 blocks = [ TrendBlock(...), TrendBlock(...), SeasonalityBlock(...), SeasonalityBlock(...), ]
优点 :
预测结果可以明确分解为趋势和季节性成分
便于业务理解和调试
符合传统时间序列分析的习惯
缺点 :
需要预先知道时间序列的特性(是否有季节性、周期是多少)
灵活性较低
通用架构
在通用架构中,所有块都是相同的通用块( Generic
Block),不强制指定功能:
1 2 3 4 5 6 7 blocks = [ GenericBlock(...), GenericBlock(...), GenericBlock(...), GenericBlock(...), ]
优点 :
完全由数据驱动,不需要先验知识
可以学习任意复杂的模式
在 M4 竞赛中表现更好
缺点 :
如何选择?
业务场景需要可解释性 :选择可解释性架构
追求最高精度 :选择通用架构
不确定数据特性 :先用通用架构,再根据结果调整
基函数展开与残差连接
单个块的结构
每个 N-BEATS 块包含以下组件:
全连接层( FC Layers) :从历史数据中提取特征
前向扩展( Forward Expansion) :生成基函数系数 $$3.
后向扩展( Backward
Expansion) :生成回看窗口的拟合值
基函数层( Basis Layer) :应用基函数展开
让我们看一个简化的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 class NBeatsBlock (nn.Module): def __init__ (self, input_size, theta_size, basis_function ): super ().__init__() self.fc_layers = nn.Sequential( nn.Linear(input_size, 512 ), nn.ReLU(), nn.Linear(512 , 512 ), nn.ReLU(), nn.Linear(512 , theta_size) ) self.basis_function = basis_function def forward (self, x ): theta = self.fc_layers(x) forecast = self.basis_function.forward_expansion(theta) backcast = self.basis_function.backward_expansion(theta) return forecast, backcast
残差连接的作用
残差连接是 N-BEATS
的关键设计。每个块不仅预测未来,还拟合历史数据。拟合的残差会传递给下一个块:
1 2 3 4 5 6 7 8 9 10 11 12 13 def forward_through_stacks (self, x ): residuals = x for block in self.blocks: forecast, backcast = block(residuals) total_forecast += forecast residuals = residuals - backcast return total_forecast, residuals
这样设计的好处是:
渐进式分解 :每个块专注于提取特定尺度的模式
避免信息丢失 :残差连接确保信息不会在深层网络中丢失
多尺度特征 :不同块可以捕获不同时间尺度的模式
趋势块( Trend Block)详解
多项式基函数
趋势块使用多项式基函数来建模趋势。对于 个基函数:
$$
b_k^{} (h) = h^{k-1}, k = 1, 2, , K$$
这意味着:
$ k=1: 常 数 项 ( h^0 = 1$)
$ k=2: 线 性 项 ( h^1 = h$)
$ k=3: 二 次 项 ( h^2$)
$ k=4: 三 次 项 ( h^3$)
...
预测公式为:
完整实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 class TrendBasis (nn.Module): """趋势基函数:多项式展开""" def __init__ (self, degree=4 ): super ().__init__() self.degree = degree def forward_expansion (self, theta, forecast_length ): """ theta: [batch_size, degree] forecast_length: 预测长度 H 返回: [batch_size, forecast_length] """ batch_size = theta.shape[0 ] t = torch.arange(forecast_length, dtype=theta.dtype, device=theta.device) t = t.unsqueeze(0 ).expand(batch_size, -1 ) basis = torch.zeros(batch_size, forecast_length, self.degree, dtype=theta.dtype, device=theta.device) for k in range (self.degree): basis[:, :, k] = t ** k forecast = torch.sum (basis * theta.unsqueeze(1 ), dim=2 ) return forecast def backward_expansion (self, theta, backcast_length ): """拟合历史数据""" return self.forward_expansion(theta, backcast_length) class TrendBlock (nn.Module): """趋势块""" def __init__ (self, input_size, forecast_length, backcast_length, hidden_size=512 , num_layers=4 , degree=4 ): super ().__init__() self.forecast_length = forecast_length self.backcast_length = backcast_length layers = [] layers.append(nn.Linear(input_size, hidden_size)) for _ in range (num_layers - 2 ): layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, degree)) self.fc_stack = nn.Sequential(*layers) self.basis = TrendBasis(degree) def forward (self, x ): """ x: [batch_size, backcast_length] """ x_flat = x.mean(dim=1 ) x_flat = x_flat.unsqueeze(1 ) theta = self.fc_stack(x_flat) forecast = self.basis.forward_expansion(theta, self.forecast_length) backcast = self.basis.backward_expansion(theta, self.backcast_length) return forecast, backcast
趋势块的特点
可解释性强 :系数 直接对应多项式的各项
外推能力强 :多项式可以很好地外推到未来
适合长期预测 :对于趋势明显的序列,多项式能很好地捕捉长期变化
季节性块( Seasonality
Block)详解
傅里叶基函数
季节性块使用傅里叶基函数(正弦和余弦)来建模周期性模式:
$$
b_{2k-1}^{} (h) = ()$
$
$$
b_{2k}^{} (h) = ()$$
其中:
是周期长度(如:日数据 ,月数据 )
是谐波阶数( )
预测公式为:
完整实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 class SeasonalityBasis (nn.Module): """季节性基函数:傅里叶展开""" def __init__ (self, harmonics=10 , period=None ): super ().__init__() self.harmonics = harmonics self.period = period def forward_expansion (self, theta, forecast_length, period=None ): """ theta: [batch_size, 2 * harmonics] forecast_length: 预测长度 H period: 周期长度(如果为 None,使用 self.period) """ if period is None : period = self.period if self.period is not None else forecast_length batch_size = theta.shape[0 ] t = torch.arange(forecast_length, dtype=theta.dtype, device=theta.device) t = t.unsqueeze(0 ).expand(batch_size, -1 ) forecast = torch.zeros(batch_size, forecast_length, dtype=theta.dtype, device=theta.device) for k in range (1 , self.harmonics + 1 ): sin_coef = theta[:, 2 * k - 2 ] cos_coef = theta[:, 2 * k - 1 ] sin_basis = torch.sin(2 * np.pi * k * t / period) cos_basis = torch.cos(2 * np.pi * k * t / period) forecast += sin_coef.unsqueeze(1 ) * sin_basis forecast += cos_coef.unsqueeze(1 ) * cos_basis return forecast def backward_expansion (self, theta, backcast_length, period=None ): """拟合历史数据""" return self.forward_expansion(theta, backcast_length, period) class SeasonalityBlock (nn.Module): """季节性块""" def __init__ (self, input_size, forecast_length, backcast_length, hidden_size=512 , num_layers=4 , harmonics=10 , period=None ): super ().__init__() self.forecast_length = forecast_length self.backcast_length = backcast_length self.period = period layers = [] layers.append(nn.Linear(input_size, hidden_size)) for _ in range (num_layers - 2 ): layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, 2 * harmonics)) self.fc_stack = nn.Sequential(*layers) self.basis = SeasonalityBasis(harmonics, period) def forward (self, x ): """ x: [batch_size, backcast_length] """ x_flat = x.mean(dim=1 ).unsqueeze(1 ) theta = self.fc_stack(x_flat) period = self.period if period is None : period = self.backcast_length forecast = self.basis.forward_expansion(theta, self.forecast_length, period) backcast = self.basis.backward_expansion(theta, self.backcast_length, period) return forecast, backcast
季节性块的特点
周期性强 :能很好地捕捉周期性模式
灵活性高 :通过多个谐波可以表示复杂的周期性
需要周期先验 :最好预先知道周期长度,但也可以通过数据学习
双残差堆叠( Double
Residual Stacking)
什么是双残差堆叠?
N-BEATS 的核心创新之一是"双残差堆叠"机制。每个块产生两个输出:
Forecast(预测) :对未来时间步的预测
Backcast(回看) :对历史时间步的拟合
这两个输出通过不同的路径传播:
Forecast 向上累加,形成最终预测
Backcast 用于计算残差,传递给下一个块
可视化理解
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 输入序列: [y1, y2, ..., yT] | v [Block 1] / \ Forecast1 Backcast1 | | | v | 残差 1 = 输入 - Backcast1 | | | v | [Block 2] | / \ | Forecast2 Backcast2 | | | | | v | | 残差 2 = 残差 1 - Backcast2 | | | | | v | | [Block 3] | | ... | | v v 最终预测 = Forecast1 + Forecast2 + Forecast3 + ...
代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 class NBeatsStack (nn.Module): """N-BEATS 堆叠:多个块的组合""" def __init__ (self, blocks ): super ().__init__() self.blocks = nn.ModuleList(blocks) def forward (self, x ): """ x: [batch_size, backcast_length] 返回: forecast, backcast """ forecast = torch.zeros(x.shape[0 ], self.blocks[0 ].forecast_length, dtype=x.dtype, device=x.device) residuals = x for block in self.blocks: block_forecast, block_backcast = block(residuals) forecast = forecast + block_forecast residuals = residuals - block_backcast return forecast, residuals class NBeatsModel (nn.Module): """完整的 N-BEATS 模型""" def __init__ (self, stacks ): super ().__init__() self.stacks = nn.ModuleList(stacks) def forward (self, x ): """ x: [batch_size, backcast_length] """ total_forecast = torch.zeros(x.shape[0 ], self.stacks[0 ].blocks[0 ].forecast_length, dtype=x.dtype, device=x.device) for stack in self.stacks: stack_forecast, _ = stack(x) total_forecast = total_forecast + stack_forecast return total_forecast
为什么这样设计有效?
多尺度特征提取 :不同块可以专注于不同时间尺度的模式
渐进式细化 :每个块处理前一个块的残差,逐步提取更细粒度的模式
信息保留 : Forecast
的累加确保所有块学到的信息都被保留
训练稳定性 :残差连接有助于梯度流动,使深层网络更容易训练
M4 竞赛冠军方案深度分析
M4 竞赛背景
M4( Makridakis 4)是时间序列预测领域最权威的竞赛之一,包含 100,000
个时间序列,涵盖年度、季度、月度、周度、日度和小时数据。
N-BEATS 的配置
在 M4 竞赛中, N-BEATS 使用了以下配置:
通用架构 :所有块都是 Generic Block
30 个块 :分为 3 个堆叠,每个堆叠 10 个块
输入长度 :根据数据频率自适应选择
集成学习 :训练多个模型并集成
关键技巧
1. 输入长度选择
1 2 3 4 5 6 7 8 9 10 11 def get_input_length (frequency ): """根据数据频率选择输入长度""" input_lengths = { 'Yearly' : 6 , 'Quarterly' : 8 , 'Monthly' : 24 , 'Weekly' : 13 , 'Daily' : 14 , 'Hourly' : 48 } return input_lengths.get(frequency, 24 )
2. 数据归一化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class Normalizer : """数据归一化""" def __init__ (self ): self.scale = None def fit_transform (self, x ): """拟合并转换""" self.scale = torch.abs (x[:, -1 :]) + 1e-8 return x / self.scale def inverse_transform (self, x ): """逆转换""" return x * self.scale
3. 损失函数
N-BEATS 使用 MAPE( Mean Absolute Percentage Error)和 sMAPE
的组合:
1 2 3 4 5 6 7 8 9 10 11 12 13 def mape_loss (y_pred, y_true ): """平均绝对百分比误差""" return torch.mean(torch.abs ((y_true - y_pred) / (y_true + 1e-8 ))) * 100 def smape_loss (y_pred, y_true ): """对称平均绝对百分比误差""" numerator = torch.abs (y_pred - y_true) denominator = (torch.abs (y_pred) + torch.abs (y_true)) / 2 + 1e-8 return torch.mean(numerator / denominator) * 100 def combined_loss (y_pred, y_true ): """组合损失""" return mape_loss(y_pred, y_true) + smape_loss(y_pred, y_true)
为什么 N-BEATS 能夺冠?
架构优势 :双残差堆叠能够提取多尺度特征
通用性强 :不需要针对不同频率的数据做特殊处理
训练稳定 :残差连接使深层网络训练更稳定
可扩展性 :可以通过增加块的数量来提升性能
PyTorch 完整实现
下面是一个完整的、可运行的 N-BEATS 实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 import torchimport torch.nn as nnimport numpy as npfrom typing import List , Optional class GenericBasis (nn.Module): """通用基函数:可学习的基函数""" def __init__ (self, basis_size=10 ): super ().__init__() self.basis_size = basis_size self.register_parameter('basis_weights' , nn.Parameter(torch.randn(basis_size, 100 ))) def forward_expansion (self, theta, forecast_length ): """ theta: [batch_size, basis_size] """ batch_size = theta.shape[0 ] t = torch.linspace(0 , 1 , forecast_length, device=theta.device) t = t.unsqueeze(0 ).expand(batch_size, -1 ) basis_values = torch.zeros(batch_size, forecast_length, self.basis_size, device=theta.device) for i in range (self.basis_size): basis_values[:, :, i] = torch.sin(2 * np.pi * (i + 1 ) * t) forecast = torch.sum (basis_values * theta.unsqueeze(1 ), dim=2 ) return forecast def backward_expansion (self, theta, backcast_length ): return self.forward_expansion(theta, backcast_length) class GenericBlock (nn.Module): """通用块""" def __init__ (self, input_size, forecast_length, backcast_length, hidden_size=512 , num_layers=4 , basis_size=10 ): super ().__init__() self.forecast_length = forecast_length self.backcast_length = backcast_length layers = [] layers.append(nn.Linear(input_size, hidden_size)) for _ in range (num_layers - 2 ): layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.ReLU()) layers.append(nn.Linear(hidden_size, basis_size)) self.fc_stack = nn.Sequential(*layers) self.basis = GenericBasis(basis_size) def forward (self, x ): """ x: [batch_size, backcast_length] """ x_flat = x.mean(dim=1 ).unsqueeze(1 ) theta = self.fc_stack(x_flat) forecast = self.basis.forward_expansion(theta, self.forecast_length) backcast = self.basis.backward_expansion(theta, self.backcast_length) return forecast, backcast class NBeatsStack (nn.Module): """N-BEATS 堆叠""" def __init__ (self, blocks ): super ().__init__() self.blocks = nn.ModuleList(blocks) def forward (self, x ): forecast = torch.zeros(x.shape[0 ], self.blocks[0 ].forecast_length, dtype=x.dtype, device=x.device) residuals = x for block in self.blocks: block_forecast, block_backcast = block(residuals) forecast = forecast + block_forecast residuals = residuals - block_backcast return forecast, residuals class NBeatsModel (nn.Module): """完整的 N-BEATS 模型""" def __init__ (self, forecast_length, backcast_length, num_stacks=3 , num_blocks_per_stack=10 , hidden_size=512 , num_layers=4 , basis_size=10 ): super ().__init__() self.forecast_length = forecast_length self.backcast_length = backcast_length self.stacks = nn.ModuleList() for _ in range (num_stacks): blocks = [ GenericBlock( input_size=1 , forecast_length=forecast_length, backcast_length=backcast_length, hidden_size=hidden_size, num_layers=num_layers, basis_size=basis_size ) for _ in range (num_blocks_per_stack) ] self.stacks.append(NBeatsStack(blocks)) def forward (self, x ): """ x: [batch_size, backcast_length] """ total_forecast = torch.zeros(x.shape[0 ], self.forecast_length, dtype=x.dtype, device=x.device) for stack in self.stacks: stack_forecast, _ = stack(x) total_forecast = total_forecast + stack_forecast return total_forecast if __name__ == "__main__" : model = NBeatsModel( forecast_length=24 , backcast_length=48 , num_stacks=3 , num_blocks_per_stack=10 ) batch_size = 32 x = torch.randn(batch_size, 48 ) forecast = model(x) print (f"输入形状: {x.shape} " ) print (f"预测形状: {forecast.shape} " )
实战案例一:零售销售预测
问题描述
预测某零售店未来 7 天的日销售额。历史数据包含过去 90
天的销售额,有明显的周季节性(周末销售额较高)。
数据准备
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import pandas as pdimport numpy as npimport torchfrom torch.utils.data import Dataset, DataLoaderclass SalesDataset (Dataset ): def __init__ (self, data, backcast_length=28 , forecast_length=7 ): self.data = data self.backcast_length = backcast_length self.forecast_length = forecast_length def __len__ (self ): return len (self.data) - self.backcast_length - self.forecast_length + 1 def __getitem__ (self, idx ): x = self.data[idx:idx + self.backcast_length] y = self.data[idx + self.backcast_length: idx + self.backcast_length + self.forecast_length] return torch.FloatTensor(x), torch.FloatTensor(y) def generate_sales_data (n_days=365 ): """生成带趋势和季节性的销售数据""" t = np.arange(n_days) trend = 1000 + 2 * t seasonality = 200 * np.sin(2 * np.pi * t / 7 ) + 100 * np.cos(2 * np.pi * t / 7 ) noise = np.random.normal(0 , 50 , n_days) sales = trend + seasonality + noise return sales sales_data = generate_sales_data(365 ) dataset = SalesDataset(sales_data, backcast_length=28 , forecast_length=7 ) dataloader = DataLoader(dataset, batch_size=32 , shuffle=True )
模型训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 from nbeats import SeasonalityBlock, NBeatsStack, NBeatsModelblocks_stack1 = [ SeasonalityBlock( input_size=1 , forecast_length=7 , backcast_length=28 , harmonics=5 , period=7 ) for _ in range (5 ) ] blocks_stack2 = [ SeasonalityBlock( input_size=1 , forecast_length=7 , backcast_length=28 , harmonics=3 , period=7 ) for _ in range (5 ) ] stacks = [ NBeatsStack(blocks_stack1), NBeatsStack(blocks_stack2) ] model = NBeatsModel(stacks) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3 ) criterion = nn.MSELoss() for epoch in range (100 ): total_loss = 0 for x, y in dataloader: optimizer.zero_grad() forecast = model(x) loss = criterion(forecast, y) loss.backward() optimizer.step() total_loss += loss.item() if (epoch + 1 ) % 10 == 0 : print (f"Epoch {epoch+1 } , Loss: {total_loss/len (dataloader):.4 f} " )
结果分析
训练完成后,可以分析每个块学到的模式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import matplotlib.pyplot as pltmodel.eval () with torch.no_grad(): x_test = torch.FloatTensor(sales_data[-28 :]).unsqueeze(0 ) forecast = model(x_test) plt.figure(figsize=(12 , 6 )) plt.plot(range (-28 , 0 ), sales_data[-28 :], label='历史数据' , marker='o' ) plt.plot(range (0 , 7 ), forecast[0 ].numpy(), label='预测' , marker='s' ) plt.plot(range (0 , 7 ), sales_data[-7 :], label='真实值' , marker='x' ) plt.legend() plt.title('零售销售预测结果' ) plt.xlabel('天数' ) plt.ylabel('销售额' ) plt.grid(True ) plt.show()
实战案例二:电力负荷预测
问题描述
预测未来 24 小时的电力负荷。历史数据包含过去 48
小时的负荷,具有明显的日周期性和周周期性。
数据特点
电力负荷数据通常具有:
日周期性 :白天负荷高,夜间负荷低
周周期性 :工作日和周末的负荷模式不同
趋势性 :长期可能有缓慢变化
模型设计
对于这种多周期的情况,可以使用混合架构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 trend_blocks = [ TrendBlock( input_size=1 , forecast_length=24 , backcast_length=48 , degree=3 ) for _ in range (3 ) ] daily_seasonal_blocks = [ SeasonalityBlock( input_size=1 , forecast_length=24 , backcast_length=48 , harmonics=5 , period=24 ) for _ in range (5 ) ] weekly_seasonal_blocks = [ SeasonalityBlock( input_size=1 , forecast_length=24 , backcast_length=48 , harmonics=3 , period=168 ) for _ in range (5 ) ] trend_stack = NBeatsStack(trend_blocks) daily_stack = NBeatsStack(daily_seasonal_blocks) weekly_stack = NBeatsStack(weekly_seasonal_blocks) model = NBeatsModel([trend_stack, daily_stack, weekly_stack])
训练与评估
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 def train_electricity_model (model, dataloader, num_epochs=100 ): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3 ) criterion = nn.MSELoss() for epoch in range (num_epochs): model.train() total_loss = 0 for x, y in dataloader: optimizer.zero_grad() forecast = model(x) loss = criterion(forecast, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) optimizer.step() total_loss += loss.item() if (epoch + 1 ) % 20 == 0 : print (f"Epoch {epoch+1 } , Loss: {total_loss/len (dataloader):.4 f} " ) return model def evaluate_model (model, test_loader ): model.eval () mae_list = [] mape_list = [] with torch.no_grad(): for x, y in test_loader: forecast = model(x) mae = torch.mean(torch.abs (forecast - y)) mape = torch.mean(torch.abs ((forecast - y) / (y + 1e-8 ))) * 100 mae_list.append(mae.item()) mape_list.append(mape.item()) print (f"平均 MAE: {np.mean(mae_list):.2 f} " ) print (f"平均 MAPE: {np.mean(mape_list):.2 f} %" )
❓ Q&A: N-BEATS 常见问题
Q1: N-BEATS
和传统时间序列模型(如 ARIMA)有什么区别?
A : 主要区别在于:
非线性建模能力 : N-BEATS
使用神经网络,可以学习复杂的非线性模式; ARIMA 是线性模型
可解释性 : N-BEATS
的可解释性架构可以明确分解趋势和季节性; ARIMA
的参数也有统计意义,但解释方式不同
特征工程 : N-BEATS 可以自动学习特征; ARIMA
需要手动选择参数( p, d, q)
数据要求 : ARIMA 需要平稳序列,可能需要差分;
N-BEATS 可以直接处理非平稳序列
Q2: 如何选择输入长度(
backcast_length)?
A : 输入长度的选择取决于:
数据频率 :
年度数据: 6-10 个时间步
季度数据: 8-12 个时间步
月度数据: 24-36 个时间步
周度数据: 13-26 个时间步
日度数据: 14-30 个时间步
小时数据: 48-168 个时间步
预测长度 :通常输入长度是预测长度的 2-4
倍
数据特性 :如果数据有明显的长期趋势,需要更长的输入窗口
Q3:
通用架构和可解释性架构哪个更好?
A : 这取决于你的目标:
追求最高精度 :选择通用架构,在 M4
竞赛中表现更好
需要业务解释 :选择可解释性架构,可以明确看到趋势和季节性分解
不确定数据特性 :先用通用架构,如果效果不好再尝试可解释性架构
Q4:
如何确定基函数的数量( basis_size 或 harmonics)?
A : 基函数数量的选择:
趋势块 :通常 3-5 个多项式项就足够( degree=3 到
5)
季节性块 :谐波数量(
harmonics)取决于周期的复杂性:
简单周期: 3-5 个谐波
复杂周期: 5-10 个谐波
非常复杂的周期: 10-20 个谐波
通用块 :通常 10-20 个基函数
建议 :从较小的数量开始,如果欠拟合再增加。
Q5: N-BEATS
的训练时间很长,如何加速?
A : 加速训练的方法:
减少块的数量 :从 30 个块减少到 10-15 个
减少隐藏层大小 :从 512 减少到 256 或 128
使用 GPU : N-BEATS 可以很好地利用 GPU 并行计算
批量大小 :增加批量大小可以提高 GPU 利用率
混合精度训练 :使用torch.cuda.amp进行半精度训练
1 2 3 4 5 6 7 8 9 10 11 from torch.cuda.amp import autocast, GradScalerscaler = GradScaler() for x, y in dataloader: optimizer.zero_grad() with autocast(): forecast = model(x) loss = criterion(forecast, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
Q6: N-BEATS
如何处理多变量时间序列?
A : N-BEATS
最初设计用于单变量时间序列。对于多变量情况,有几种方法:
独立预测 :对每个变量分别训练一个 N-BEATS 模型
特征融合 :将多个变量作为输入特征:
1 2 3 4 class MultiVariateNBeatsBlock (nn.Module): def __init__ (self, input_size, num_variables, ... ): self.fc_stack = nn.Sequential(...)
使用其他模型 :考虑 N-BEATS 的扩展版本,如 N-HiTS 或
Temporal Fusion Transformer
Q7: 如何处理缺失值?
A : 处理缺失值的方法:
前向填充 :用前一个值填充
插值 :线性插值或样条插值
模型处理 :在输入层添加掩码,让模型学习处理缺失值:
1 2 3 4 5 6 class NBeatsBlockWithMasking (nn.Module): def forward (self, x, mask ): x_masked = x * mask ...
Q8: N-BEATS
的预测结果不稳定怎么办?
A : 提高稳定性的方法:
集成学习 :训练多个模型并平均预测结果
正则化 :添加 L1/L2 正则化或 Dropout
学习率调度 :使用学习率衰减
梯度裁剪 :防止梯度爆炸
数据归一化 :确保输入数据在合理范围内
1 2 3 4 5 6 7 8 9 def ensemble_predict (models, x ): predictions = [] for model in models: model.eval () with torch.no_grad(): pred = model(x) predictions.append(pred) return torch.stack(predictions).mean(dim=0 )
Q9: 如何解释 N-BEATS
学到的趋势和季节性?
A : 对于可解释性架构:
趋势块 :查看多项式系数 ,了解趋势的形状
季节性块 :可视化学到的傅里叶基函数:
1 2 3 4 5 6 7 def visualize_seasonality (block, period=7 ): theta = torch.randn(1 , 2 * block.harmonics) seasonal_pattern = block.basis.forward_expansion(theta, period, period) plt.plot(seasonal_pattern[0 ].numpy()) plt.title('学到的季节性模式' ) plt.show()
残差分析 :查看每个块处理后的残差,了解还有哪些模式未被提取
Q10: N-BEATS
适合哪些应用场景?
A : N-BEATS 适合的场景:
单变量时间序列预测 :零售销售、电力负荷、网站流量等
需要可解释性的场景 :业务分析、决策支持
中等长度的预测 :从几小时到几个月
有明显趋势或季节性的数据 : N-BEATS
的优势在于分解这些成分
不适合的场景 : 1.
多变量强相关 :变量间有复杂依赖关系 2.
外部特征重要 :需要结合外部变量(如天气、事件) 3.
极短期预测 :几秒钟或几分钟的预测可能过于复杂 4.
实时性要求极高 : N-BEATS 的计算相对较慢
Q11: 如何评估 N-BEATS
模型的性能?
A : 评估 N-BEATS 需要考虑多个指标:
1. 预测精度指标 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def evaluate_nbeats (y_pred, y_true ): """计算多个评估指标""" metrics = {} metrics['MAE' ] = torch.mean(torch.abs (y_pred - y_true)).item() metrics['RMSE' ] = torch.sqrt(torch.mean((y_pred - y_true) ** 2 )).item() metrics['MAPE' ] = torch.mean(torch.abs ((y_true - y_pred) / (y_true + 1e-8 ))) * 100 numerator = torch.abs (y_pred - y_true) denominator = (torch.abs (y_pred) + torch.abs (y_true)) / 2 + 1e-8 metrics['sMAPE' ] = torch.mean(numerator / denominator) * 100 naive_forecast = y_true[:-1 ] naive_mae = torch.mean(torch.abs (y_true[1 :] - naive_forecast)) metrics['MASE' ] = metrics['MAE' ] / (naive_mae + 1e-8 ) return metrics metrics = evaluate_nbeats(predictions, ground_truth) for name, value in metrics.items(): print (f'{name} : {value:.4 f} ' )
2. 可解释性评估 :
对于可解释性架构,可以评估趋势和季节性的分解质量:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 def evaluate_decomposition (model, x, y_true ): """评估分解质量""" model.eval () with torch.no_grad(): trends, seasonalities = [], [] residuals = x.clone() for stack in model.stacks: for block in stack.blocks: forecast, backcast = block(residuals) if isinstance (block.basis, TrendBasis): trends.append(forecast) elif isinstance (block.basis, SeasonalityBasis): seasonalities.append(forecast) residuals = residuals - backcast total_trend = sum (trends) total_seasonality = sum (seasonalities) reconstruction = total_trend + total_seasonality recon_error = torch.mean((reconstruction - y_true) ** 2 ) trend_smoothness = torch.mean(torch.abs (torch.diff(total_trend))) seasonality_fft = torch.fft.fft(total_seasonality) periodicity_strength = torch.max (torch.abs (seasonality_fft[1 :])) return { 'reconstruction_error' : recon_error.item(), 'trend_smoothness' : trend_smoothness.item(), 'periodicity_strength' : periodicity_strength.item() }
3. 稳定性评估 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def evaluate_stability (model, test_loader, n_runs=10 ): """评估模型预测的稳定性""" predictions_list = [] model.eval () with torch.no_grad(): for _ in range (n_runs): predictions = [] for x, _ in test_loader: pred = model(x) predictions.append(pred) predictions_list.append(torch.cat(predictions)) predictions_tensor = torch.stack(predictions_list) std = predictions_tensor.std(dim=0 ).mean().item() mean_pred = predictions_tensor.mean(dim=0 ) cv = (std / (mean_pred.abs ().mean() + 1e-8 )).item() return { 'prediction_std' : std, 'coefficient_of_variation' : cv }
Q12: N-BEATS
如何处理非平稳时间序列?
A : N-BEATS 通过以下方式处理非平稳序列:
1. 数据归一化 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class AdaptiveNormalizer : """自适应归一化:处理非平稳序列""" def __init__ (self, method='last_value' ): self.method = method self.scale = None def fit_transform (self, x ): """根据方法选择归一化策略""" if self.method == 'last_value' : self.scale = torch.abs (x[:, -1 :]) + 1e-8 elif self.method == 'rolling_mean' : window = min (10 , x.size(1 ) // 2 ) self.scale = x[:, -window:].mean(dim=1 , keepdim=True ).abs () + 1e-8 elif self.method == 'min_max' : self.scale_min = x.min (dim=1 , keepdim=True )[0 ] self.scale_max = x.max (dim=1 , keepdim=True )[0 ] self.scale = self.scale_max - self.scale_min + 1e-8 return (x - self.scale_min) / self.scale return x / self.scale def inverse_transform (self, x ): """逆变换""" if self.method == 'min_max' : return x * self.scale + self.scale_min return x * self.scale
2. 差分处理 :
对于有明显趋势的非平稳序列,可以先差分再预测:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 def difference_series (x, order=1 ): """时间序列差分""" diff = x.clone() for _ in range (order): diff = diff[:, 1 :] - diff[:, :-1 ] return diff def inverse_difference (diff, original_last_values, order=1 ): """逆差分""" restored = diff.clone() for _ in range (order): restored = torch.cat([ original_last_values.unsqueeze(1 ), restored ], dim=1 ) restored = restored.cumsum(dim=1 ) original_last_values = restored[:, -1 :] return restored original_series = torch.randn(32 , 100 ).cumsum(dim=1 ) differenced = difference_series(original_series, order=1 ) model = NBeatsModel(...) forecast_diff = model(differenced[:, -backcast_length:]) last_values = original_series[:, -1 :] forecast_original = inverse_difference(forecast_diff, last_values, order=1 )
3. 自适应块选择 :
对于非平稳序列,可以使用更多趋势块:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def create_adaptive_architecture (data_complexity='high' ): """根据数据复杂度创建自适应架构""" if data_complexity == 'high' : blocks = [ TrendBlock(...) for _ in range (10 ) ] + [ SeasonalityBlock(...) for _ in range (5 ) ] else : blocks = [ TrendBlock(...) for _ in range (5 ) ] + [ SeasonalityBlock(...) for _ in range (5 ) ] return blocks
实战技巧与性能优化
N-BEATS 超参数调优完整指南
1. 堆叠和块的数量选择
数据复杂度
推荐堆叠数
每堆叠块数
总块数
说明
简单模式
2
5-10
10-20
单变量、短序列
中等复杂度
3
8-10
24-30
多变量、中等序列
复杂模式
3-4
10-15
30-60
长序列、多尺度依赖
M4 竞赛配置 : 3 个堆叠,每个堆叠 10 个块,共 30
个块。
选择原则 :
从较小的配置开始( 2 堆叠× 5 块)
如果欠拟合,逐步增加块数
如果过拟合,减少块数或增加 Dropout
1 2 3 4 5 6 7 8 def get_optimal_config (data_complexity='medium' ): """根据数据复杂度推荐配置""" configs = { 'simple' : {'num_stacks' : 2 , 'num_blocks' : 5 , 'hidden_size' : 256 }, 'medium' : {'num_stacks' : 3 , 'num_blocks' : 10 , 'hidden_size' : 512 }, 'complex' : {'num_stacks' : 3 , 'num_blocks' : 15 , 'hidden_size' : 512 }, } return configs.get(data_complexity, configs['medium' ])
2. 隐藏层大小( Hidden Size)
数据规模
推荐 hidden_size
说明
< 1,000 样本
256
避免过拟合
1,000-10,000
512
标准配置
> 10,000
512-1024
充分表达能力
3. 基函数数量( Basis Size / Harmonics)
趋势块 :
degree = 3-5(多项式次数)
通常 3-4 次就足够表示大多数趋势
季节性块 :
harmonics = 5-10(谐波数量)
简单周期: 3-5 个谐波
复杂周期: 10-20 个谐波
通用块 :
basis_size = 10-20
M4 竞赛中使用 10 个基函数
4. 输入长度( Backcast Length)选择
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def get_input_length (frequency, forecast_length ): """根据数据频率和预测长度选择输入长度""" multipliers = { 'Yearly' : 1.5 , 'Quarterly' : 2.0 , 'Monthly' : 3.0 , 'Weekly' : 2.0 , 'Daily' : 2.0 , 'Hourly' : 3.5 , } multiplier = multipliers.get(frequency, 2.0 ) return int (forecast_length * multiplier)
数据预处理最佳实践
1. 归一化策略
N-BEATS 对归一化很敏感,推荐使用最后值归一化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 class LastValueNormalizer : """最后值归一化( N-BEATS 推荐)""" def __init__ (self ): self.scale = None def fit_transform (self, x ): """ x: [batch_size, seq_len] 或 [seq_len] """ if x.dim() == 1 : x = x.unsqueeze(0 ) self.scale = torch.abs (x[:, -1 :]) + 1e-8 return x / self.scale def inverse_transform (self, x ): """逆变换""" return x * self.scale def compare_normalizers (data ): """对比不同归一化方法的效果""" normalizers = { 'LastValue' : LastValueNormalizer(), 'MinMax' : MinMaxScaler(), 'Standard' : StandardScaler(), } results = {} for name, normalizer in normalizers.items(): normalized = normalizer.fit_transform(data) results[name] = { 'mean' : normalized.mean(), 'std' : normalized.std(), 'range' : [normalized.min (), normalized.max ()] } return results
2. 缺失值处理
1 2 3 4 5 6 7 8 9 10 def handle_missing_values (data, method='forward_fill' ): """处理缺失值""" if method == 'forward_fill' : return data.fillna(method='ffill' ).fillna(method='bfill' ) elif method == 'interpolate' : return data.interpolate(method='linear' ) elif method == 'mean' : return data.fillna(data.mean()) else : raise ValueError(f"Unknown method: {method} " )
3. 异常值检测与处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def detect_and_handle_outliers (data, method='iqr' , factor=1.5 ): """异常值检测和处理""" if method == 'iqr' : Q1 = data.quantile(0.25 ) Q3 = data.quantile(0.75 ) IQR = Q3 - Q1 lower_bound = Q1 - factor * IQR upper_bound = Q3 + factor * IQR data_clipped = data.clip(lower_bound, upper_bound) return data_clipped elif method == 'zscore' : z_scores = np.abs ((data - data.mean()) / data.std()) return data[z_scores < 3 ] else : return data
训练优化技巧
1. 损失函数选择
N-BEATS 在 M4 竞赛中使用 MAPE 和 sMAPE 的组合:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 class NBeatsLoss (nn.Module): """N-BEATS 专用损失函数""" def __init__ (self, mape_weight=0.5 , smape_weight=0.5 ): super ().__init__() self.mape_weight = mape_weight self.smape_weight = smape_weight def mape (self, y_pred, y_true ): """平均绝对百分比误差""" return torch.mean(torch.abs ((y_true - y_pred) / (y_true + 1e-8 ))) * 100 def smape (self, y_pred, y_true ): """对称平均绝对百分比误差""" numerator = torch.abs (y_pred - y_true) denominator = (torch.abs (y_pred) + torch.abs (y_true)) / 2 + 1e-8 return torch.mean(numerator / denominator) * 100 def forward (self, y_pred, y_true ): mape_loss = self.mape(y_pred, y_true) smape_loss = self.smape(y_pred, y_true) return self.mape_weight * mape_loss + self.smape_weight * smape_loss
2. 学习率调度
1 2 3 4 5 6 7 8 9 10 11 12 def get_nbeats_scheduler (optimizer, num_epochs ): """N-BEATS 学习率调度""" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs, eta_min=1e-6 ) return scheduler scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min' , factor=0.5 , patience=10 , verbose=True )
3. 梯度裁剪
虽然 N-BEATS 通常训练稳定,但对于深层网络仍建议使用梯度裁剪:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def train_nbeats (model, dataloader, optimizer, criterion, num_epochs ): """N-BEATS 训练循环""" for epoch in range (num_epochs): model.train() total_loss = 0 for x, y in dataloader: optimizer.zero_grad() forecast = model(x) loss = criterion(forecast, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) optimizer.step() total_loss += loss.item() avg_loss = total_loss / len (dataloader) print (f"Epoch {epoch+1 } /{num_epochs} , Loss: {avg_loss:.4 f} " )
常见问题排查
问题 1:预测结果过于平滑(欠拟合)
可能原因:
解决方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 model = NBeatsModel( num_stacks=3 , num_blocks_per_stack=15 , ... ) model = NBeatsModel( hidden_size=1024 , ... ) trend_block = TrendBlock(degree=5 ) seasonal_block = SeasonalityBlock(harmonics=15 )
问题 2:预测结果震荡(过拟合)
解决方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class NBeatsBlockWithDropout (nn.Module): def __init__ (self, ... ): ... self.dropout = nn.Dropout(0.3 ) def forward (self, x ): x_flat = x.mean(dim=1 ).unsqueeze(1 ) x_flat = self.dropout(x_flat) theta = self.fc_stack(x_flat) ... model = NBeatsModel( num_blocks_per_stack=5 , hidden_size=256 , ... ) optimizer = torch.optim.Adam( model.parameters(), lr=1e-3 , weight_decay=1e-5 )
问题 3:训练速度慢
优化方案: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) model = model.to(device) dataloader = DataLoader(dataset, batch_size=64 ) from torch.cuda.amp import autocast, GradScalerscaler = GradScaler() with autocast(): forecast = model(x) loss = criterion(forecast, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() model = NBeatsModel(num_blocks_per_stack=5 )
模型解释性分析
1. 趋势和季节性分解可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 def visualize_decomposition (model, x ): """可视化 N-BEATS 的分解结果""" model.eval () trends = [] seasonalities = [] residuals = x.clone() with torch.no_grad(): for stack in model.stacks: for block in stack.blocks: forecast, backcast = block(residuals) if isinstance (block.basis, TrendBasis): trends.append(forecast) elif isinstance (block.basis, SeasonalityBasis): seasonalities.append(forecast) residuals = residuals - backcast fig, axes = plt.subplots(4 , 1 , figsize=(15 , 12 )) axes[0 ].plot(x[0 ].numpy(), label='原始数据' ) axes[0 ].set_title('原始时间序列' ) axes[0 ].legend() total_trend = sum (trends) axes[1 ].plot(total_trend[0 ].numpy(), label='趋势' ) axes[1 ].set_title('趋势成分' ) axes[1 ].legend() total_seasonality = sum (seasonalities) axes[2 ].plot(total_seasonality[0 ].numpy(), label='季节性' ) axes[2 ].set_title('季节性成分' ) axes[2 ].legend() axes[3 ].plot(residuals[0 ].numpy(), label='残差' ) axes[3 ].set_title('残差' ) axes[3 ].legend() plt.tight_layout() plt.show()
2. 块贡献度分析
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def analyze_block_contributions (model, x ): """分析每个块对最终预测的贡献""" model.eval () block_contributions = [] total_forecast = torch.zeros_like(model(x)) with torch.no_grad(): for stack_idx, stack in enumerate (model.stacks): for block_idx, block in enumerate (stack.blocks): forecast, backcast = block(x if block_idx == 0 else residuals) block_contributions.append({ 'stack' : stack_idx, 'block' : block_idx, 'forecast_norm' : forecast.norm().item(), 'forecast_mean' : forecast.mean().item(), }) total_forecast += forecast residuals = residuals - backcast if block_idx > 0 else x - backcast contributions = [c['forecast_norm' ] for c in block_contributions] plt.figure(figsize=(12 , 6 )) plt.bar(range (len (contributions)), contributions) plt.xlabel('块索引' ) plt.ylabel('预测范数' ) plt.title('各块对预测的贡献度' ) plt.show() return block_contributions
N-BEATS 工程实践
模型版本管理与 A/B 测试
1. 模型版本管理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 import jsonfrom datetime import datetimeimport torchclass ModelVersionManager : """模型版本管理器""" def __init__ (self, model_dir='./models' ): self.model_dir = model_dir self.version_file = f'{model_dir} /versions.json' self.versions = self._load_versions() def _load_versions (self ): """加载版本信息""" try : with open (self.version_file, 'r' ) as f: return json.load(f) except FileNotFoundError: return {} def save_model (self, model, metrics, config, version_name=None ): """保存模型和版本信息""" if version_name is None : version_name = datetime.now().strftime('%Y%m%d_%H%M%S' ) model_path = f'{self.model_dir} /nbeats_{version_name} .pt' torch.save({ 'model_state_dict' : model.state_dict(), 'config' : config, 'metrics' : metrics, 'timestamp' : datetime.now().isoformat() }, model_path) self.versions[version_name] = { 'model_path' : model_path, 'metrics' : metrics, 'config' : config, 'timestamp' : datetime.now().isoformat() } with open (self.version_file, 'w' ) as f: json.dump(self.versions, f, indent=2 ) return version_name def load_model (self, version_name ): """加载指定版本的模型""" if version_name not in self.versions: raise ValueError(f'版本 {version_name} 不存在' ) checkpoint = torch.load(self.versions[version_name]['model_path' ]) return checkpoint def get_best_model (self, metric='sMAPE' , lower_is_better=True ): """获取最佳模型""" best_version = None best_score = float ('inf' ) if lower_is_better else float ('-inf' ) for version, info in self.versions.items(): score = info['metrics' ].get(metric, float ('inf' )) if (lower_is_better and score < best_score) or \ (not lower_is_better and score > best_score): best_score = score best_version = version return best_version, self.versions[best_version] version_manager = ModelVersionManager() version_name = version_manager.save_model( model=model, metrics={'sMAPE' : 12.5 , 'MAE' : 45.2 }, config={'num_stacks' : 3 , 'num_blocks' : 10 } ) best_version, best_info = version_manager.get_best_model('sMAPE' ) print (f'最佳模型版本: {best_version} , sMAPE: {best_info["metrics" ]["sMAPE" ]} ' )
2. A/B 测试框架
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 class ABTestFramework : """A/B 测试框架""" def __init__ (self, model_a, model_b, traffic_split=0.5 ): self.model_a = model_a self.model_b = model_b self.traffic_split = traffic_split self.results = {'A' : [], 'B' : []} def predict (self, x, user_id=None ): """根据用户 ID 或随机分配模型""" if user_id is not None : import hashlib hash_val = int (hashlib.md5(str (user_id).encode()).hexdigest(), 16 ) use_model_a = (hash_val % 100 ) < (self.traffic_split * 100 ) else : use_model_a = np.random.random() < self.traffic_split if use_model_a: pred = self.model_a(x) group = 'A' else : pred = self.model_b(x) group = 'B' return pred, group def log_result (self, group, y_pred, y_true ): """记录预测结果""" error = torch.abs (y_pred - y_true).mean().item() self.results[group].append(error) def compare_results (self ): """比较 A/B 测试结果""" if len (self.results['A' ]) == 0 or len (self.results['B' ]) == 0 : return None from scipy import stats mean_a = np.mean(self.results['A' ]) mean_b = np.mean(self.results['B' ]) t_stat, p_value = stats.ttest_ind(self.results['A' ], self.results['B' ]) improvement = (mean_a - mean_b) / mean_b * 100 return { 'mean_error_A' : mean_a, 'mean_error_B' : mean_b, 'improvement' : improvement, 'p_value' : p_value, 'significant' : p_value < 0.05 } ab_test = ABTestFramework(model_a=model_v1, model_b=model_v2, traffic_split=0.5 ) for x, y_true in test_loader: pred, group = ab_test.predict(x, user_id=123 ) ab_test.log_result(group, pred, y_true) comparison = ab_test.compare_results() print (f"模型 B 相对模型 A 改进: {comparison['improvement' ]:.2 f} %" )print (f"统计显著性: {comparison['significant' ]} " )
生产环境部署优化
1. 模型压缩与加速
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 class ProductionNBeats : """生产环境优化的 N-BEATS 模型""" def __init__ (self, model, use_quantization=True , use_jit=True ): self.model = model self.use_quantization = use_quantization self.use_jit = use_jit if use_quantization: self.model = self._quantize_model() if use_jit: self.model = self._jit_compile() def _quantize_model (self ): """量化模型""" import torch.quantization as quantization quantized = quantization.quantize_dynamic( self.model, {nn.Linear}, dtype=torch.qint8 ) return quantized def _jit_compile (self ): """JIT 编译""" self.model.eval () example_input = torch.randn(1 , 48 , 1 ) traced = torch.jit.trace(self.model, example_input) return traced def predict_batch (self, x, batch_size=32 ): """批量预测(优化内存使用)""" predictions = [] for i in range (0 , len (x), batch_size): batch = x[i:i+batch_size] with torch.no_grad(): pred = self.model(batch) predictions.append(pred) return torch.cat(predictions) production_model = ProductionNBeats( model=trained_model, use_quantization=True , use_jit=True ) predictions = production_model.predict_batch(test_data, batch_size=64 )
2. 缓存与预热
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class CachedNBeatsService : """带缓存的 N-BEATS 服务""" def __init__ (self, model, cache_size=1000 ): self.model = model self.cache = {} self.cache_size = cache_size self.hit_count = 0 self.miss_count = 0 self._warmup() def _warmup (self ): """预热模型(避免首次推理延迟)""" dummy_input = torch.randn(1 , 48 , 1 ) with torch.no_grad(): _ = self.model(dummy_input) def _hash_input (self, x ): """生成输入哈希(用于缓存键)""" key = (x.mean().item(), x.std().item()) return key def predict (self, x ): """带缓存的预测""" cache_key = self._hash_input(x) if cache_key in self.cache: self.hit_count += 1 return self.cache[cache_key] self.miss_count += 1 with torch.no_grad(): pred = self.model(x) if len (self.cache) >= self.cache_size: oldest_key = next (iter (self.cache)) del self.cache[oldest_key] self.cache[cache_key] = pred return pred def get_cache_stats (self ): """获取缓存统计""" total = self.hit_count + self.miss_count hit_rate = self.hit_count / total if total > 0 else 0 return { 'hit_rate' : hit_rate, 'cache_size' : len (self.cache), 'hits' : self.hit_count, 'misses' : self.miss_count } service = CachedNBeatsService(model, cache_size=1000 ) predictions = service.predict(test_input) stats = service.get_cache_stats() print (f"缓存命中率: {stats['hit_rate' ]:.2 %} " )
3. 错误处理与降级
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 class RobustNBeatsService : """健壮的 N-BEATS 服务(带错误处理和降级)""" def __init__ (self, primary_model, fallback_model=None ): self.primary_model = primary_model self.fallback_model = fallback_model self.error_count = 0 self.fallback_count = 0 def predict (self, x ): """预测(带错误处理)""" try : if x is None or x.numel() == 0 : raise ValueError("输入为空" ) if x.size(1 ) != self.primary_model.backcast_length: raise ValueError(f"输入长度不匹配: 期望{self.primary_model.backcast_length} , 得到{x.size(1 )} " ) if torch.isnan(x).any () or torch.isinf(x).any (): raise ValueError("输入包含 NaN 或 Inf" ) with torch.no_grad(): pred = self.primary_model(x) if torch.isnan(pred).any () or torch.isinf(pred).any (): raise ValueError("预测结果包含 NaN 或 Inf" ) return pred except Exception as e: self.error_count += 1 logging.error(f"预测错误: {e} " ) if self.fallback_model is not None : self.fallback_count += 1 try : return self.fallback_model(x) except : pass return self._default_predict(x) def _default_predict (self, x ): """默认预测(简单策略)""" recent_mean = x[:, -5 :].mean(dim=1 , keepdim=True ) forecast_len = self.primary_model.forecast_length return recent_mean.expand(-1 , forecast_len) def get_stats (self ): """获取服务统计""" return { 'error_count' : self.error_count, 'fallback_count' : self.fallback_count, 'fallback_rate' : self.fallback_count / max (self.error_count, 1 ) } service = RobustNBeatsService( primary_model=main_model, fallback_model=simple_model ) predictions = service.predict(test_input) stats = service.get_stats() print (f"降级率: {stats['fallback_rate' ]:.2 %} " )
部署检查清单
模型准备 :
性能优化 :
监控与日志 :
容错与恢复 :
总结要点
N-BEATS
通过创新的架构设计,在时间序列预测领域取得了突破性成果。本文深入解析了其核心思想和技术细节:
核心创新
基函数展开 :用神经网络学习基函数,灵活表示趋势和季节性
双残差堆叠 :通过 Forecast 累加和 Backcast
残差传递,实现多尺度特征提取
可解释性架构 :明确区分趋势块和季节性块,便于业务理解
关键技术点
趋势块 :使用多项式基函数,适合长期趋势建模
季节性块 :使用傅里叶基函数,适合周期性模式建模
通用块 :完全数据驱动,在 M4 竞赛中表现最佳
残差连接 :确保信息流动和训练稳定性
实践建议
架构选择 :业务需要可解释性用可解释性架构,追求精度用通用架构
超参数调优 :从较小的模型开始( 2 堆叠× 5
块),逐步增加复杂度
数据预处理 :使用最后值归一化,处理缺失值和异常值
损失函数 :根据任务特性选择 MAPE/sMAPE 或 MSE
训练优化 :使用梯度裁剪、学习率调度、混合精度训练
集成学习 :多个模型的集成可以进一步提升性能
实战优化清单
未来方向
N-BEATS 的成功启发了后续研究:
N-HiTS :改进了多分辨率特征提取
PatchTST :结合了 Transformer 的注意力机制
TimeGPT :大语言模型在时间序列中的应用
时间序列预测是一个充满挑战的领域, N-BEATS
为我们提供了一个既强大又可解释的解决方案。希望本文能帮助你深入理解
N-BEATS,并在实际项目中应用它来解决时间序列预测问题。