Buy This!: Session-based Recommendation Using SR-GNN
This tutorial is an attempt to show how graph machine learning(ML) models can be applied to the session-based recommendation task using the PyG framework. You can find the accompanying Colab notebook here. By Eun Jee Sung, as part of the Stanford CS224W course project.
Table of Contents:
1. Overview
2. Methods
3. Sessions as Graphs
4. Dataset
5. Data Pipeline
6. SR-GNN
7. Result
Overview
Many well-known recommender systems like matrix factorization are developed with the assumption that it is possible to build and use long-term user profiles for recommendations, but this is not true for several reasons.
- Identifying and tracking users by unique ids may not be an option for small and medium-sized companies.
- Browser fingerprints and cookies may be unstable in different environments.
- We may visit each site with different intentions each time! For example, on Amazon, a user might search for ski items today and notebooks tomorrow.
Session-based recommendation is a task to recommend items solely based on the user’s interactions in the session without referring to the user’s long-term profile. This poses an interesting technical challenge: how can we infer the user’s current implied interest from the relatively short but complex interaction pattern in a given session?
In this tutorial, we will study a graph ML model called Session-based Recommendation with Graph Neural Networks (SR-GNN) proposed for this task. We will study SR-GNN while implementing the model with the PyG framework. We will also train the model using the RetailRocket dataset.
Q) What is the PyG framework?
A) The PyG framework is the most popular graph deep learning framework built on top of PyTorch that offers convenient elements to use graph datasets and develop graph ML models. This tutorial offers only a sneak peek into the PyG library. If you want to know more about the framework, check this documentation out.
Methods
Existing Methods
Existing methods for session-based recommendation can be summarized into several categories.
Conventional recommendation methods. The most well-known and probably the most general method is matrix factorization. Matrix factorization is to factorize a user-item rating matrix into two low-rank matrices each representing latent factors of users and items. Also, there are some item-based neighborhood methods that count the co-occurrence of items in the same session. However, matrix factorization and item-based neighborhood methods cannot account for the sequential order of items. Markov chain methods can account for the sequential nature of data, but they make a strong assumption that the sequence components are independent.
Deep-learning-based methods. With the development of deep learning, Recurrent Neural Networks (RNNs) have also been actively proposed since they showed impressive performance for the tasks dealing with sequential data, e.g. natural language processing and video analysis. Some of the most well-known models include GRU4Rec and NARM. GRU4Rec was one of the first models that applied RNN to the session-based recommendation domain [8]. NARM applied an attention mechanism to RNNs to capture users’ implied sequential behavior and purposes [9]. However, these sequential models have some limitations. First, their predictions often rely on the last interaction. Second, it is hard to take repeated consumption into account.
Graph ML. Recent development in graph ML also inspired several Graph Neural Networks (GNNs) for recommendation tasks. However, according to the SR-GNN paper we study today, many representative algorithms like Node2Vec [11], DeepWalk [12], and LINE [13] targeted undirected graphs, which made it hard to apply the algorithms to session-based recommendation tasks.
SR-GNN
Today, we will study one of the classic models for session-based recommendation: SR-GNN. According to the original paper, by explicitly modeling the switching behavior of users between nodes(items), this model can capture repeated and complex interactions within a session better than traditional sequential networks such as vanilla RNN [1]. Also, compared to other graph ML models, SR-GNN can model directed session graphs because it uses GRU cells.
The paper claims that it showed better performance than other baselines such as GRU4Rec or NARM for two of the most popular datasets for session-based recommendation [7]. Not only that, this paper was accepted to the 2019 AAAI conference and has been cited 408 times as of December 9th 2021 according to Google Scholar!
Sessions as Graphs
The core elements of graphs are nodes and edges. Graph ML models try to capture the representations of nodes and the edges connecting them (Warning: This is a gross oversimplification. If you are interested in graph ML, check our fantastic course :)).
For our prediction task, each session is represented as a graph and each interacted item in the session is represented as a node. After looking into a session graph, we predict what the next item the user clicked was.
Let’s look at an example. The example session below represents a user’s interaction sequence with items id 248676, 8775, 246453, 8775, and 193150, in that order. We can see that graph representation allows us to see that the user was quite interested in 8775 and was comparing several products. Maybe items 8775, 246453, and 193150 were similar in some sense. Graph representation allows to extract such information from short but complex session sequences.
The RetailRocket Dataset
The example dataset we are going to use in this tutorial is the RetailRocket dataset [2]. It was released as a Kaggle competition dataset and has been used as a benchmark dataset for several session-based recommendation papers. Retail is one of the domains where session-based recommender systems are most frequently used because often users use the same website with different intentions at each session.
Our raw dataset consists of user event logs during the data collection period. Each row is an interaction between a user and an item. To use this for our graph ML model, we first have to group a series of events that happened in a relatively short amount of time into a session and convert each series of events into a session graph.
For simplicity, we only use the ‘view’ events for our prediction task. We also filter out visitors who only visited the website only one or two times. Since we only need information about item ids, we drop the visitorid
section as well. Then we group the events from the same user within a 2-hour window into a ‘session.’ The final preprocessed dataset is a list of sessions.
You can follow detailed preprocessing steps in this Colab notebook or download the data I preprocessed from here.
Data Pipeline
How do we store each session?
Training a graph ML model requires us to convert each row-format session datapoint into a graph. We will use torch_geometric.data.Data
class to store session graphs.
How do we pass the Data
instances to our model?
For this, we will use the PyG InMemoryDataset
and DataLoader
class. PyG InMemoryDataset
offers convenient functionalities to handle graph datasets. For one thing, if you try to pass torch_geometric.data.Data
instances using the usual torch.data.Dataset
and torch.data.DataLoader
class, it will instantly ask you for a custom collate function to collate pyg.data.Data
instances into a batch. The PyG library will handle this for you!
Our custom InMemoryDataset
class code is as follows. In the code, edge_index
is a 2-d tensor having two columns, where each row’s first column represents the sender and the second column represents the receiver node of an edge. y
is the target label of each session, which is the last item that was interacted in the real session. x
is the list of original unique item ids.
Note that PyG’s Data
class gets collated into oneBatch
class instance, which is basically a single large collated graph (a single Data
instance). This detail will come in handy later.
SR-GNN
Now, let’s dive into our main model! Before moving on to the details, take a moment to scan the big picture with the diagram below.
Node Embeddings
We first have to create the initial embeddings of our nodes. We will simply embed the unique item ids into vector-form embeddings. The torch.nn.Embedding
layer will take care of this.
Gated Graph Neural Network (GG-NN) Layer
We need to refine these initial embeddings into better representations. To achieve this, SR-GNN uses (1) sessions’ connectivity information and (2) the GG-NN structure.
Building an Adjacency Matrix
We have to use everything we know about the given session to build the best representation. We already used item ids to build node embeddings but we also know…. how they are connected! We will use adjacency matrices to explicitly make our model learn how the nodes interact with each other within a session.
The matrix above is an adjacency matrix for the given directed graph on the left. The left half of the matrix represents ‘out-degrees’ and the right half represents ‘in-degrees’. The adjacency matrix is then transformed to create the activations from edges in both directions.
The GRU Cell
The activations are then passed to a cell with these equations:
Do these equations look familiar to you? This is just a Gated Recurrent Unit (GRU) layer! SR-GNN uses a variant of GG-NN from the paper “Gated Graph Sequence Neural Networks” [6], which in turn adopts a GRU cell in their architecture.
The GG-NN paper claims that adopting the GRU cell allows the model to learn the changing representations of graph data while processing the data sequentially. It makes sense for our SR-GNN to adopt the GG-NN structure since our model targets sequential datasets.
Using PyG’s MessagePassing Class
Let’s put this into codes. The PyG library already offers torch_geometric.nn.conv.GatedGraphConv
class that implements the GG-NN paper, but SR-GNN uses a slightly different gated session graph layer. To be specific, the original GG-NN concatenates a zero matrix to the right of a feature matrix, but we don’t need that for our network. So we will implement our custom gated graph convolution layer using PyG’s MessagePassing
class (you can find more details on this documentation).
Again, our gated session graph layer has two main parts. We will put these inside the forward()
function.
- Message propagation to create and use an adjacency matrix (
self.propagate
). For this part, the PyGMessagPassing
class offers functionality to implement GNNs conveniently. We only have to correctly define themessage()
andmessage_and_aggregate()
functions and callpropagate()
. - The GRU cell (
self.gru
). For this part, we will reusetorch.nn.GRUCell
class from the PyTorch library.
We often need to define complex message-passing functions and aggregation functions to implement GNNs, but our gated session graph layer is quite simple. Our nodes will pass their node embedding to their neighboring nodes, so our message()
function will directly pass x_j
passed to itself. Our graphs don’t have any complicated edge weights or aggregation strategies, so themessage_and_aggregate()
function will simply multiply our adjacency matrix with the node embeddings, but thetorch_geometric.nn.conv.MessagePassing
class will internally create the adjacency matrix from the edge_index
we pass to self.propagate()
without any more work from us!
The final GatedSessionGraphConv
layer class code looks like this:
We then use this message-passing class inside the SRGNN
class like this:
class SRGNN(nn.Module):
def __init__(self, hidden_size, n_items):
super(SRGNN, self).__init__()
...
self.gated = GatedSessionGraphConv(self.hidden_size)
... def forward(self, data):
...
v_i = self.gated(embedding, edge_index)
...
Creating Session Embeddings
We now have good representations of each node in our graphs. How can we use this to create the embeddings of each session (graph)? The paper proposes to create the ‘local embedding’ and ‘global embedding’ of a session to create the final ‘hybrid embedding’.
The local embedding is simply the embedding of the last item of a session.
The global embedding is the weighted average of the embeddings of the items in the session. How much each item contributes to the ‘global’ embedding of the session is determined by the attention score.
The final hybrid embedding of a session is created by first concatenating the local and global embeddings and then linearly transforming them.
The equations are simple, but it can be a little tricky to put this into code and calculate everything in vector form.
Do you remember that PyG’s Data
class instances in a batch are collated into a single graph (a Batch
instance)? We can map each node back to individual graphs by calling.batch
on the instance which maps each node to its respective graph identifier.
So we can implement equations (6) and (7) like this:
# (1)-(5)
v_i = self.gated(embedding, edge_index)# Divide nodes by session
sections = list(torch.bincount(batch_map).cpu())
v_i_split = torch.split(v_i, sections)v_n, v_n_repeat = [], []
for session in v_i_split:
v_n.append(session[-1])
v_n_repeat.append(
session[-1].view(1, -1).repeat(session.shape[0], 1))
v_n, v_n_repeat = torch.stack(v_n), torch.cat(v_n_repeat, dim=0)q1 = self.W_1(v_n_repeat)
q2 = self.W_2(v_i)
Computing item scores
The final scores of each item are computed by computing the cosine similarity between the session embedding (1 x d) and the embeddings of all 466867 unique items (466867 x d). In our case, the resulting un-normalized score will have a shape of (466867 x 1). This represents the scores for each item for the given session.
# Unnormalized scores
z = torch.mm(self.embedding.weight, s_h.T).T# We take softmax of the unnormalized scores
# Note that this line is not in the colab, since softmax function is
# already included in the `nn.CrossEntropyLoss()` class.
y_hat = F.softmax(z)
Final Model
Our final model looks like this:
Result
Let’s start training! In case you missed it, check the Colab notebook for the training pipeline. This is the training loss graph I got:
This means that for our validation split, our model could accurately predict the very item that actual users clicked after a given session, out of 480K items, for 22% of times!
Let’s run this model on the test split. I’ll use the Hit@K accuracy as my evaluation metrics since this is one of the most popular evaluation metrics for recommendation tasks. Hit@K metric counts the number of cases where our K-th highest-scored items (K recommended items) contained the target item. My best model got the below results:
- Hit@1 accuracy: 20.45%
- Hit@10 accuracy: 43.72%
- Hit@20 accuracy: 49.64%
Exciting! Our model could accurately predict the next item the user is going to click with 20% probability. Also, if our model were to recommend 10 items to choose from, the user would like one of the recommended items in 44% of cases!
I hope this tutorial was helpful for anyone interested in using graph ML techniques for session-based recommendation tasks. If you have any questions, please direct it to ejsung[at]stanford.edu.
While creating this tutorial, I referred to the github repositories by userbehavioranalysis [3] and CRIPAC-DIG [10], but the Colab codes, article, and images were entirely written or created by myself for simpler implementation and explanation.
I thank the 2021 Fall CS224W teaching staff for their passionate teaching and constructive feedback on this tutorial.
Reference
[1] Wu, Shu, et al. “Session-based recommendation with graph neural networks.” Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 33. №01. 2019.
[2] RetailRocket dataset. https://www.kaggle.com/retailrocket/ecommerce-dataset
[3] https://github.com/userbehavioranalysis/SR-GNN_PyTorch-Geometric
[4] Stanford 2021Fall CS224W lecture slide, Week1, slide 40.
[5] http://colah.github.io/posts/2015-08-Understanding-LSTMs/
[6] Li, Yujia, et al. “Gated graph sequence neural networks.” arXiv preprint arXiv:1511.05493 (2015).
[7] Wu, Shiwen, et al. “Graph neural networks in recommender systems: a survey.” arXiv preprint arXiv:2011.02260 (2020).
[8] Hidasi, Balázs, et al. “Session-based recommendations with recurrent neural networks.” arXiv preprint arXiv:1511.06939 (2015).
[9] Li, Jing, et al. “Neural attentive session-based recommendation.” Proceedings of the 2017 ACM on Conference on Information and Knowledge Management. 2017.
[10] https://github.com/CRIPAC-DIG/SR-GNN
[11] Grover, Aditya, and Jure Leskovec. “node2vec: Scalable feature learning for networks.” Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining. 2016.
[12] Perozzi, Bryan, Rami Al-Rfou, and Steven Skiena. “Deepwalk: Online learning of social representations.” Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining. 2014.
[13] Tang, Jian, et al. “Line: Large-scale information network embedding.” Proceedings of the 24th international conference on world wide web. 2015.