GNN-Based Link Prediction in Drug-Drug Interaction Networks

An introduction to Graph Machine Learning using PyG

Anfal Siddiqui
Stanford CS224W GraphML Tutorials
14 min readJan 16, 2022

--

Photo by freestocks on Unsplash

By Anfal Siddiqui as part of the Stanford CS224W course project.

Overview

In recent years, there has been an explosion in Graph Machine Learning (GML). Unlike Natural Language Processing or Computer Vision, which deal with nicely structured data, GML works with messy, unstructured graphs. The core family of models that have driven this revolution are known as Graph Neural Networks (GNN).

In this post, we provide a practical guide for beginners on how to apply a GNN to a real-world problem: predicting interactions between different pairs of drugs. We focus primarily on the GraphSage model, providing both conceptual guidance and implementation tutorials for the model and a variety of advanced techniques in the popular PyG (PyTorch Geometric) framework. You can follow along in our Google Colab:

Dataset and Task

We will be working with the ogbl-ddi dataset, one of the datasets made available through the Open Graph Benchmark (OGB). As described in [1], ogbl-ddi is a homogenous, unweighted, undirected graph that represents the drug-drug interaction network. The feature-less nodes in the graph represent either FDA-approved or experimental drugs. Edges between nodes in the graph represent interactions between the drugs, where the joint effect of taking both drugs is markedly different than the expected effects if either drug was taken independently.

The graph consists of 4,267 nodes and 1,334,889 edges. It has an average node degree of 501 and a clustering coefficient of 0.51, demonstrating that this is a considerably denser graph than those seen in many GML problems.

The task associated with this dataset is link/edge prediction — namely predicting drug-drug interactions given only information from known drug-drug interactions [1]. We make use of the dataset split provided by OGB, which consists of 1,067,911 training edges, 133,489 validation edges, and 133,489 test edges. We will progressively augment the base graph throughout this post in different ways.

A visualization of the entire ogbl-ddi training graph

In keeping with prior work with ogbl-ddi, we evaluate the performance of our models using Hits@K. Using this metric, the goal is to rank true drug interactions higher than drugs that do not interact. As originally executed in [1], Hits@K’s calculation can be viewed as a two-step process. First, the model’s predictions for true drug interactions are ranked among a set of approximately 100,000 randomly sampled negative drug interactions. Then, for each positive edge, if the edge ranks among the top K, excluding all other positive edges, num_pos_in_top_k is incremented. Hits@K is then calculated as num_pos_in_top_k/total_num_pos_edges. Prior research found K=20 to be a good value, so we will primarily report on Hits@20 and provide accuracy as an additional datapoint.

Base Architecture

Node Embedding Generation Using GraphSage

The base GNN model from which we iterate upon throughout this post is GraphSage [2]. GraphSage is one of the classical GNN models, which are characterized by a two-step process of message and aggregation (along with a non-linearity) for creating node embeddings. The intuition behind GNN models like GraphSage is that the resultant embeddings for each node will incorporate aggregated information from the node’s neighbors. Stacking successive GNN layers allows nodes to increase their receptive field, i.e. incorporate information from neighbors increasingly further away.

Concretely, in GraphSage, the lth-layer embedding for node v is derived from the following equation:

We can breakdown this equation into two “stages”. First, we aggregate “messages” from node v’s neighbors, which effectively means aggregating their embeddings from the (l-1)th layer, with the first layer just being the initial node features:

GraphSage, unlike some other GNN models, then has an additional step of aggregating the results from v’s neighbors with v itself, which means utilizing v’s embedding from the (l-1)th layer to derive its lth-layer embedding:

Here, σ can be any non-linearity, such as RELU. AGG can take on various different forms, ranging from a simple average or sum over all the messages from the neighbors, to more complex formulations that incorporate a multilayer perceptron or LSTM to act upon neighboring nodes’ embeddings.

Additionally, L2-normalization can be applied to the embeddings at each layer, which may offer a performance boost in some cases. GraphSage can optionally sample neighbors to aggregate, which is separately supported by PyG.

Visualization of a 2-layer GraphSage model generating an embedding for node A. Red arrows indicate self-aggregation. Layer 0 is the input features

Link Prediction Using Node Embedding

Once we have generated node embeddings using our GNN, we can then use the embeddings for any pair of nodes to predict whether they should be linked. In our base model, the LinkPredictor is simply the dot product between the two node embeddings, followed by a sigmoid non-linearity.

PyG Implementation

Implementing our base model (SAGE + DotProductLinkPredictor) is straightforward using PyG, making use of the out-of-the-box SAGEConv layer.

Baseline Results

In our baseline setup, all nodes have an initial feature of just 1. As such, we make use of “sum” aggregation. “Sum” aggregation is viewed as a more expressive aggregation scheme than “mean” or “max” as it can better differentiate nodes with different neighborhood structures, especially in our constant node feature setup (ex: a node with N neighbors would get a combined message of N with “sum” aggregation, but all nodes would get 1 with “max” or “mean”) [4].

After training the model defined above with 7 SAGEConv layers, we get an accuracy of roughly ~77% and a Hits@20 of 0 on our validation set. Clearly this model is struggling to distinguish positive and negative edges, with many negative edges getting very high scores. However, it provides the perfect starting point to build upon.

Model Enhancements

We first explore ways to modify the GNN and the LinkPredictor to produce a model that is more performant.

Skip-Connections

Unlike with some other deep learning models, simply adding more layers to a GNN doesn’t always help. With each additional layer, we expand the receptive field of each node — i.e. the k-hop neighborhood that is involved in calculating its embedding — by 1. We can quickly reach a point where the k-hop neighborhoods for all nodes are nearly identical, resulting in the final embeddings for each node being extremely similar, or over-smoothed.

However, with a large graph like ogbl-ddi, we suspect there is value in having a larger receptive field that would let our embeddings incorporate information from further away neighbors/drugs. To mitigate the risk of over-smoothing, we can add skip-connections, where the input to a layer is added onto the layer’s output. Intuitively, because earlier layers in the GNN are less smoothed, they may do a better job differentiating nodes. Adding a skip-connection easily allows us to increase the impact of earlier layers and gives us a mixture of shallow and deep models, where the model can choose to even skip entire layers if they are not helpful and just rely on the skip-connection [5].

Depiction of skip-connections among GNN layers

Implementing a basic skip-connection in PyG is straightforward.

After training the model defined above, we actually see a drop in accuracy to ~56%, with the same 0 Hits@20 score. This poor performance using skip-connections suggests that earlier layers in the model have a negative impact, likely because all of our nodes have the same initial feature and earlier embeddings do a poor job differentiating between them. The skip-connection likely makes training harder by requiring the model to work around it, leading to a drop in performance. However, skip-connections can be a helpful technique in other scenarios (like when nodes have different initial features), and readers are encouraged to keep them in mind.

Post-Processing Layers

Although adding additional convolutional layers can get problematic, a GNN can contain other layers, such as pre-processing and post-processing layers. These layers do not pass messages, but instead apply multilayer perceptrons to the node embeddings.

GNN Model architecture with both pre- and post- processing layers

Post-processing layers, in particular, are useful to further prepare node embeddings for additional downstream reasoning tasks — such as use in link prediction [5]!

Adding on post-processing layers to our base GraphSage model is straightforward, requiring only standard PyTorch blocks.

After training the model defined above with 4 layers of post-processing, we see a small gain in accuracy to ~80% and a boost in Hits@20 to 0.119. Just with this simple change, we are already moving in the right direction — emphasizing the value in considering post-processing layers when working on your own GML tasks.

Neural Link Predictor

While we have so far focused on enhancing the GNN, itself, to get better embeddings, the LinkPredictor — which actually makes the predictions — is also an avenue for experimentation. In our base model, it’s just a dot product between embeddings with a sigmoid. However, as we saw with our post-processing layer experiment, passing our embeddings through MLPs can boost performance. Rather than perform this post-processing as part of the GNN, we can instead shift it to the LinkPredictor and turn that into its own small neural network.

Given the embeddings of two nodes, we can pass their element-wise product through a neural network to generate a more expressive prediction of whether they should be linked. Although this seems similar to the post-processing we did previously, the limitation of the prior method was that we had to generate individual node embeddings that were so powerful that they could directly be used for making link predictions for all possible pairs of nodes. By instead having our LinkPredictor be an independent neural network that can operate directly on pairs of nodes and take into account their unique interactions, this task becomes easier and the model becomes much more expressive.

Illustration of how link prediction is performed given two node embeddings with the NeuralLinkPredictor

Just like with the post-processing layers, implementing our Neural Link Predictor just requires some standard PyTorch.

Training for 100 epochs using the NeuralLinkPredictor and our base SAGE model gets us to an accuracy of 94% and a Hits@20 of 0.32, a considerable jump that shows the tremendous advantage of making our LinkPredictor more expressive!

Node Feature Augmentation

Up until now, we have focused our efforts on improving our model. We now turn to ways we can augment our initial graph to yield better performance.

Issue With Constant Node Features

Thus far, we have taken the standard approach of assigning constant initial features to all nodes in our dataset. However, this can introduce difficulties in training. As all the nodes have the same initial feature, the earlier layers in the GNN produce embeddings that are not well-differentiated. Even as we get to higher layers in our GNN, the poor quality of the earlier layers can continue to negatively affect the outputs (due to the dependence on prior layers to compute higher layers), potentially making convergence more difficult.

Augmenting With Node Features

Rather than use constant features, we can instead create initial features for our nodes that take their individual attributes into account, which can potentially allow for better differentiation of nodes early on and improve our final embeddings [6].

The exact features to choose are highly problem dependent and require experimentation. However, for demonstration purposes, we choose to make use of PageRank, Clustering Coefficient, Betweenness Centrality, and Node Degree, leveraging NetworkX to compute these common node statistics:

Notice that we opted to keep a constant dimension in our initial node features, even after augmentation. Intuitively, this would allow the model to still be able to learn as well as in the constant node feature scenario (as it could choose to simply ignore the additional features). Once the embedding is defined, we simply pass it in as the initial node features in our training loop instead of the constant embedding.

Results

Using our base SAGE model (except with only 5 layers instead of 7) and the NeuralLinkPredictor, training with our augmented node features yields a model that still achieves an accuracy of ~92%, and comparable Hits@20 of 0.295. Even with minimal feature engineering, node feature augmentation yielded us similar performance as before with a shallower network that trains faster, demonstrating its benefits. Depending on your graph and task, you can try more complex node features like cycle counts, graphlet vectors, etc., which may yield even better results!

Distance Encoding + Incorporating Edge Features

We now delve into incorporating distance encoding through edge features, our most advanced technique — both conceptually and programmatically.

Motivation

The highest performing models on the ogbl-ddi leaderboard have used some form of distance encoding. As described in [3], incorporating distance encoding into link prediction makes intuitive sense: two nodes are more likely to be connected by an edge if the distance between them in the existing graph is small.

However, calculating the shortest path distance between nodes can get expensive, and even incorporating those distances into the LinkPredictor can get complex. Ideally, the node embeddings we produce should have notions of distances between the nodes they represent baked into them, such that the LinkPredictor can seamlessly act on them.

As demonstrated in prior work [3], one avenue for achieving this is by calculating “edge features” that capture approximated distances between nodes and factoring them into our GNN. Thus far, we have relied on traditional GraphSage formulations that do not incorporate edge features. To change this, we will peel back the layers of the built-in PyG SAGEConv layer and learn how we can create our own MessagePassing layer! We present a slightly simplified version of the approach taken in the state-of-the-art models to fit within Colab GPU limits and for easier understanding.

Approximating Shortest-Path Distances

Trying to calculate the shortest-path distances between all nodes can get computationally intractable. Instead, we can approximate distances by leveraging an anchor set of K randomly sampled nodes [3]. We first calculate the distance from each of these K nodes to all other nodes in the graph. We then use these distances to get a reasonable approximation of the distance between any two nodes in the graph:

Let d_k,x denote the shortest path distance from node k to node x. We can thus approximate the shortest path between nodes u and v as the follows:

Naturally, larger values of K give better estimates, but are computationally more expensive. Prior work has found K = 200 to be reasonable. We can easily create our matrix of anchor set distances using NetworkX.

Here, spd is a matrix where the ith row holds the distance between node i and all the nodes in the anchor set. Thus, we have:

Using Distances in Message Computation

As a reminder, during GNN computation of the embedding for node u in traditional GraphSage, u received messages from each of its neighbors, v. The message computation for the message from v so far has just been taking its embedding from the prior layer, h_v^{l-1}. However, we can also incorporate features for the edge v -> u in our message computation — which in our case will be the d_u,v we calculated using spd. Thus, the message sent by v will be:

Building Our Own MessagePassing Layer

With the conceptual underpinning fleshed out, we can begin implementing our own MessagePassing layer that makes use of the distance edge features. PyG already has great documentation about creating your own custom MessagePassing/Conv layer. Rather than reproduce that resource, we will instead start by showing our specific implementation.

The most important parts to look at above are the initial lines of forward and message. forward takes in spd (the node x K matrix calculated above) and passes it and the prior layer embeddings x along to propagate, one of the methods defined on the MessagePassing superclass, which initiates the message passing process.

Depiction of how spd is “broken” up into spd_i and spd_j. Here, we have an edge going from Node B to Node A. We can see here how the corresponding rows for the two nodes in spd are replicated in spd_i and spd_j to make operations on them easier.

Message computation actually occurs in the message function. Notice how we have arguments in message defined with _i and _j. Recall that node j passes along a message to node i if there is an edge j -> i. By specifying spd_i and spd_j in the argument list, the MessagePassing superclass logic takes care of breaking up spd using the edge_index: spd_j will contain some nth row that is node j’s entry from spd and the nth row of spd_i will contain the spd row for node i (x is similarly broken up into x_j) . See the visual depiction of this above. With these matrices nicely setup for us by PyG, we can quickly calculate our distances in a vectorized manner with:

We then simply apply a linear transformation to this result and add it onto the prior layer embeddings to compute our messages!

Results

Using our custom MessagePassing layer in conjunction with the NeuralLinkPredictor, we are able to train a model that yields an accuracy of 84% and a Hits@20 of 0.11. Though we had to make many compromises such as reducing our node dimensionality and using a less-precise MessagePassing formulation in order to work with Colab’s limited GPU memory (see our Colab for details), distance encoding still produced somewhat reasonable results.

More elaborate implementations of distance encoding take advantage of multiple anchor sets for better approximations, more GPU capacity, higher values of K, etc. and have yielded state-of-the-art results on the ogbl-ddi dataset. Though we did not replicate those results, this exercise has now armed you with first-hand knowledge of how to create custom MessagePassing layers within PyG to implement sophisticated GML algorithms!

Conclusion

In this post, you have learned both about GNNs/GraphSage and how to implement numerous techniques that can improve the performance of your models. The NeuralLinkPredictor proved to be our most reliable technique, and after 500 epochs of training, it scores 94% accuracy and 0.32 on Hits@20 on the challenging ogbl-ddi test set. Not bad!

You’ve come far on your GML journey, but there is still so much to learn about recommender systems, making graph/node predictions, and beyond! Some additional resources on where you can go next are linked below. Happy learning!

Additional Resources

[1] Hu, Weihua, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and Jure Leskovec. “Open graph benchmark: Datasets for machine learning on graphs.” arXiv preprint arXiv:2005.00687 (2020)

[2] Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in neural information processing systems 30 (2017)

[3] Li, Boning, Yingce Xia, Shufang Xie, Lijun Wu, and Tao Qin. “Distance-Enhanced Graph Neural Network for Link Prediction.” (2021)

[4] Leskovec, J Lecture 9 Slide 62, Stanford CS224W Fall 2021

[5] Leskovec, J Lecture 7 Slides 65–67, Stanford CS224W Fall 2021

[6] Leskovec, J Lecture 8 Slide 27, Stanford CS224W Fall 2021

--

--