Using PyG to get Book Recommendations

By Priya Mishra, Sanjari Srivastava, and Ana Selvaraj as a part of the Stanford CS 224W course project

Ever felt like this:

[Image Source]

Sometimes the power is out and you need a good book to keep you company. But how will you choose it? Maybe your friends are not into the genre you like. Maybe your local bookstore is too overwhelming with choices and your wallet is sparse this month because of inflated food prices.

Well, you can just build your own recommender system with PyG, of course!

Image of PyG (torch_geometric)’s github repo

In this tutorial, we will use PyG to build and optimize a recommender system that uses a GCN (Graph Convolutional Network). The full version of our code is in this Colab notebook where you can train the complete model and evaluate it.

In this tutorial, we will go over:

  1. How to build a heterogeneous graph with HeteroData?
  2. How to use RGATConv and HeteroConv to build a recommender system?
  3. How to evaluate our model with the help of RandomLinkSplit and interact with our recommender system?

Background

For this tutorial, we’ll be using a public dataset of Goodreads books and interactions between books and users found here. Please note that this dataset is for academic use only.

Example rows of User Book Dataset
goodreads_interactions.csv
goodreads_books.csv

In particular, we use the user-book dataset to create our graph. This dataset records user-book interactions in the form of the rating that a user leaves for a book on Goodreads.

If you don’t have access to a machine with high GPU RAM (> 15 GB), we recommend starting with a subset of the data that only deals with books in the genre of poetry. Building a system to process the complete dataset (which has 228,648,342 user-book interactions!) would require a machine with more memory and disk space than the default Colab runtime provided by Google.

Heterogeneous Graphs

We model the Goodreads data with a bipartite heterogeneous graph with two node types: users and books. An edge between nodes of either type will represent whether ‘this user interacts with or reads this book’. Each edge will have a feature of the user’s rating of the book on a scale from 0 to 5.

This graph can easily be constructed with the user book interactions dataset. The user nodes are derived from the list of unique user IDs and the book nodes are derived from the list of unique book IDs. The list of edges comes from this dataset since each row represents a user interacting with a specific book.

Once we model the graph, the recommender system essentially becomes a link prediction task, where an edge’s existence between a user and a book implies that the user would read (and rate) a book. This is a concept in network theory where we want to predict the existence of edges between pairs of nodes.

Illustration of Graph Structure

HeteroData

To construct a heterogeneous graph to feed into our model, we use torch_geometric’s HeteroData class. This data structure mimics a nested dictionary similar to a regular Data object. HeteroData stores node and edge-level attributes. It creates disjoint storage objects for each node and edge type.

To create a HeteroData object, we need the following information:

  • A tensor of unique user IDs to initialize the user nodes
  • A tensor of unique book IDs to initialize the book nodes
  • A tensor of unique user-book interactions (edge_index) and any edge-level features

For our dataset, we use the user-book interactions of the dataset to create the edge_index by combining the user node IDs and the corresponding book node IDs for each interaction. This gives us the edge_index in the COO format required by the HeteroData class.

# Extract the mapped ID values for each interaction to obtain the graph edges
# Note that each user-book interaction is an edge in our graph.
interactions_user_id = torch.from_numpy(interactions_user_id['mapped_id'].values)
interactions_book_id = torch.from_numpy(interactions_book_id['mapped_id'].values)


# Construct the edge index in the COO format.
edge_index = torch.stack([interactions_user_id, interactions_book_id], dim=0)

Given this information, we just need to plug it into an HeteroData object to create a graph. Note that we make the graph undirected so there will be two nominal types of edges (user reads/rates book, book is read/rated by user to describe one user-book interaction).

# Initialize a HeteroData object
data = HeteroData()

num_edges = len(user_book_data)

# Use the consecutive mapped IDs for users and books as their node IDs.
data['user'].node_ids = torch.tensor(unique_user_id['mapped_id'], dtype=torch.int32)
data['book'].node_ids = torch.tensor(unique_book_id['mapped_id'], dtype=torch.int32)

# Use the edge_index to build edges of the graph.
# Assign the edge attributes to the edge
data['user', 'reads', 'book'].edge_index = edge_index
data['user', 'reads', 'book'].edge_attr = edge_features

# Make the graph undirected.
# This adds another reverse edge type to the graph.
data = T.ToUndirected()(data)

HeteroData comes with many built-in functions to get the statistics of a graph that can be used to verify the structure of the graph created so far.

print(data.metadata())
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
(['user', 'book'], [('user', 'reads', 'book'), ('book', 'rev_reads', 'user')])
Number of nodes: 414313
Number of edges: 5468700

HeteroData also has various features related to induced subgraphs. This allows us to easily extend our recommender system for a specific subset of users or books.

subgraph = {}

# Only consider 100 nodes of each type
subgraph['user'] = torch.tensor(unique_user_id['mapped_id'][:100], dtype=torch.long)
subgraph['book'] = torch.tensor(unique_book_id['mapped_id'][:100], dtype=torch.long)

subgraph = data.subgraph(subgraph)
print(f"Number of nodes in our subgraph: {subgraph.num_nodes}")py
Number of nodes in our subgraph: 200

If you want to generalize a heterogeneous graph to a model made for homogeneous graphs, you could also use the to_homogeneous function to plug your graph into it. Thus, HeteroData makes it easy to work with heterogeneous graphs.

GNN

A Graph Neural Network is an optimizable transformation on the attributes of a graph. This makes it good for a recommender system since we specify what attributes and features of a graph to focus on to predict a link.

In this tutorial, we are using a GNN with user-item interaction information. Specifically, we will be using GAT (Graph Attention Network), an attention mechanism to learn the influence of the neighbors.

Figure describing message passing layer of a GNN [Reference]

RGATConv

For our tutorial, we use the RGATConv operator to take advantage of the edge features (attributes of user-book interactions) for our recommender system. In our case, this will be the user ratings for each book. PyG implements the relational graph attentional operator from this paper.

RGATConv implements a set of attentive mechanisms, attention computation schemes, and different cardinality preservation options to set up its layer. If you’re curious, please consult its documentation to figure out which settings would work best for you.

For the purposes of this tutorial, we are using additive attention and an across-relation attention mechanism.

Additive mechanism

The across-relation attention mechanism implements a single probability distribution over the different representations of nodes for nodes j in the neighborhood of node i.

Across-relation attention mechanism

Our layer will output heads * out_channels features for each node.

To use RGATConv, we also need to assign a scalar to represent the two nominal edge types (user reads book, book is read by user).

# Assign a edge type (scalar) to the two edges used later by the RGATConv layer
for i, edge_type in enumerate(data.edge_types):
data[edge_type].edge_type = torch.full((num_edges,), i)

HeteroConv

We use torch_geometric’s HeteroConv wrapper for graph convolutions on heterogeneous graphs. HeteroConv allows us to propagate messages between sources and target nodes based on the graph structure for each edge type. This allows us to easily apply the RGATConv layers to our graph with multiple types of nodes and edges. If multiple edge types have the same target node, the results are aggregated using the specified aggregation scheme.

HeteroConv takes as its input a dictionary specifying the message passing layer to apply to each edge type. Hence, using HeteroConv, we can easily apply different message passing layers to different edge types. In our model, we apply the RGATConv layer to both edges. We use a helper function generate_convs to generate this dictionary.

# Helper function to create a dict assigning a GNN layer for each edge type 
# used as input to HeteroConv below.
def generate_convs(hetero_data,
in_channels,
hidden_channels,
out_channels=None,
first_layer=False
):

convs = {}
if not out_channels:
out_channels = hidden_channels
num_relations = len(hetero_data.edge_types)

# We're using the default attention settings for RGATConv
for (src_type, edge_type, dst_type) in hetero_data.edge_types:
if first_layer:
convs[(src_type, edge_type, dst_type)] = RGATConv(in_channels,
hidden_channels, num_relations, edge_dim=1)
else:
convs[(src_type, edge_type, dst_type)] = RGATConv(hidden_channels,
hidden_channels, num_relations, edge_dim=1)
return convs
# We use two RGATConv layers and use HeteroConv since we have two types of edges.
# We use the generate_conv helper function above to create a dict mapping
# GNN layers to each edge type in the graph
self.convs1 = HeteroConv(generate_convs(hetero_data, self.in_channels,
self.hidden_size, first_layer=True))
self.convs2 = HeteroConv(generate_convs(hetero_data,
self.hidden_size, self.hidden_size, first_layer=False))

Now we have everything we need to define our graph neural network:

Diagram for our GNN
class LinkPredictor(nn.Module):
def __init__(self, hetero_data, in_channels, hidden_size):
super().__init__()

self.in_channels = in_channels
self.hidden_size = hidden_size

# The goodreads dataset does not come with embeddings for user and books.
# We learn embeddings for user and books here.
self.user_emb = nn.Embedding(data["user"].num_nodes, in_channels)
self.book_emb = nn.Embedding(data["book"].num_nodes, in_channels)

# We use two RGATConv layers and use HeteroConv since we have two types of edges.
# We use the generate_conv helper function above to create a dict mapping
# GNN layers to each edge type in the graph
self.convs1 = HeteroConv(generate_convs(hetero_data, self.in_channels,
self.hidden_size, first_layer=True))
self.convs2 = HeteroConv(generate_convs(hetero_data, self.hidden_size, self.hidden_size, first_layer=False))

self.bns1 = nn.ModuleDict()
self.bns2 = nn.ModuleDict()
self.relus1 = nn.ModuleDict()
self.relus2 = nn.ModuleDict()

for node_type in hetero_data.node_types:
self.bns1[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1)
self.bns2[node_type] = nn.BatchNorm1d(self.hidden_size, eps=1)
self.relus1[node_type] = torch.nn.LeakyReLU()
self.relus2[node_type] = torch.nn.LeakyReLU()

def forward(self, data: HeteroData) -> Tensor:
'''
The RGATConv layers expect the following input:

x: Node embeddings
edge_index: Edge connectivity
edge_type: Edge type which is a scalar between {0..R-1} for R relations.
edge_attr: Edge attributes, if any to be used during message passing

For using HeteroConv layer, we convert these inputs to dictionary mapping
nodeto their respective embeddings, and edges to edge_index, edge_type, and
edge_attributes.
'''

x = {
"user": self.user_emb(data["user"].node_ids),
"book": self.book_emb(data["book"].node_ids),
}

edge_index_dict = {}
edge_type_dict = {}
edge_attr_dict = {}

for i, edge_type in enumerate(data.edge_types):
edge_index_dict[edge_type] = data[edge_type].edge_index
edge_type_dict[edge_type] = data[edge_type].edge_type
edge_attr_dict[edge_type] = data[edge_type].edge_attr

x = self.convs1(x, edge_index_dict, edge_type_dict=edge_type_dict,
edge_attr_dict=edge_attr_dict)
x = forward_op(x, self.bns1)
x = forward_op(x, self.relus1)
x = self.convs2(x, edge_index_dict, edge_type_dict=edge_type_dict,
edge_attr_dict=edge_attr_dict)
x = forward_op(x, self.bns2)
x = forward_op(x, self.relus2)

# For the final link prediction, we use the user and movie embeddings to
# calculate a final score for each edge.
edge_label_index = data["user", "reads", "book"].edge_label_index
edge_scores = (
x['user'][edge_label_index[0]] * x['book'][edge_label_index[1]]
).sum(dim=-1)

return edge_scores

Evaluation

We use RandomLinkSplit from PyG to divide our graph edges into training, validation, and test subsets. This is analogous to sklearn.model_selection.train_test_split but here we are splitting links or edges.

The training, validation, and test sets all contain the same number of nodes but different edges. The split is performed such that the training set does not contain the edges in the validation or test sets, and the validation set does not contain the edges in the test set. The number of edges to use for validation and test sets is set using num_val and num_test.

Since we are splitting a heterogeneous graph, we need to provide the edge types and reverse edge types. This ensures that no information about edges we want to evaluate on gets leaked in the training phase during message passing. To add negative training samples during link prediction, we set add_negative_train_samples to True. This allows us to minimize the false positive rate. We also set disjoint_train_ratio so that 30% of the training edges will be used as ground-truth labels for supervision during training.

import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score

args = {
'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
'in_channels': 16,
'hidden_size': 16,
'lr': 1e-3,
'epochs': 200,
}

print(f"Device: '{args['device']}'")

# Initialize the model and
model = LinkPredictor(data, args['in_channels'], args['hidden_size']).to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
## Split the dataset

transform = T.RandomLinkSplit(
num_val=0.2,
num_test=0.2,
disjoint_train_ratio=0.3,
neg_sampling_ratio=0.67,
add_negative_train_samples=True,
edge_types=('user', 'reads', 'book'),
rev_edge_types=('book', 'rev_reads', 'user'),
)

train_data, val_data, test_data = transform(data)

We use the ROC_AUC scores to evaluate our model and binary cross-entropy loss.

def loss(self, edge_scores, edge_labels):
loss = F.binary_cross_entropy_with_logits(edge_scores, edge_labels)
return loss

# Helper function to run the train loop for the given model and dataset
def train(model, train_dataset, args):
model.train()
train_dataset = train_dataset.to(args['device'])
for epoch in range(args['epochs']):
optimizer.zero_grad()
edge_scores = model(train_dataset)
edge_labels = train_dataset["user", "reads", "book"].edge_label
loss = model.loss(edge_scores, edge_labels)
loss.backward()
optimizer.step()


edge_predictions = (torch.sigmoid(edge_scores) >= 0.5).int()

roc_score = roc_auc_score(edge_labels.cpu(), edge_predictions.detach().cpu())
acc_score = accuracy_score(edge_labels.cpu(), edge_predictions.detach().cpu())

losses.append(loss)
accuracies.append(acc_score)
roc_scores.append(roc_score)

if epoch % 10 == 0:
print(f"Epoch: {epoch:03d} Loss: {loss:.4f} ROC_AUC Score: {roc_score:.4f} Accuracy: {acc_score:.4f}")
Loss, Accuracy, ROC_AUC curve over epochs
# Helper function to calculate evaluation metrics for the given model and dataset
def evaluate(model, dataset, args):
model.eval()
dataset = dataset.to(args['device'])
edge_scores = model(dataset)
edge_labels = dataset["user", "reads", "book"].edge_label
edge_predictions = (torch.sigmoid(edge_scores) >= 0.5).int()

roc_score = roc_auc_score(edge_labels.cpu(), edge_predictions.detach().cpu())
acc_score = accuracy_score(edge_labels.cpu(), edge_predictions.detach().cpu())

print(f"ROC_AUC Score: {roc_score:.4f} Accuracy: {acc_score:.4f}")

Finally, onto evaluation!

# Evaluate the train model on validation/test data 
print('Model Performance on Validation set')
evaluate(model, val_data, args)
print()
print('Model Performance on Test set')
evaluate(model, test_data, args)
Model Performance on Validation set
ROC_AUC Score: 0.7406 Accuracy: 0.7875

Model Performance on Test set
ROC_AUC Score: 0.7265 Accuracy: 0.7767

You can inspect edge_predictions in evaluate to see what your model predicts. If it reports 1, the model recommends the edge exists. This represents a recommendation for the user represented by the node ID to read the book represented by the other node ID in the edge’s coordinates.

[Image Source]

Conclusion

In this tutorial, we demonstrated how to train and evaluate a recommender system on a heterogenous graph. In particular, we build a heterogenous graph using PyG’s HeteroData object. Our approach can be extended to convert any dataset that does not come with pre-defined edge indices, user features, and edge features to a heterogeneous graph to use for training Graph Neural Networks.

Next, we use the RGATConv message passing layer with the HeteroConv wrapper to perform graph convolutions on a heterogeneous graph with multiple node types and edge types. Finally, for training and evaluating the model, we use RandomLinkSplitter to obtain training, validation, and test sets. This class allows us to enable negative sampling during training and ensures that the training set does not contain any edges in the validation/test sets.

To further improve the model performance, we can experiment with adding additional features to the graph. For example, we could use an NLP model to generate sentiment scores from the user reviews of specific books, that can be added as an edge attribute to the graph. When working with multiple genres of the book, we could add a node feature for the book nodes to specify the genre to give recommendations in genres liked by a particular user. The tutorial can also be extended to train other datasets and use different message passing layers for various edge types.

More resources/tutorials to consult:

GNN Intro by Google

LightGCN tutorial for movie recommender system

PyG’s official tutorial for link prediction on heterogeneous graphs

References:

--

--