# Read it, or Reddit? A GNN Approach to Predicting Relationships on Reddit

By: Finn Dayton and Brandon Minsung Kang as part of the Stanford CS224W course project

*Feel free to follow along with the associated **Colab**!*

# Introduction

Imagine: you’re mindlessly zombie scrolling through your Reddit home feed, yearning for a post that’s worth interacting with. Unfortunately, your feed is riddled with irrelevant posts from irrelevant subreddits. You don’t care about posts from *r/mildlyinteresting*… After all, the entire subreddit is simply filled with content that wasn’t interesting enough to make r/interesting! Tragically, that’s the sad reality of Reddit scrolling for many with countless hours of human productivity lost everyday to the depths of un-engaging content.

And that’s where Graph Machine Learning (GraphML) comes in. With the power of graph representation and Graph Neural Networks (GNNs), we propose that the aforementioned problem can be represented as a GraphML task where we can optimize Reddit home pages and feeds to display content from communities that a user is more likely to interact with based on their previous activity.

So follow along on our exploration of how we can implement GNNs to leverage the rich network structure of Reddit and capture relationships between users, post, and communities in the hopes of delivering a better browsing experience.

**Table of Contents**

We structure our post into 5 broad sections:

- Exploration of Our Dataset
- Explanation of Models
- Walkthrough of Approach
- Discussion of Results
- Conclusion

# Exploration of Our Dataset

**Dataset Structure**

The Reddit dataset [1] is a graph dataset that was created from Reddit posts in September 2014. The authors sampled 50 large communities to create a dataset containing 232,965 posts with an average degree of 492. Node labels are constructed from the “subreddit” that a post belongs to. From this, the authors built a post-to-post graph that connects posts where the same user commented on both. Additionally, each post has a length 602 feature vector that describe the average embedding of the post title, the average embedding of all the post’s comments, the post’s score, and the number of comments made on the post. The Reddit graph can be summarized as an undirected graph with no isolated nodes and no self-loops. The graph is characterized by several different node types that represent various subreddits and a single edge type that represents whether two posts have a shared commenter. Our code below provides a summary of key statistics that might help with understanding the type of information that the dataset contains.

`from torch_geometric.datasets import Reddit`

`# Load the Reddit dataset`

dataset = Reddit(root='data/Reddit')

# Extract a single graph from the dataset

data = dataset[0]

. . .

>>> Average node degree: 492

>>> Number of nodes: 232965

>>> Number of edges: 114615892

>>> Number of features per node: 602

>>> Number of features per edge: 0

**Using the Dataset**

The task we explore is link prediction. Link prediction describes a task where the goal is to predict missing or future connections between nodes in a network. In the context of our dataset, our link prediction task will involve predicting which posts a user would most likely comment on based on the nodes (posts) that they’ve previously commented on as well as the links that exist between posts that they’ve commented on. Broadly speaking, we can build a model optimized for this task by masking existing connections between nodes and training our model to correctly predict the hidden edges. More details to follow!

**Data Preprocessing**

To make our dataset more wieldy, we first perform several preprocessing steps. We first import the Reddit dataset directly from **torch_geometric**, a library built upon PyTorch to specifically help train GNNs. We then use the associated Data object. This helps reduce the complexity of downloading, re-uploading, and formatting the json files provided in the original dataset’s zip. See the above code block if you need a refresheer on how to import and extract the Reddit Dataset

As we mentioned before, the Reddit dataset is rich with both a large number of nodes (2,329,645 nodes) as well as high connectivity between nodes (114,615,892 edges). This makes it both time consuming and computationally expensive to perform explorations as well as model building with the RAM and the GPU sizes provided by Colab Pro. To solve this issue, we selected a subgraph (a subset of nodes/edges from our original graph) that would help improve computability while remaining* representative of the original graph*. We do this in the following steps:

- Find representative nodes. Here, we define “representative” nodes to be nodes that have node degrees equal to the average node degree in the graph (492 edges).
- Randomly select a subset of representative nodes.
- Perform K-Hop subgraph generation from those representative nodes
- Create a subgraph object from the subset of edges and nodes generated from the above steps

The following code demonstrates our implementation of the above steps.

`# Calculate degree for each node in graph`

degrees = degree(data.edge_index[0], num_nodes=data.num_nodes)

# Find Node(s) with degree 492

nodes_with_degree_492 = (degrees == 492).nonzero(as_tuple=False).squeeze()

print(nodes_with_degree_492)

# Calculate number of nodes with degree 492

n = nodes_with_degree_492.shape[0]

# Randomly select sample of nodes with degree 492, number_of_samples can be changed as desired

number_of_samples = 10

# Sampled nodes of the correct degree

nodes = []

for i in range(number_of_samples):

random_index = random.randint(0, n-1)

node_id = nodes_with_degree_492[random_index].item()

nodes.append(node_id)

________________________

from torch_geometric.data import Data

# Create a representative subgraph generating K-Hop Neighborhood from subset of "Average Nodes"

subset, edge_index, mapping, edge_mask = k_hop_subgraph(nodes, 1, edge_index=data.edge_index,relabel_nodes=True)

# Our subgraph - voila!

sub_graph = Data(x=data.x[subset], edge_index=edge_index,

y=data.y[subset])

Once these steps are done, we can create a NetworkX Graph object of this final subgraph. This helps us create interesting visualizations such as the one below.

Note: Though each time we re-run the graph generation code, we get a graph with the same number of nodes, the initial nodes are sampled randomly. This means you will get a slightly different looking graph, but most of the characteristics (i.e. graph size, connectedness, edge direction, subreddit variety) will be the same.

**Splitting our Dataset**

One final step! We partition **sub_graph **into training, validation, and test graphs. This is an important step in model development as it allows us to evaluate the performance of our models on data that it has not seen during training.

We perform a 80/10/10 Train/Validate/Test split on the edges in our subgraph after shuffling the order of our edge list to add randomization. After getting the respective edge sets, we then add the corresponding nodes. This ensures that the three resulting graphs will not have any stranded nodes. This technique gives us three NetworkX Graph objects which we then transform back to PyTorch Geometric objects to facilitate training and testing on our GNN model.

# Explanation of Models

Now that we’ve defined our problem space, let’s explore how Graph Neural Networks can help us predict what Reddit posts that you might like! First, let’s go over some basic definitions that will help you understand the concepts we introduce:

**What is a Graph Neural Network (GNN)?**

A Graph Neural Network (GNN) is a type of neural network that is designed to work with graph data structures, such as social networks or molecular structures. GNNs can learn to represent nodes and edges in a graph as vectors, allowing them to perform machine learning tasks on graph data. In other words, a GNN learns embeddings for nodes based on both their features and their neighborhood structure in the graph.

GNNs are a general type of model and have several implementations including Graphical Attention Networks (GAT) [2] and GraphSAGE [3].

**GraphSAGE**

In our application, GraphSAGE is one of the key components of our GNN architecture. The GraphSAGE layer has the following properties. For a given central node, we have the following message passing update rule:

Here, W_1 and W_2 are learnable weight matrices that apply linear transformations to the central node embedding and aggregation outputs, respectively. The nodes that are being aggregated over are the nodes neighboring the central node. In our application, we use mean aggregation for simplicity.

Now wait — you might be wondering, what does message passing and message aggregation even mean? Fair question. We provide basic definitions for both below:

- Message Passing — the passing of data between nodes that are connected by edges. In the context of a GNN, each layer incorporates message passing that allows for a node to update its embedding based on its neighbors.
- Message Aggregation — an aggregation step is applied once a node receives messages from its neighbors. Aggregation determines how the messages are combined. Examples include sum, mean, max aggregation.

A brief graphic that demonstrates message passing and aggregation below:

Though the math may look convoluted to the uninitiated, fear not! Simply put, the entire point is to learn an embedding (vector) for a given node.

**LinkPredictor Class**

Given two embeddings for nodes A and B, you might ask, how can we predict 0 (unconnected) or 1 (connected)? This question is key to our link prediction task as it involves calculating the probability that an edge between two nodes might exist.

For our LinkPredictor, we decided to use a straightforward set of two linear layers, using ReLU and dropout between them. Essentially, given two node embeddings outputted from the GNNStack, the LinkPredictor will either perform element-wise multiplication or addition (up to you!) to combine them. This approach was inspired by the paper “Link Prediction Based on Graph Neural Networks” 2018 by Yang and Chen [4].

The entire model, as depicted above, contains several back-to-back GraphSAGE models to learn deeper embeddings and the LinkPredictor class is tacked on the end. Our implementation can be found below:

`class LinkPredictor(nn.Module):`

def __init__(self, input_dim, hidden_dim, dropout, strategy="multiply"):

super(LinkPredictor, self).__init__()

self.input_dim = input_dim

self.hidden_dim = hidden_dim

self.strategy = strategy

self.model = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim),

nn.ReLU(),

nn.Dropout(p = dropout),

nn.Linear(self.hidden_dim, 1),

nn.Sigmoid())

def forward(self, a, b):

if self.strategy == "addition":

return self.model(a + b)

elif self.strategy == "multiply":

return self.model(a * b)

# Walkthrough of Approach

After splitting our subgraph into three subgraphs: train, validation and test, there is one more important nuance that we call attention to. It is important to generate a set of “negative” edges (this refers to edges that do not exist in our graphs) that is equivalent to the number of positive edges that exist in our graphs. This becomes imperative as we desire a model that can differentiate between edges that exist and do not exist in our graph structures. Thus, providing a negative sample can help inform our model. We provide a glimpse of this process below:

`# args contains dropout and num_layers`

class model(nn.Module):

def __init__(self, args, input_dim=602, hidden_dim=256, output_dim=256, layer=GraphSage, link_predictor_class=LinkPredictor, stacker=GNNStack):

super(model, self).__init__()

self.__dict__.update(locals())

self.stacked_layers_model = self.stacker(layer, input_dim, hidden_dim, output_dim, args).to(device) #these are the graph sage layers

self.link_predictor = link_predictor_class(output_dim, hidden_dim, dropout=args['dropout']).to(device)

def forward(self, x, edge_index):

# _, num_edges = sub_graph.edge_index.size()

node_embeddings = self.stacked_layers_model(x, edge_index)

num_nodes, embedding_him = node_embeddings.size()

num_edges = edge_index.size(1)

# sample positive edges (indices)

pos_edge_indices = edge_index

# print(f"\n pos_edge_indices: {pos_edge_indices}\n")

# print(f"\n num_edges: {num_edges}\n")

# sample negative edges (indices)

neg_edge_indices = negative_sampling(edge_index, num_nodes, num_neg_samples=int(num_edges), method='dense')

# feed these embeddings into link_predictor to get neg_preds and pos_preds

link_predictions_pos = self.link_predictor(node_embeddings[pos_edge_indices[0]], node_embeddings[pos_edge_indices[1]]) # Brandon TODO

link_predictions_neg = self.link_predictor(node_embeddings[neg_edge_indices[0]], node_embeddings[neg_edge_indices[1]]) # Brandon TODO

# print(link_predictions_pos)

# print(link_predictions_neg)

# print(f"\n link_predictions_pos: {link_predictions_pos.size()}, unique values: {len(torch.unique(link_predictions_pos.view(-1)))}")

# print(f"link_predictions_neg: {link_predictions_pos.size()}, unique values: {len(torch.unique(link_predictions_neg.view(-1)))} \n")

# if self.eval:

return link_predictions_pos, link_predictions_neg

Once we have our negative samples, we’re all ready to go! From here, we run the train graph through the entire model, get the predictions, then back-propagate back through all the gradients in the model. Every five epochs, we run the validation and test graphs through the model to see how the model is generalizing to unseen graphs. The below code performs this. See the linked Colab for definitions of the helper functions.

`import matplotlib.pyplot as plt`

epochs = args['epochs']

epochs_bar = trange(1, epochs + 1, desc='Loss n/a')

# extract the edge indices and node feature matrices

x_train = pyg_train_graph.x

edge_index_train = pyg_train_graph.edge_index

x_val = pyg_val_graph.x

edge_index_val = pyg_val_graph.edge_index

x_test = pyg_test_graph.x

edge_index_test = pyg_test_graph.edge_index

# move to cuda

x_train = x_train.to(device)

x_val = x_val.to(device)

x_test = x_test.to(device)

edge_index_train = edge_index_train.to(device)

edge_index_val = edge_index_val.to(device)

edge_index_test = edge_index_test.to(device)

# instantiate the model and said it to train mode

my_model = model(args)

my_model.train()

# get optimizer

optimizer = optim.Adam(my_model.parameters(), lr=.005)

# convert the edge_index tensors into Pytorch Dataset type so we can iteration of them in train()

edge_index_train = TensorDataset(edge_index_train)

edge_index_val = TensorDataset(edge_index_val)

edge_index_test = TensorDataset(edge_index_test)

# Train the model

losses = []

valid_hits_validation_list = []

valid_hits_test_list = []

for epoch in epochs_bar:

epoch_loss, total_examples = 0, 0

loss = train(x_train, edge_index_train, args, my_model, optimizer)

losses.append(loss)

epochs_bar.set_description(f'Epoch {epoch}, Loss {loss:0.4f}')

if epoch % 5 == 0:

valid_hits_validation = test(my_model, x_val, edge_index_val, args, k=10)

valid_hits_validation_list.append(valid_hits_validation)

valid_hits_test = test(my_model, x_val, edge_index_val, args, k=10)

valid_hits_test_list.append(valid_hits_test)

print(f'Epoch: {epoch}, Train loss: {loss}, Validation Hits@20: {valid_hits_validation}, Test Hits@20: {valid_hits_test}')

else:

valid_hits_validation_list.append(valid_hits_validation_list[-1] if valid_hits_validation_list else 0)

valid_hits_test_list.append(valid_hits_test_list[-1] if valid_hits_test_list else 0)

# Discussion of Results

We trained our model for 100 epochs on the train graph, running the validation and test graphs though the model every five epochs of training.

Our evaluation metric for the validation and test models is Hits@K where K is a variable set to 20 in our code. Here, we define Hits@K as the number of the top K recommendations that are actually linked.

In other words, the sigmoid at the end of the Link Predictor class outputs a number between 0.0 and 1.0 for each given pair. We sort these predictions in descending order. Hits@20 is the number of pairs of nodes in the top 20 (with high predicted probability of a link between them) that *actually *have a link between them. Therefore, the best score is 20, and the worst score is 0. We can also represent this number as a decimal between 0.0 and 1.0 where the denominator is K and numerator is the number of “hits”. The code for calculating hits@K is shown below:

`def calculate_hits_at_k(pos_preds,neg_preds,k=20):`

tensor_ones = torch.ones_like(pos_preds)

tensor_zeros = torch.zeros_like(neg_preds)

pos_preds_labeled = torch.cat((pos_preds, tensor_ones),dim=1)

neg_preds_labeled = torch.cat((neg_preds, tensor_zeros), dim=1)

combined_preds = torch.cat((pos_preds_labeled, neg_preds_labeled), dim=0)

sorted_combined_preds, indices = torch.sort(combined_preds[:,0], descending=True)

sorted_combined_preds = combined_preds[indices]

hit_indices = torch.arange(k)

hits = torch.sum(sorted_combined_preds[hit_indices,1])

return(hits/k)

Overall, our model performed well. Results will vary based on the K value for Hits@K, and the other hyperparameters. As you can see, the model’s train loss went down monotonically across the epochs, but the validation converged to 100% Hits@K after only 60 epochs. If we had more time, we might consider improvements relating to gradient clipping, learning rate decay, and sampling larger graphs to train over.

# Conclusion

If you’ve stuck around until this point, thank you! We hope you’ve learned as much as we’ve learned as we built up our project. Our work only touches the surface of what GNNs can do and where they can be applied. But before we wrap things up, let’s summarize everything we’ve learned today.

- We learned about the NetworkX and PyTorch Geometric libraries and explored the tradeoffs between using graph representations in one library over another.
- We explored how to generate representative subgraphs to help improve the computability and explorability of large graphs.
- We discovered how to build a GraphSAGE GNN and utilized the Link Predictor class to help build an effective model for link prediction.

Overall, we’re excited because we believe these insights and explorations can be extended beyond Reddit to any form of media that can be represented by graph structures. We hope you’re proud of what you’ve achieved and that this marks the beginning of further explorations of Graph Neural Networks.

# References

[1] Reddit Dataset. Pytorch Geometric. https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/reddit.html

[2] PetarVelicˇkovic ́,GuillemCucurull,ArantxaCasanova,AdrianaRomero,PietroLiò,andYoshua Bengio. Graph attention networks, 2018.

[3] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs, 2018.

[4] Muhan Zhang and Yixin Chen. Link prediction based on graph neural networks, 2018.