会话推荐的难点在于:只有一段很短的点击序列,没有稳定的长期画像,但用户的“当前意图”又变化很快。 SR-GNN 的关键洞察是把一次会话看成一张有向图:同一个物品可能在会话里重复出现,不同跳转路径会形成不同的局部结构;用图神经网络在这张会话图上传播信息,比单纯把序列喂给 RNN 更能捕捉“短期内的复杂依赖”。本文会围绕它的实现脉络讲清楚:会话图如何构建与归一化、门控式的 GNN 更新在做什么、如何从节点表示得到会话级表示并输出下一跳打分,以及它为什么能在多个基准数据集上稳定超过传统序列模型,方便你在真实场景里复用这套思路。
背景介绍
在会话推荐系统中,用户的点击序列是短期行为的体现,并且我们只使用当前会话中的点击数据来预测用户下一个可能点击的物品。在这种背景下,我们无法依赖于用户的长期历史偏好,而是基于当前会话内的物品之间的关系来进行推荐。
会话推荐问题可以表述为:给定一个物品集合
具体细节
会话图的构建
为了更好地捕捉会话内物品之间的复杂关系, SR-GNN
将会话数据转换为图结构。对于每一个会话序列
- 节点:表示用户在会话中的每一个点击物品;
- 边:表示物品之间的点击顺序。
例如,用户在会话中依次点击了物品
物品嵌入的学习
在构建好会话图之后, SR-GNN 使用图神经网络(
GNN)来学习物品嵌入。 GNN
的优势在于能够在图结构上进行信息的传播和聚合,从而捕捉到图中节点(物品)之间的复杂关系。具体来说,每个节点的嵌入
该公式表示在图神经网络中的信息传播过程。会话图中的每个节点(物品)不仅仅依赖于自身的信息,它还会从它的邻居节点(即与之有边相连的物品)获取信息,并通过加权求和来更新其自身状态。公式中的关键组成部分包括:
:会话图的邻接矩阵。它确定了节点之间的连接方式。对于一个节点 来说,这个矩阵决定了它可以从哪些其他节点获取信息。 :表示会话图中所有节点的前一时间步的嵌入向量。 :一个权重矩阵,控制着如何结合这些信息。 :偏置项,用于调整输出。
$
$$
$$
经过多轮迭代后,节点的最终嵌入
生成会话表示
在每个会话图中的物品节点嵌入学习完成后, SR-GNN 生成整个会话的表示。这是通过结合局部嵌入和全局嵌入完成的:
局部嵌入:直接使用最后一个点击物品的嵌入
来表示当前用户的短期兴趣。 全局嵌入:通过自注意力机制将会话中的所有物品嵌入聚合起来,捕捉用户的长时兴趣。
是一个全局向量,它的作用是提供一个权重机制,用来衡量不同物品(节点)的重要性。这个全局向量是通过训练学习到的,它帮助模型对当前会话中的每个物品节点进行一个权重分配。$ q _n$** 表示会话中的最后一个物品的嵌入。为什么用最后一个物品呢?因为在很多推荐场景中,用户的最后一个点击行为往往反映了用户最当前的兴趣。最后一个物品的嵌入 是一个很重要的信号,它能代表用户对某类物品的偏好。 则表示会话中的第 个物品的嵌入。这个物品是当前会话中可能存在的某个物品。 和 是两个权重矩阵,用来将最后一个物品和当前物品的嵌入向量映射到一个新的空间中。通过这两个权重矩阵,模型能够比较当前物品 与会话中最后一个物品 之间的相似性。这样可以让模型捕捉到用户在会话中是如何从一个物品逐步过渡到最后一个物品的兴趣变化。 - 如果某个物品
和最后一个物品 的相关性很高,那么它的权重 会更大,表示这个物品对当前会话的整体偏好影响更大。
- 最终嵌入:通过将局部嵌入
和全局嵌入 进行线性组合,生成最终的会话表示。
预测与模型训练
在会话嵌入生成之后, SR-GNN 通过计算每个候选物品的得分
模型使用交叉熵损失函数进行训练:
$$
L = -_{i} y_i (_i) $$
其中,$ y_i
在 SR-GNN 模型中,公式
然后,模型将计算得到的分数通过 Softmax 函数转化为概率分布
其中,
损失函数的定义
为了训练模型,使用交叉熵损失函数来衡量模型的预测结果与实际点击物品之间的差异。损失函数的形式如下:
其中,
模型训练
训练过程中,采用反向传播算法( Back-Propagation Through Time, BPTT)来更新模型参数。在会话推荐任务中,大部分会话的长度相对较短,因此建议选择较小的训练步数,以防止过拟合。
这个过程通过不断地调整模型参数,使模型逐步学会捕捉用户的行为模式,从而在新的会话中为用户推荐最有可能点击的物品。
代码示例
模型的实现源代码在 https://github.com/CRIPAC-DIG/SR-GNN/tree/master,下面我将提供一个简化版本的代码进行讲解。
类定义与初始化
1 | class SimplifiedSRGNN: |
- 物品嵌入矩阵:模型使用
embedding变量表示物品的嵌入向量。每个物品都对应一个向量,大小为hidden_size,这些向量是通过训练来更新的。 - 邻接矩阵:
adj_in和adj_out是占位符,用于存储会话图的入度和出度邻接矩阵。这些矩阵用于信息传播,帮助模型了解物品之间的点击顺序。 - 权重矩阵:
W_in和W_out是两个权重矩阵,分别用于对入度和出度邻接矩阵中的物品嵌入进行变换。每个矩阵的大小与hidden_size相同,用于调整信息传播的权重。
图神经网络中的信息传播
1 | def gnn_propagation(self): |
- 物品嵌入:通过
tf.nn.embedding_lookup,我们从嵌入矩阵中获取当前批次中物品的嵌入向量,形状为(batch_size, T, hidden_size),其中T是会话中的物品序列长度。 - 信息传播:我们通过
adj_in和adj_out进行入度和出度邻接矩阵的乘法操作,来更新每个节点的嵌入。这里的操作模拟了会话中的物品点击顺序对信息传播的影响。 - GRU 单元:
GRU是一种循环神经网络单元,用来捕捉序列中的时间依赖性。我们将聚合后的物品嵌入传入 GRU 中,最终得到final_state,它表示了当前批次中物品序列的状态。
训练过程
1 | def train(self): |
- 得分计算:我们通过 GNN 得到的物品序列的状态
final_state和物品嵌入矩阵做内积运算,得到每个物品的推荐得分logits。 - 交叉熵损失函数:我们使用
tf.nn.sparse_softmax_cross_entropy_with_logits来计算目标物品的损失,这个损失衡量了模型对下一个物品预测的准确性。 - 优化器:使用
Adam优化器对模型进行优化,目的是通过最小化损失函数,逐步更新模型参数,提高预测准确性。
训练循环
1 | def train_model(n_items, epochs=10): |
- 生成模拟数据:我们为每个训练批次生成随机的邻接矩阵(
adj_in_batch和adj_out_batch)和物品序列(item_batch)。 - 执行训练:在每个训练批次中,我们运行模型的优化器来最小化损失函数,并输出当前批次的损失值。
- 会话管理:在 TensorFlow 中,通过
tf.Session()来执行计算图,并使用sess.run()来实际执行模型的计算和优化操作。
- 本文标题:SR-GNN —— Session-based Recommendation with Graph Neural Networks
- 本文作者:Chen Kai
- 创建时间:2021-07-15 10:45:00
- 本文链接:https://www.chenk.top/Session-based-Recommendation-with-Graph-Neural-Networks/
- 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!