Predicting Drug-Drug Interactions using Graph Neural Networks

By Ananth Agarwal, Meg Richey, and Zeb Mehring as part of the Stanford CS224W Fall 2021 course project.

--

Understanding how pharmaceutical drugs interact with each other is a crucial problem for both medical research and practice. The field of Graph Machine Learning (GraphML) can be used to answer questions about drug-drug interactions with high confidence without first needing to conduct costly and time-intensive randomized controlled trials. This article demonstrates how!

N.B. Follow along with the companion Colab.

Introduction

Modern medicine, especially modern drug development, is incredibly powerful. Advances in chemistry, biology, and engineering have facilitated the creation of drugs that are able to, with astounding efficacy, treat symptoms of a wide variety of ailments, from the acute to the chronic, the mild to the severe.

Such potent tools, however, are not without their downsides. Drugs can have adverse interactions with other drugs for patients with comorbidities or issues that require multiple prescriptions to treat. These effects range from discomfort to blunting of the drug’s effects to severe pain and even death. Thus doctors need to take a patient’s current medication regimen into account when evaluating drug-based treatment options.

Today, medical professionals rely on drug-drug interaction databases when prescribing medication to ensure that any new drugs will not interact with a patient’s existing prescriptions in a negative way. These databases represent decades of codified medical knowledge, but are, however, woefully incomplete. Drug discovery is a field of constant innovation, and pharmaceutical scientists are capable of generating new drugs much more quickly than the scientific community is capable of setting up randomized controlled trials, undergoing peer review, and publishing research.

Enter modern computer science. It would be a boon to drug development and patient outcomes if the pharmaceutical industry had access to a machine learning model that could predict whether a new drug interacts with a target set of other drugs, to guide medical recommendations and further research. We can formulate the task of predicting the existence of interactions between drugs as a graph machine learning problem. This article will describe how to do so, and how to solve it (hands-on)!

Data

Model

Like other networks (social, dependency, etc.) in computer science, we can represent drug-drug interactions by an undirected graph. We encode drugs as nodes in this graph and link two nodes by an edge if there is a meaningful interaction between the drugs they represent.

Illustration of a sample DDI graph. The DDI graph is undirected, unweighted, and homogenous (i.e. there is only a single “type” of node and a single “type” of edge). In this visualization, solid lines are known interactions, and dashed lines represent “missing edges”, i.e., edges that represent interactions that we would like to predict.

Open Graph Benchmark (OGB) is a collection of open-source, benchmark datasets for graph machine learning tasks. Our focus in this article is the ogbl-ddi dataset, which consists of a single drug-drug interaction (DDI) network as described above.

We define the DDI graph mathematically as G = (V, E), where V is the set of nodes and E is the set of edges. Each node v ∈ V in the graph represents an FDA-approved or experimental drug. The existence of an edge (u, v) between two nodes u and v indicates that the two drugs interact such that the effect of taking both drugs together is considerably different from the expected effect in which drugs act independently of each other [1]. For example, two drugs that target the same proteins may have a significant joint interaction.

Problem Definition

We now wish to formulate the drug-drug interaction problem in terms of the ogbl-ddi graph.

First, some vocabulary. Positive edges are edges present in the dataset. Each positive edge represents a known significant interaction between the two drugs represented by the edge’s endpoints.

If two drugs do not interact (that is, they have the same effects when taken together versus taken separately from each other), there will not be an edge present in the graph. These “holes” are called negative edges; in other words, edges that don’t exist in the graph.

There are also potentially missing edges in the graph. These edges link drugs that may have a significant joint interaction, but the knowledge of which isn’t present in our dataset.

Our goal is to develop a graph machine learning model to solve the link prediction task: given two drugs as input, we want to predict if the two drugs interact with each other, i.e., if an edge should exist between these two nodes in the graph. This should allow us to complete our dataset by understanding missing edges as either positive or negative edges.

Loading the Data

We promised a hands-on tutorial, didn’t we? Following the example from the OGB website, we can load the DDI dataset into PyTorch Geometric (PyG):

ddi_graph is a torch_geometric.data.Data object containing the DDI graph provided by OGB. The most important thing to note is the edge_index attribute of ddi_graph, which is a torch.Tensor containing every edge in the network represented by a pair of node indices (corresponding to the endpoints). It is thus of shape 2 x |E|.

Solving the drug-drug interaction problem

Before proceeding, we must first outline how to formulate the link prediction task in the DDI network, which is (as described) equivalent to the drug-drug interaction decision task. We will make use of Graph Neural Networks to compute node representations, and use those representations to make predictions about node interaction.

Towards Graph Neural Networks

To be able to predict if an edge should exist between two nodes in the DDI graph, we consider first learning low-dimensional representations for each node called node embeddings.

Formally, given an input graph, we wish to encode each node v ∈ V into a d-dimensional vector in the embedding space. Embeddings are learned to optimize an important property: nodes that are “similar” in the original network should be embedded close together in the embedding space. Similarity is defined by the objective function, which can be expressed in English as “the likelihood that the two drugs represented by each of the nodes have a meaningful interaction”.

The tendency of similar nodes to cluster together in networks is observed in a wide variety of settings. For example, in social networks, friendship clusters tend to contain people with similar interests and experiences, and a good model would embed nodes within each cluster close together in the embedding space. In the drug-drug interaction setting, if two drugs target the same set of proteins, they are more likely to be connected by an edge since they are more likely to interact with each other. We want our final drug node embeddings to capture any such similarities that can be learned from the input network to improve the performance of our link predictor.

There are several different “traditional” techniques for learning node embeddings, such as node2vec and DeepWalk, but here we will focus on Graph Neural Networks (GNNs).

Graph Neural Networks

A natural question at this point might be “why GNNs?” or, more specifically, “why not just use a matrix representation of the graph and apply an existing ML technique?”

Unlike images and sequence data (for which we usually use CNNs and RNNs, respectively), graphs don’t have a well-defined node order or fixed “reference point”. Unfortunately, an adjacency matrix representation of a graph implies a particular node ordering (the order the nodes are written out in the matrix), which will inhibit a model’s ability to generalize.

We need an architecture that doesn’t “care” about ordering. In other words, we wish it to be the case that regardless of the order in which computations are conducted on nodes, a given node should end up with the same embedding (permutation equivariance).

GNNs provide this feature, and are in fact a general case of other popular ML architectures (e.g. CNNs and Transformers) which do make assumptions about order and structure.

Definition

Graph Neural Networks are a style of ML architecture that utilize each node’s local neighborhood structure in an order-independent manner to iteratively learn an embedding over a series of computation “layers”.

In the following sections, represent the embedding of node v at layer l of a GNN as h(v, l). Initial embeddings h(v, 0) are typically feature vectors derived from the entities that the nodes represent, but since the DDI graph lacks node features, we will randomly initialize our embedding vectors.

Message Passing — Conceptual Overview

Feeding forward through a GNN can be visualized as the passing of data from each node in the graph to other nodes in the graph along positive edges. Note that this is very similar to the way that data is passed around and transformed through an ordinary deep neural network; the only difference is that in a GNN, the structure through which embedding data is allowed to flow is itself input data (i.e. the graph itself).

Each GNN “layer” is defined by a round of this message passing between each node and its neighbors, the aggregation of the received neighbor messages by each node, and the computation of an updated embedding using the aggregated messages.

Message passing is the mechanism by which nodes are able to incorporate information from their local neighborhood structure to determine their own embedding. It consists of each node producing a message, which is passed to other nodes along the outgoing edges from that node during a round. Often, the message from a node v to a node u as part of layer l is based on node v’s previous layer embedding, h(v, l-1), and/or attributes of the edge (u, v).

2-layer GNN message passing visualization. Colors represent node embeddings.

The visualization above demonstrates message passing in action in a 2-layer GNN. In the above graph, we modeled the initial embedding of each node as a different color. During the first round of message passing, each node passes a message (its color) to its neighbors. The neighbors then incorporate all the messages they receive to update their own color (visualized by a mixing of colors at each node). After the second round, it is clear that neighbors up to two hops away have influenced the embedding of each of the nodes.

To form a GNN, we stack many message passing layers together in order to compute embeddings that synthesize information from embeddings across the graph. As the visualization above helps illustrate, over a K-layer GNN, the set of nodes that determines the embedding of a particular node (functionally, the node’s K-hop neighborhood) is called the receptive field.

One of the biggest advantages of GNNs over other node embedding methods is inductive capability (also known as “generalization” ability). Once we have our trained GNN, if we are given a new drug for which a few interactions with existing drugs in the DDI graph are known, we can generate an embedding for the drug by running it through our K-layer GNN. This embedding can then be used to predict other interactions in the network.

Message Passing — Mathematics

The result of each message passing layer is an update to the embeddings of each node in the graph. The update to the embedding h for the node v in layer l+1 can generically be expressed in the following form:

Equation 1: Node embedding update equation

Let’s break this down. First note the variable m denotes a message passed along an edge in the graph. We first collect the messages for all the neighbors of the target node v into a set. We then apply an aggregation function AGG to this set.

Possible aggregation functions include max, mean, and sum, in order of increasing expressive power. For the reasons discussed previously, it is important that the aggregation function is permutation invariant: regardless of the order of the inputs, the output should be the same.

After aggregating the neighbor messages, the message passing layer then performs an UPDATE function, which combines the previous embedding of the node v with the aggregated neighborhood messages. The result of this is the embedding of node v in layer l+1.

Individual GNN layers vary in the choice of message, aggregation, and update functions. Examples include Graph Convolutional Networks (GCN), GraphSAGE, and Graph Attention Networks (GAT). We will be using GraphSAGE for our analysis of the DDI dataset.

GraphSAGE

It’s time to bring this all together and introduce the GNN architecture we will use for the drug-drug interaction problem: GraphSAGE [2]. GraphSAGE is a particular architecture of GNN. Its message passing layer is defined as:

Equation 2: GraphSAGE layer update equation

In other words, it uses:

  • A generic aggregation function (e.g. mean, max, sum)
  • A summation update function

GraphSAGE layers also apply linear transformations (using learnable parameters W) to both the node embeddings and the aggregation outputs during the update step. Note that these parameters exhibit a form of weight sharing much like the kernels of CNNs. Namely, the matrices W are common across all nodes of the GNN. This ensures that the model size does not grow too rapidly; each W matrix is of shape d x d, where d is the embedding dimension, for a total of O(K x d x d) parameters. This is independent of the graph size.

We will implement our first GNN based on the GraphSAGE architecture’s convolutional operator. PyG provides a built-in implementation of each GraphSAGE layer in the SAGEConv class (though we will re-implement this using the above equation later). Using this, our first GNN model will have the following structure:

Back to the code! We can implement forward-propagation through a GraphSAGE GNN as follows:

SAGEConv uses mean as the aggregation function. We initialize our GraphSAGE model as follows:

This should read trivially now. In fact, this looks remarkably like an ordinary deep neural network (all the magic is happening in the conv layers). To forward propagate, we pass each node embedding through a GraphSAGE “layer” that collects messages from neighboring nodes, transforms and aggregates them, and combines them with the target node’s previous embedding. This result is used to update the embedding of each node, which is passed through a nonlinearity and subjugated to dropout for good measure.

Running this network with trained parameters should give us the node embeddings we’ve been seeking all along!

Link Prediction Head

Once we have the node embeddings output by the GNN, we need to use them to make link predictions. In addition to the GNN, we need to train a link prediction head: a binary classification model that computes the probability of the existence of an edge between two nodes in the network.

We will use the following neural network structure to predict links:

We’ll represent the node inputs by the embeddings computed by the GNN. The implementation of the link predictor is not the focus of this article, and we will use a very simple deep neural network for link prediction. In practice, more powerful networks may be used.

Suppose we want to predict if edge (u, v) exists, and we have the final node embeddings h(u, K) and h(v, K). We need to combine these node embeddings into an “edge embedding” that we can feed into the link predictor. This can be done in a couple different ways: concatenation, summation, etc. In our case, we will use an element-wise product.

We will then transform this edge embedding by a few standard linear layers, feeding the final output through a sigmoid function to obtain a probability for the edge.

This code implements a 2-layer link predictor and should be familiar to anyone well-versed in modern machine learning with PyTorch:

Initialize as follows:

Training

Now that we have defined our full model (GraphSAGE and link predictor neural network), we need to train it. Training a graph neural network on a single input graph is slightly more complex than training other machine learning models, but it can be formalized in much the same way.

Dataset Splitting

Since we are training the task of link prediction (a binary classification problem), our model inputs will be node pairs in the network. To generate these inputs, we will partition the positive edges in our network into three groups: the training set, validation set, and the test set.

The original sample DDI graph with positive edges partitioned according to some dataset split.

The GNN will be trained using only the edges in the training set: at training time, it doesn’t know anything about the validation or test edges. Intuitively, this can be thought of as removing the validation and test edges from the DDI graph.

Referring back to the code from the “Loading the Data” section, OGB has facilitated this for us by only including training edges when we loaded ddi_graph, so no additional work is needed from our side.

While training, we want our model to learn properties of both positive and negative edges. While we have a fixed set of training positive edges, we don’t have a fixed set of training negative edges. Instead, for each mini-batch of training positive edges, we will randomly sample an equal number of negative edges from the full graph. Note that negative samples in the training set may be positive edges in the validation or test sets.

Fortunately, OGB provides us with a fixed dataset split:

As explained on the OGB website, the edges are split such that drugs in the validation and test sets target different proteins than the drugs in the training set.

This is actually better than a random split, because it mirrors a more realistic use-case: new drug discoveries are likely to target different proteins than the existing set of drugs in the database. We want our model to be able to make accurate predictions about drugs that may be different from every other drug we’ve seen so far. A random split would not train our model to generalize as well since it likely will have seen every drug-protein combo in the training set. One notable downside of this split is that it is more challenging for the model to learn than a random split, but given the advantages it provides in terms of model generalization, this tradeoff is acceptable.

One more technical detail. While we are forced to sample the negative edges for training, validation and test negative edges are given to us in this OGB split. Groovy.

Training Process

Training our full model (GraphSAGE + link predictor neural network) on a single training example edge (node pair) involves:

  1. Conducting message-passing to compute the embeddings of all nodes in the network (GraphSAGE)
  2. Feeding the embeddings of the edge’s nodes through the link predictor
  3. Computing the loss between the link predictor’s output and the edge’s label. Since link prediction is essentially a binary classification problem, we use binary cross-entropy loss.
  4. Backpropagating the loss through the link predictor and the GNN to update the model parameters

In practice, we perform steps 1–3 over a mini-batch of edges to avoid an excessive number of backpropagation computations (which can get computationally expensive).

In code, this looks like:

The initial embeddings are randomly initialized using torch.nn.Embedding as follows:

Evaluation

While training the model, we will periodically score it against the validation and test data to track its performance. OGB provides an evaluator to score our model’s performance that uses the Hits@K metric, as opposed to raw prediction accuracy. The need to use a metric other than accuracy is underscored by the fact that most real world networks are sparse; they generally have far fewer edges than nodes. A link prediction model that just predicted “no edge” for every input would therefore be able to achieve high accuracy.

What is important for our model is for it to be able to identify positive edges. It should output a higher predicted probability for positive edges than for negative edges, which is what Hits@K captures.

Given a set of positive edge prediction probabilities and negative edge prediction probabilities (for example, predictions on the test set positive edges and test set negative edges), Hits@K is equal to the ratio of positive edges that have a higher predicted value than the K-th highest negative edge prediction. For example, Hits@3 = 0.75 in the following code because the 3rd highest negative edge prediction is 0.45, and 3 out of the 4 positive edge predictions are greater than 0.45.

Evaluating our model on the validation and test positive and negative edges involves:

  1. Conducting message-passing to compute the embeddings of all nodes in the network (GraphSAGE)
  2. Feeding the embeddings for each edge’s nodes through the link predictor
  3. Calculating Hits@K for the validation and test predictions

A snippet of the code is as follows:

Results

Hu et al. (2020) (paper, code) demonstrated strong results with GraphSAGE on the DDI data [3]. Running a scaled-down version of this model, as shown via code snippets in this article (due to Colab GPU constraints and for getting quicker results for pedagogical purposes), did not yield the same results, but it does demonstrate the power of the model (and the learning trajectory).

Using the functions we’ve defined above, a training loop looks like:

Validation and test hits are calculated every eval_steps epochs (5 in the graph below). If you’ve been following along in the Colab (or with the snippets here), you should see results like* those illustrated in the following graph:

As we can see, training loss is decreasing and our Hits@20 metric is increasing for both our validation and test sets. Success!

* Similar to, but not exact due to randomness in data shuffling and embedding initialization.

Custom GraphSAGE Layer — Additional Attributes

So far, we’ve used a “standard” GraphSAGE message passing layer. In this section, we’ll explore the use of a custom layer that takes into account graph features (in our case, edge features).

Node Anchors

GraphSAGE alone is a powerful architecture to apply to the drug-drug interaction problem, but it can be enhanced. In particular, we consider the modified update equation:

Equation 3: Rewrite Equation 2 with a generic message for each node

Where now each node produces an opaque message m for layer l+1 rather than just simply its previous layer l embedding. This allows us to augment our node embeddings with additional features. Consider a set of features p. Define each message as:

Equation 4: Define the message as a combination of the node’s previous layer embedding and a new set of features p

where we’ve defined yet another learnable weight matrix and a combination function COMB.

Now the question remains of how to define features that might improve the model. Lu and Yang (2020) (paper, code), whose GraphSAGE implementation is currently second on the OGB leaderboard for ogbl-ddi, define an edge feature equal to the average distance from the endpoints to “anchor nodes” [4].

First, a random subset of a nodes in the network are selected to be anchor nodes. The shortest path is then computed from each node in the network to each of these anchor nodes, producing a matrix of shape |V| x a. In code, we extract these distances as follows (note that we converted ddi_graph to a NetworkX graph nx_ddi_graph to use NetworkX’s shortest-paths utilities):

Next, for each edge (u, v) in the graph, the distance from u to each anchor node is added to the distance from the anchor node to v, and then these summed distances are averaged over the set of anchor nodes, producing “average anchor distance” for each edge. This yields a matrix of shape |E| x 1, which contains our one-dimensional edge features. The authors of the paper scale these features to the range 0 to 1.

The following code illustrates this full process:

The resulting edge features intuitively represent the importance of each edge in the network. In effect, they represent the average distance between the endpoints of the edge, if the edge didn’t exist. Higher values suggest that the two endpoints have different local neighborhoods, and that the edge in question serves as a powerful link between these parts of the graph. Lower values suggest the converse.

Example computation of the anchor distance between two nodes A and E, with C as the anchor.

GraphSAGE Layer

As we’ve only updated the message passing implementation details, we can re-use the GraphSAGE model structure from before (with a minor modification to pass the edge_attr parameter — see the accompanying Colab for details). We will, however, implement a custom message passing layer to take our edge attributes into account.

Instead of using PyG’s SAGEConv operator as we did earlier, we extend torch_geometric.nn.conv.MessagePassing to define our own operator that utilizes the computed edge attributes. The below code builds on the PyG base message passing framework to define a GraphSAGE layer that incorporates the provided edge attributes and implements Equations 3 and 4 in the forward and message functions.

As we can see, this follows nicely from the theory previously laid out. Our message generation function combines the embedding of each node with the transformed edge attributes using COMB. Moreover, the forward pass implements the GraphSAGE algorithm: first messages are propagated and aggregated (done by the work of the MessagePassing.propagate method), then they are combined via a summation of linear transformations with the current embedding to produce the updated embedding.

Note that this layer accepts any edge attributes. While node-anchor distance proved to be a useful feature for Lu and Yang, other attributes should be investigated to see if they yield a model with better performance.

Finally, we are ready to instantiate the GraphSAGE model that uses our custom operator:

If you’re interested in the results of adding these edge attributes, see the accompanying Colab.

Conclusion

Aaaand that’s a wrap! It’s been a whirlwind, but this was intended as a brief introduction to the application of Graph Neural Networks (and GraphSAGE in particular) to solve the drug-drug interaction problem.

Summary of core topics covered:

  1. Graph Neural Networks are a particular kind of ML architecture for graph data. They are permutation equivariant models that operate by passing “messages” of data along the structure of the graph.
  2. GraphSAGE is a particular kind of GNN that is both expressive and powerful enough to solve the DDI problem for the ogbl-ddi dataset
  3. GNNs are used in concert with traditional ML architectures (e.g. a deep neural network link predictor) to solve real-world problems
  4. GraphSAGE is a highly customizable and flexible architecture, and there are many augmentations possible beyond the vanilla implementation

We encourage you to also explore other OGB datasets — there are many other highly impactful applications of graph ML techniques!

References

[1] “Link Property Prediction.” Open Graph Benchmark, https://ogb.stanford.edu/docs/linkprop/#ogbl-ddi

[2] Hamilton, William L., et al. “Inductive Representation Learning on Large Graphs.” ArXiv.org, 10 Sept. 2018, https://arxiv.org/abs/1706.02216

[3] Hu, Weihua, et al. “Open Graph Benchmark: Datasets for Machine Learning on Graphs.” ArXiv.org, 2 May 2020, https://arxiv.org/abs/2005.00687

[4] Lu, Shitao, and Yang, Jing. “Link Prediction With Structural Information.” https://github.com/lustoo/OGB_link_prediction/blob/main/Link%20prediction%20with%20structural%20information.pdf

All images are created by the authors.

--

--