# GNN-Based Link Prediction in Drug-Drug Interaction Networks

## An introduction to Graph Machine Learning using PyG

*By Anfal Siddiqui as part of the Stanford CS224W course project.*

# Overview

In recent years, there has been an explosion in **Graph Machine Learning** (GML). Unlike Natural Language Processing or Computer Vision, which deal with nicely structured data, GML works with messy, unstructured graphs. The core family of models that have driven this revolution are known as **Graph Neural Networks (GNN)**.

In this post, we provide a practical guide for beginners on how to apply a GNN to a real-world problem: predicting interactions between different pairs of drugs. We focus primarily on the **GraphSage** model, providing both conceptual guidance and implementation tutorials for the model and a variety of advanced techniques in the popular PyG (PyTorch Geometric) framework. You can follow along in our Google Colab:

# Dataset and Task

We will be working with the ogbl-ddi dataset, one of the datasets made available through the Open Graph Benchmark (OGB). As described in [1], ogbl-ddi is a homogenous, unweighted, undirected graph that represents the drug-drug interaction network. The feature-less nodes in the graph represent either FDA-approved or experimental drugs. Edges between nodes in the graph represent interactions between the drugs, where the joint effect of taking both drugs is markedly different than the expected effects if either drug was taken independently.

The graph consists of 4,267 nodes and 1,334,889 edges. It has an average node degree of 501 and a clustering coefficient of 0.51, demonstrating that this is a considerably denser graph than those seen in many GML problems.

The task associated with this dataset is link/edge prediction — namely predicting drug-drug interactions given only information from known drug-drug interactions [1]. We make use of the dataset split provided by OGB, which consists of 1,067,911 training edges, 133,489 validation edges, and 133,489 test edges. We will progressively augment the base graph throughout this post in different ways.

In keeping with prior work with ogbl-ddi, we evaluate the performance of our models using Hits@K. Using this metric, the goal is to rank true drug interactions higher than drugs that do not interact. As originally executed in [1], Hits@K’s calculation can be viewed as a two-step process. First, the model’s predictions for true drug interactions are ranked among a set of approximately 100,000 randomly sampled negative drug interactions. Then, for each positive edge, if the edge ranks among the top K, excluding all other positive edges, `num_pos_in_top_k`

is incremented. Hits@K is then calculated as `num_pos_in_top_k/total_num_pos_edges`

*. *Prior research found K=20 to be a good value, so we will primarily report on Hits@20 and provide accuracy as an additional datapoint.

# Base Architecture

## Node Embedding Generation Using GraphSage

The base GNN model from which we iterate upon throughout this post is GraphSage [2]. GraphSage is one of the classical GNN models, which are characterized by a two-step process of message and aggregation (along with a non-linearity) for creating node embeddings. The intuition behind GNN models like GraphSage is that the resultant embeddings for each node will incorporate aggregated information from the node’s neighbors. Stacking successive GNN layers allows nodes to increase their receptive field, i.e. incorporate information from neighbors increasingly further away.

Concretely, in GraphSage, the *lth-layer* embedding for node **v** is derived from the following equation:

We can breakdown this equation into two “stages”. First, we aggregate “messages” from node **v**’s neighbors, which effectively means aggregating their embeddings from the (*l*-1)th layer, with the first layer just being the initial node features:

GraphSage, unlike some other GNN models, then has an additional step of aggregating the results from **v**’s neighbors with **v** itself, which means utilizing **v**’s embedding from the (*l*-1)th layer to derive its *l*th-layer embedding:

Here, σ can be any non-linearity, such as RELU. AGG can take on various different forms, ranging from a simple average or sum over all the messages from the neighbors, to more complex formulations that incorporate a multilayer perceptron or LSTM to act upon neighboring nodes’ embeddings.

Additionally, L2-normalization can be applied to the embeddings at each layer, which may offer a performance boost in some cases. GraphSage can optionally *sample* neighbors to aggregate, which is separately supported by PyG.

## Link Prediction Using Node Embedding

Once we have generated node embeddings using our GNN, we can then use the embeddings for any pair of nodes to predict whether they should be linked. In our base model, the LinkPredictor is simply the dot product between the two node embeddings, followed by a sigmoid non-linearity.

## PyG Implementation

Implementing our base model (`SAGE`

+ `DotProductLinkPredictor`

) is straightforward using PyG, making use of the out-of-the-box `SAGEConv`

layer.

## Baseline Results

In our baseline setup, all nodes have an initial feature of just **1**. As such, we make use of “sum” aggregation. “Sum” aggregation is viewed as a more expressive aggregation scheme than “mean” or “max” as it can better differentiate nodes with different neighborhood structures, especially in our constant node feature setup (ex: a node with **N** neighbors would get a combined message of **N **with “sum” aggregation, but all nodes would get 1 with “max” or “mean”) [4]**.**

After training the model defined above with 7 `SAGEConv`

layers, we get an accuracy of roughly ~77% and a Hits@20 of 0 on our validation set. Clearly this model is struggling to distinguish positive and negative edges, with many negative edges getting very high scores. However, it provides the perfect starting point to build upon.

# Model Enhancements

We first explore ways to modify the GNN and the LinkPredictor to produce a model that is more performant.

## Skip-Connections

Unlike with some other deep learning models, simply adding more layers to a GNN doesn’t always help. With each additional layer, we expand the receptive field of each node — i.e. the **k**-hop neighborhood that is involved in calculating its embedding — by 1. We can quickly reach a point where the **k**-hop neighborhoods for all nodes are nearly identical, resulting in the final embeddings for each node being extremely similar, or over-smoothed.

However, with a large graph like ogbl-ddi, we suspect there is value in having a larger receptive field that would let our embeddings incorporate information from further away neighbors/drugs. To mitigate the risk of over-smoothing, we can add **skip-connections**, where the input to a layer is added onto the layer’s output. Intuitively, because earlier layers in the GNN are less smoothed, they may do a better job differentiating nodes. Adding a skip-connection easily allows us to increase the impact of earlier layers and gives us a mixture of shallow and deep models, where the model can choose to even skip entire layers if they are not helpful and just rely on the skip-connection [5].

Implementing a basic skip-connection in PyG is straightforward.

After training the model defined above, we actually see a drop in accuracy to ~56%, with the same 0 Hits@20 score. This poor performance using skip-connections suggests that earlier layers in the model have a negative impact, likely because all of our nodes have the same initial feature and earlier embeddings do a poor job differentiating between them. The skip-connection likely makes training harder by requiring the model to work around it, leading to a drop in performance. However, skip-connections can be a helpful technique in other scenarios (like when nodes have different initial features), and readers are encouraged to keep them in mind.

## Post-Processing Layers

Although adding additional convolutional layers can get problematic, a GNN can contain *other *layers, such as pre-processing and post-processing layers. These layers do not pass messages, but instead apply multilayer perceptrons to the node embeddings.

Post-processing layers, in particular, are useful to further prepare node embeddings for additional downstream reasoning tasks — such as use in link prediction [5]!

Adding on post-processing layers to our base GraphSage model is straightforward, requiring only standard PyTorch blocks.

After training the model defined above with 4 layers of post-processing, we see a small gain in accuracy to ~80% and a boost in Hits@20 to 0.119. Just with this simple change, we are already moving in the right direction — emphasizing the value in considering post-processing layers when working on your own GML tasks.

## Neural Link Predictor

While we have so far focused on enhancing the GNN, itself, to get better embeddings, the `LinkPredictor`

— which actually makes the predictions — is also an avenue for experimentation. In our base model, it’s just a dot product between embeddings with a sigmoid. However, as we saw with our post-processing layer experiment, passing our embeddings through MLPs can boost performance. Rather than perform this post-processing as part of the GNN, we can instead shift it to the `LinkPredictor`

and turn that into its own small neural network.

Given the embeddings of two nodes, we can pass their element-wise product through a neural network to generate a more expressive prediction of whether they should be linked. Although this seems similar to the post-processing we did previously, the limitation of the prior method was that we had to generate individual node embeddings that were so powerful that they could directly be used for making link predictions for all possible pairs of nodes. By instead having our `LinkPredictor`

be an independent neural network that can operate directly on pairs of nodes and take into account their unique interactions, this task becomes easier and the model becomes much more expressive.

Just like with the post-processing layers, implementing our Neural Link Predictor just requires some standard PyTorch.

Training for 100 epochs using the `NeuralLinkPredictor`

and our base `SAGE`

model gets us to an accuracy of 94% and a Hits@20 of 0.32, a considerable jump that shows the tremendous advantage of making our `LinkPredictor`

more expressive!

# Node Feature Augmentation

Up until now, we have focused our efforts on improving our model. We now turn to ways we can augment our initial graph to yield better performance.

## Issue With Constant Node Features

Thus far, we have taken the standard approach of assigning constant initial features to all nodes in our dataset. However, this can introduce difficulties in training. As all the nodes have the same initial feature, the earlier layers in the GNN produce embeddings that are not well-differentiated. Even as we get to higher layers in our GNN, the poor quality of the earlier layers can continue to negatively affect the outputs (due to the dependence on prior layers to compute higher layers), potentially making convergence more difficult.

## Augmenting With Node Features

Rather than use constant features, we can instead create initial features for our nodes that take their individual attributes into account, which can potentially allow for better differentiation of nodes early on and improve our final embeddings [6].

The exact features to choose are highly problem dependent and require experimentation. However, for demonstration purposes, we choose to make use of PageRank, Clustering Coefficient, Betweenness Centrality, and Node Degree, leveraging NetworkX to compute these common node statistics:

Notice that we opted to keep a constant dimension in our initial node features, even after augmentation. Intuitively, this would allow the model to still be able to learn as well as in the constant node feature scenario (as it could choose to simply ignore the additional features). Once the embedding is defined, we simply pass it in as the initial node features in our training loop instead of the constant embedding.

## Results

Using our base `SAGE`

model (except with only 5 layers instead of 7) and the `NeuralLinkPredictor`

, training with our augmented node features yields a model that still achieves an accuracy of ~92%, and comparable Hits@20 of 0.295. Even with minimal feature engineering, node feature augmentation yielded us similar performance as before with a shallower network that trains faster, demonstrating its benefits. Depending on your graph and task, you can try more complex node features like cycle counts, graphlet vectors, etc., which may yield even better results!

# Distance Encoding + Incorporating Edge Features

We now delve into incorporating distance encoding through **edge features**, our most advanced technique — both conceptually and programmatically.

## Motivation

The highest performing models on the ogbl-ddi leaderboard have used some form of distance encoding. As described in [3], incorporating distance encoding into link prediction makes intuitive sense: two nodes are more likely to be connected by an edge if the distance between them in the existing graph is small.

However, calculating the shortest path distance between nodes can get expensive, and even incorporating those distances into the `LinkPredictor`

can get complex. Ideally, the node embeddings we produce should have notions of distances between the nodes they represent baked into them, such that the `LinkPredictor`

can seamlessly act on them.

As demonstrated in prior work [3], one avenue for achieving this is by calculating “edge features” that capture approximated distances between nodes and factoring them into our GNN. Thus far, we have relied on traditional GraphSage formulations that do not incorporate edge features. To change this, we will peel back the layers of the built-in PyG `SAGEConv`

layer and learn how we can create our own `MessagePassing`

layer! We present a slightly simplified version of the approach taken in the state-of-the-art models to fit within Colab GPU limits and for easier understanding.

## Approximating Shortest-Path Distances

Trying to calculate the shortest-path distances between all nodes can get computationally intractable. Instead, we can approximate distances by leveraging an anchor set of **K** randomly sampled nodes [3]. We first calculate the distance from each of these **K** nodes to all other nodes in the graph. We then use these distances to get a reasonable approximation of the distance between any two nodes in the graph:

Let **d_k,x **denote the shortest path distance from node **k** to node **x**. We can thus approximate the shortest path between nodes **u** and **v** as the follows:

Naturally, larger values of **K **give better estimates, but are computationally more expensive. Prior work has found **K = 200 **to be reasonable. We can easily create our matrix of anchor set distances using NetworkX.

Here, `spd`

** **is a matrix where the* i*th row holds the distance between node **i** and all the nodes in the anchor set. Thus, we have:

## Using Distances in Message Computation

As a reminder, during GNN computation of the embedding for node **u** in traditional GraphSage, **u** received messages from each of its neighbors, **v**. The *message computation* for the message from **v** so far has just been taking its embedding from the prior layer, h_v^{l-1}. However, we can also incorporate features for the edge

in our message computation — which in our case will be the **v -> u****d_u,v** we calculated using

. Thus, the message sent by **spd****v **will be:

## Building Our Own MessagePassing Layer

With the conceptual underpinning fleshed out, we can begin implementing our own `MessagePassing`

layer that makes use of the distance edge features. PyG already has great documentation about creating your own custom MessagePassing/Conv layer. Rather than reproduce that resource, we will instead start by showing our specific implementation.

The most important parts to look at above are the initial lines of `forward`

and `message`

**. **`forward`

** **takes in `spd`

** **(the

matrix calculated above) and passes it and the prior layer embeddings *node* x *K*`x`

along to `propagate`

**, **one of the methods defined on the** **`MessagePassing`

** **superclass, which initiates the message passing process.

Message computation actually occurs in the** **`message`

function. Notice how we have arguments in `message`

defined with `_i`

and `_j`

. Recall that node **j** passes along a message to node **i **if there is an edge** j -> i. **By specifying `spd_i`

and `spd_j`

** **in the argument list, the `MessagePassing`

superclass logic takes care of breaking up `spd`

using the `edge_index`

: `spd_j`

will contain some **n**th row that is node **j**’s entry from `spd`

** **and the **n**th row of `spd_i`

will contain the `spd`

row for node **i** (`x`

** **is similarly broken up into `x_j`

)** **. See the visual depiction of this above. With these matrices nicely setup for us by PyG, we can quickly calculate our distances in a vectorized manner with:

`torch.mean(spd_i + spd_j, 1, True)`

We then simply apply a linear transformation to this result and add it onto the prior layer embeddings to compute our messages!

## Results

Using our custom `MessagePassing`

layer in conjunction with the `NeuralLinkPredictor`

, we are able to train a model that yields an accuracy of 84% and a Hits@20 of 0.11. Though we had to make many compromises such as reducing our node dimensionality and using a less-precise `MessagePassing`

formulation in order to work with Colab’s limited GPU memory (see our Colab for details), distance encoding still produced somewhat reasonable results.

More elaborate implementations of distance encoding take advantage of multiple anchor sets for better approximations, more GPU capacity, higher values of **K**, etc. and have yielded state-of-the-art results on the ogbl-ddi dataset. Though we did not replicate those results, this exercise has now armed you with first-hand knowledge of how to create custom `MessagePassing`

layers within PyG to implement sophisticated GML algorithms!

# Conclusion

In this post, you have learned both about GNNs/GraphSage and how to implement numerous techniques that can improve the performance of your models. The `NeuralLinkPredictor`

proved to be our most reliable technique, and after 500 epochs of training, it scores 94% accuracy and 0.32 on Hits@20 on the challenging ogbl-ddi test set. Not bad!

You’ve come far on your GML journey, but there is still so much to learn about recommender systems, making graph/node predictions, and beyond! Some additional resources on where you can go next are linked below. Happy learning!

## Additional Resources

[1] Hu, Weihua, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and Jure Leskovec. “Open graph benchmark: Datasets for machine learning on graphs.” *arXiv preprint arXiv:2005.00687* (2020)

[2] Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” *Advances in neural information processing systems* 30 (2017)

[3] Li, Boning, Yingce Xia, Shufang Xie, Lijun Wu, and Tao Qin. “Distance-Enhanced Graph Neural Network for Link Prediction.” (2021)

[4] Leskovec, J Lecture 9 Slide 62, Stanford CS224W Fall 2021

[5] Leskovec, J Lecture 7 Slides 65–67, Stanford CS224W Fall 2021

[6] Leskovec, J Lecture 8 Slide 27, Stanford CS224W Fall 2021