Online Link Prediction with Graph Neural Networks

This blog was co-written by Samar Khanna, Sarthak Consul, and Tanish Jain for the fulfillment of Stanford CS224W Fall 2021 (and as they all find graph neural networks amazing).

Find the companion Colab notebook here. The complete codebase can be found in the GitHub repository here.

Ever wondered if someone is truly your friend? Or faced the dilemma of not knowing what show to watch next online? Well, graph neural networks have got your back! Machine learning, as it currently stands, has become exceedingly accurate at pattern recognition and predictions. While a lot of focus has been on structured data such as images (a 2-D grid graph) and text (a linear graph), today we’re going to focus on the emerging field of machine learning on graphs.

Graph Neural Networks (GNN) are neural networks that are applied directly to a graph. While GNNs can also be used for node-level tasks (such as assigning a score for every website on the internet) and graph-level tasks (predicting properties given a molecule structure), we’re going to talk about using them for link prediction. More specifically, we’re going to focus on leveraging the inductive power of GraphSAGE, to predict interactions between a new drug with an existing collection of medicines.

We’ll be exploring the inductive power of GNNs on online link prediction by using the ogb-ddi (drug-drug interaction) dataset. Specifically, we’ll demonstrate GraphSAGE’s ability to predict new links (drug interactions) as new nodes (drugs) are sequentially added to an initial subset of the graph.

The Dataset

Image source: http://www.accutrend.com/it-still-comes-down-to-garbage-in-garbage-out/

Any ML model is only as good as the dataset it is trained on. For this blog, we will be playing with the ogb-ddi dataset [1] that was created by collecting real drug-drug interactions to serve as a leaderboard of link-predictor GNNs.

The dataset contains an unweighted, undirected graph with 4,267 nodes and 1,334,889 edges representing the drug-drug interaction network. The nodes of the graph represent different drugs. For instance, a small subset of 3 nodes, corresponding to the drugs Isotretinoin, Doxycycline, and Calcium Carbonate is visualized below.

A subset of the ogb-ddi graph containing 3 nodes

An edge between any pair of nodes, for instance between Isotretinoin and Doxycycline, implies that the effect of taking those two together is very different from the expected effect of the drugs acting independently. Such interactions can be mild eg. calcium carbonate decreases the effectiveness of Doxycycline, but some interactions can be severe such as Isotretinoin and Doxycycline together can cause increased spinal fluid pressure in the brain that can result in permanent vision loss!

Nodes that do not share an edge such as Isotretinoin and Calcium Carbonate imply that there is no unexpected interaction between them, and so taking them together should be safe.

What does it mean to predict links on this graph, and why does it matter?

Predicting links on this graph boils down to figuring out whether two drugs taken together would lead to the expected behavior of the drugs. Such information is critical to know when prescribing medicines and so it has wide-ranging implications. Potential uses of such a predictor include, but are not limited to:

  • Lowering the time and cost of drug development. Performing clinical studies is necessary to establish interactions between drugs, especially when drugs are developed for diseases that impact people with comorbidities who likely need multiple drugs. However, performing such clinical studies is usually a long and expensive process, and it is difficult to allocate resources to perform several such studies in parallel when a new drug is developed[4]. Using a drug-drug interaction predictor, it is possible to prioritize certain studies based on which interactions are predicted by the model.
  • Informing prescriptions when dealing with new diseases. When dealing with novel diseases, medical practitioners are often forced to prescribe several existing drugs, depending on the expected drug action. A recent example of this is how physicians used various drugs to treat patients afflicted with COVID-19 [5]. However, often, drug contraindications may not be tested for all possible combinations of drugs. A drug-drug interaction predictor can help doctors make more informed decisions when faced with the prospect of prescribing unusual combinations of drugs for new diseases.

Now that we understand the motivation for this task, let’s get started.

We have a companion Colab notebook that contains our implementation that you can use to follow along. We’ll be making use of PyTorch and the PyTorch Geometric (PyG) library, which has been designed specifically for writing and training GNNs.

Bonjour, GraphSAGE!

We’ll be using GraphSAGE — an iterative algorithm that learns node embeddings — for our task [3]. Aesop probably didn’t know about GraphSAGE, but he was able to articulate the core idea behind it:

A man is known by the company he keeps.

— Aesop

GraphSAGE parrots this “sage” advice: a node is known by the company it keeps (its neighbors). In this algorithm, we iterate over the target node’s neighborhood and “aggregate” their embeddings to find the target node’s embeddings. This will become clearer shortly with a little bit of math.

First, we will implement a graph convolutional network model by using a Graph Neural Network Stack. Since this is a general stack, we can use our own module implementation. In this case, a GraphSAGE layer in the GNNStack module.

The GraphSAGE aggregate approach. Image source: Hamilton et al. Inductive Representation Learning on Large Graphs, NeurIPS 2017.

Next, we must implement the GraphSAGE layer. For a given central node v, the message passing update rule is as follows:

where Wₗ and Wᵣ are learnable weight matrices and the nodes u are neighboring nodes. Additionally, we use mean aggregation for simplicity:

We use the following function to begin the message passing. The output here is a matrix of node embeddings returned from the message passing process.

def forward(self, x, edge_index, size=None)

The message function is used to construct messages from neighboring nodes to a central node for each edge in edge_index. The output here is a matrix of neighboring node embeddings that will then be aggregated.

def message(self, x_j)

Finally, the aggregate function aggregates the messages from the neighbors.

def aggregate(self, inputs, index, dim_size=None)
Message passing in GraphSAGE. Image Source: Lecture 10, CS224W at Stanford University.

Putting it all together, the GraphSAGE module looks something like this:

Thankfully, the folks at the PyG team have got our backs and one simply has to invoke SAGEConv instead to use GraphSAGE layers.

Now, we can use a GNN to learn the embeddings of each node and use them to predict edges in the graph.

Okay, awesome! So are we done? Can we start training? Well, slow down there. Our GNN with GraphSAGE computes node embeddings for all nodes in the graph, but what we want to do is make predictions on pairs of nodes. Therefore, we need a module that takes in pairs of node embeddings (i.e., an edge) and classifies if the edge connecting two drugs exists or not. To do so, we’ll define a simple neural network as follows:

The LinkPredictor takes the element-wise product of the real-valued embedding vector of 2 nodes (hᵢ and hⱼ) and computes the probability score of whether there exists a link between the 2 nodes through a multi-layer perceptron.

Now, we can learn the embeddings and GraphSage weights end-to-end using the train function defined below. We use the following loss function since we’re interested in maximizing the probability of correct edges (positive preds) while minimizing the probability of incorrect edges (negative preds) :

(the additional ϵ = 1e-15 is used for numeric stability)

Aight. So this is all cool, but what we really care about is our model’s performance on unseen test data. To evaluate how well the predictor performs, we’ll use the Hits@K metric. This metric works as follows: we rank each positive edge in the test set against 3,000,000 randomly-sampled negative edges, and count the ratio of positive edges that are ranked at K-th place or above [1].

Hits@K = Fraction of correct links in the top K links (with respect to their scores)

Now that we have all our function definitions in place, it’s time to run our offline planning code! We encourage trying out different hyperparameters to measure performance, but you can also use the hyperparameters we used in this notebook.

You should expect your training loss to smoothly decrease and similar Hits@20 values (roughly 60% Hits@20 for validation edges).

An Online Link Predictor

Now that we have a working understanding of using GraphSAGE for link prediction, let’s exploit its inductive nature, i.e., use it to predict links between a new node and the existing graph on the fly!

Example of online link prediction on a graph. Edges are predicted for every new node that is added to the graph.

This is more useful when we consider drug development since we would like to be able to predict drug-drug interactions as new drugs are introduced. Let’s get started.

Building an Online Link Predictor

We will draw upon the offline link predictor to build our online link predictor. We can think of each stage of the online predictor — corresponding to the addition of a new drug (node) — as an offline predictor, which means we can reuse a lot of our code! However, we must perform some preprocessing to be able to simulate this process, as well as understand how our online link predictor scales. In the online setting, we also want to learn the embedding of the new node quickly.

We’ll define a preprocess function, which will create a dataset with an initial subgraph and a dictionary of online nodes. This will be useful for simulating the online addition of nodes to our graph. Broadly, this allows us to generate a pickle file containing an initial graph and a new “online” node for which we must predict links.

Our new training function is nearly identical to the train function we wrote earlier. The only change is that while we needed to sample negative edges in the earlier case, we now generate negative edges for every node added in an online fashion using the preprocess function and can use those directly.

Here we visualize our model’s predictions on a few online nodes given a (very) small subgraph of 75 nodes. The blue edges represent message passing edges given to the model to generate an embedding vector for the new drug. The green, yellow, and red edges represent unseen test edges that the model classified as true positives, false negatives, and false positives, respectively. Given the very small size of the initial subgraph, there are many more false positives and false negatives than there would be given a larger subgraph for training.

The impact of prior knowledge

An interesting thing to think about would be how big our existing dataset (prior knowledge) needs to be for our online predictor to work well. We explore this question in two ways:

  • By varying the number of nodes in the initial graph: This amounts to how many drugs and their interactions amongst themselves are made available to the model.
  • By varying the number of edges used for training: This translates to how much information of the new drug is available to our model.

Here’s a summary of the results we found:

Here, the Train Msg column denotes the fraction of edges used for message passing, Train Sp denotes the fraction of edges used as supervision edges (i.e., used to compute loss) during training, while Val and Test denote the fraction of edges reserved for the validation and test sets respectively.

Unsurprisingly, richer prior knowledge — whether in terms of the number of initial nodes or the number of edges used for training — leads to better performance for the GNN. This tells us that it’s probably a good idea to build a critical mass of information in similar tasks before we can apply GNNs successfully.

Wrapping Up

Through this post and its companion notebook, we explored the inductive power of GNNs by building a link predictor for a drug-drug interaction dataset. We also introduced modifications to build an online version of the link predictor, which allows us to make predictions as new nodes are introduced in the graph.

Our GraphSAGE-based method only uses the connections between nodes to predict links. Modifications such as generating features about the links that are also used as inputs can improve the performance of GraphSAGE — in fact, the current leader on the ogb-ddi leaderboard does that to claim the top spot! Now that you have a good idea about GNNs and GraphSAGE in general, the world is your oyster — try it out on other data, and play around with GNNs!

That’s all from us!

Image Source: Wikimedia Commons

References

[1] David S Wishart Wishart, et al. DrugBank 5.0: a major update to the DrugBank database for 2018. Nucleic Acids Research, 2018.

[2] Emre Guney. Reproducible drug repurposing: When similarity does not suffice. Pacific Symposium on Biocomputing, 2017.

[3] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs. NeurIPS, 2017.

[4] Christopher P Adams and Van V Brantner. Estimating the cost of new drug development: is it really $802 million? Health Affairs, 2006.

[5] Kai Kupferschmidt and Jon Cohen. Race to find covid-19 treatments accelerates. Science, 2020.

--

--