Predicting Drug Interactions with Graph Neural Networks

(By Maya Srikanth, Arpita Singhal, Manasi Sharma as part of the Stanford CS224W course project) In this medium tutorial, you’ll learn the inner workings of advanced graph machine learning techniques and get chance to apply state-of-the-art graph neural network models to predict drug interactions. By the end of the tutorial, you’ll be equipped to build your own graph neural network models and apply them to important problems in your domain of interest!

Maya Srikanth
Stanford CS224W GraphML Tutorials
15 min readJan 18, 2022

--

Introduction

Drug-drug interaction networks are a great opportunity to use graph deep learning techniques to address the urgent healthcare problem of adverse drug interactions. In the United States alone, around 27.2% of adults have multiple chronic conditions for which they take multiple medications [1]. Chemical interactions between drugs can be quite dangerous; for example, mixing a sleep sedative such as doxylamine with an antihistamine such as diphenhydramine (commercial name Benadryl) can slow reaction times and make it risky to drive. Even if the drug interactions are not fatal, they can cause undesirable side effects or reactions which reduce quality of life. Therefore, it is vital that clinicians are aware of possible drug interactions before prescribing medications to patients.

To investigate drug-drug interactions, we will apply graph ML techniques to the ogbl-ddi dataset [2, 3], a homogeneous, unweighted, and undirected graph representing a drug-drug interaction network. This graph contains 4,267 nodes and 1,334,889 edges. Nodes represent FDA-approved or experimental drugs, and edges represent potential interactions between drugs. Notice that this dataset has no features beyond pure graph structure!

Roadmap

Let’s get more concrete with our task formulation. We want to use graph structure information to predict whether two nodes (drugs) share an edge (interaction). We don’t know anything about the drugs or the nature of their interactions: we only know whether two “anonymous” drugs interact or don’t interact. Somehow, we need to generate meaningful representations for nodes in our graph: after obtaining these representations, we can feed them into a neural network to predict the existence of an edge between any 2 nodes. Ultimately, we’d like to use graph neural network methods to build the following training and prediction pipeline:

Roadmap. 1 corresponds to the drug interaction network, 2a and 2b correspond to preprocessing, 3 corresponds to graph neural network models, and 4 corresponds to the network responsible for link prediction. MLP stands for multi-layer perceptron (under 4)

In the remainder of this blogpost, we’ll systematically discuss and implement every step of this architecture using Python and PyTorch Geometric, a library built on top of PyTorch with functionality for deep learning on graphs. We’ll cover the dataset split(1), preprocessing (2a, 2b), basic GNN components (3) and link prediction (4). Finally, we’ll implement different graph neural networks to see whether they can generate meaningful representations for nodes in our graph and thus enable robust link prediction. Note that all code needed to reproduce our results is in this Colab notebook.

Let’s dive in!

Drug Interaction Network

To perform its computation on an input graph, a graph neural network model needs 2 important pieces of information: (1) adjacency matrix information on existing edges in the graph, and (2) any feature information about nodes and edges in the graph.

Let’s start with (1): how do we get information about edges in the graph? Luckily, ogb provides us with an API to easily access our dataset of interest, and as well as the train, test, and validation splits. Note that we are dealing with a transductive link prediction task: this means that our training, validation and test sets will contain all nodes in the graph, but different subsets of edges. Each of the train, validation, and test sets will also have held-out (“supervision”) edges: ultimately, we want to train our model to predict the existence of these supervision edges! The diagram below illustrates this concept.

Transductive Split for Link Prediction. The Training Graph uses a specified subset of edges for training (red) and supervision (orange) edges. The Validation Graph uses the training (red) and supervision (orange) edges from the training graph to predict the validation supervision (yellow) edges. The Test Graph uses the training (red, orange) and validation edges (yellow) to predict the test supervision (blue) edges.

We can load all necessary data with the following code:

Now onto (2) — where do we retrieve node and edge features? Recall that we have none! All we have is graph structure. Under the next section, we talk about our options for initial feature representations.

Data Preprocessing

Random Initialization

If we don’t have any node features, the most obvious way to represent our node embeddings at the start of training is with random initialization. That is, we start with completely randomized embeddings.

Node2Vec

With random embeddings, our GNN models need to work a bit harder to decipher relationships between nodes in the graph. Is there a way to give them a head start?

One possible solution is to use Node2Vec to construct meaningful representations that our GNNs can leverage for predictions. Node2Vec learns a mapping of nodes to a low-dimensional vector space, such that the resulting feature vectors preserve the information of the node’s neighborhood in the network [4].

To that end, Node2Vec uses a biased random walk (as shown in the gif we made below!) parameterized by p and q, where p controls the probability of returning back to the previous node and q controls whether to move outward (in a DFS-like manner) versus inward (in a BFS-like manner). The algorithm simulates a fixed number of random walks of fixed length starting from each node, estimates random walk probabilities for each transition, and then optimizes the embeddings using Stochastic Gradient Descent such that nodes with similar network neighbors are close in the feature space [4]. With a combined breadth-first and depth-first strategy, the Node2Vec embeddings can encode both local and global neighborhood network features — this is important for our task, as drugs with similar network neighbors may experience similar interactions.

Node2Vec Random Walk starting at Drug 1 and ending at Drug 9. Probability 1/p denotes the probability of returning back to the previous state. Probability 1/q is the probability of going to a completely new state. Finally, if the next state is a neighbor of the previous state, the probability of going to that new state is 1. The stick figure was taken from [5] :)

We can train and save Node2Vec embeddings as follows.

Right now, we don’t have a supervisory signal telling us how “good” these embeddings are. Let’s plot the tSNE of these node embeddings to get a sense of Node2Vec’s ability to create expressive node embeddings.

Left: tSNE of Node2Vec node embeddings. Center: tSNE of Node2Vec embeddings, overlaid with 2k randomly sampled edges from the train split. Right: tSNE of Node2Vec embeddings, overlaid with 2k randomly sampled edges from the test split.

Ultimately, we’d like to see how embeddings are formed given information about drug interactions between nodes! Given our prediction task, we might expect nodes which share many interactions (edges) in the graph to have fairly similar network neighborhoods, and thus, have similar embeddings. By this logic, if we overlay a subsample of 2k train edges over our node embeddings in black (middle plot), we might expect an expressive node embedding model to result in a tSNE plot where “clusters” of embedding points represent similar nodes, and thus have more edges (darkness). We might expect to observe something similar if we overlay a subsample of 2k test edges over our node embeddings in black (right plot).

As we can see, the majority of Node2Vec’s embeddings seem to be part of the same, large cluster. As we go along, we’ll see what kind of tSNE representations we can produce for more expressive graph neural network models. Note that our tSNE visualization implementation is included in the Colab notebook!

Models

We can turn to graph neural networks to generate more expressive node representations. There are 3 key steps to any graph convolutional layer, and we’ll cover them below: message, aggregate, and update.

GNN Basics: Message, Aggregate, Update

Message: Each node creates a “message”, containing information that it wants to pass to its neighboring nodes. At a given layer (l), the message (m) computed by node u is some linear transformation of its node representation (h) produced by the previous graph convolutional layer (l — 1).

Aggregate: Each node then aggregates the messages from all of its neighbors. Crucially, this step propagates information through the network.

Update: The update function takes the outputs from the aggregation step to produce the final node embedding representation for a given node u. In order to preserve information from the previous layer embedding, we’ll often concatenate a given node’s message to itself to the aggregation output. Other variations of the update function apply a non-linear transformation (like ReLU) to the output of ‘CONCAT’.

This gif summarizes the message, aggregate, and update steps!

GNN Basic Functions: Message, Aggregate, and Update.

Permutation Invariance and Equivariance: Vital GNN Properties

Before diving in, we need to make an important observation about graph neural network layers: unlike other deep learning models, GNNs provide us the critical properties of permutation invariance and equivariance. Essentially, we don’t want the order of nodes passed into a neural network model to matter: two isomorphic graphs should be mapped to the same representation regardless of pedantic differences like the ordering of nodes in the input adjacency matrix. Permutation invariance guarantees that a graph representation generated by a GNN is the same for any two isomorphic order plans. Permutation equivariance extends this notion to node representations.

Great! Let’s proceed.

For a moment, let’s pretend that we’ve selected our initial feature representations for all nodes in our training set, and fed them, along with adjacency matrix information, into our graph neural network of choice. Our graph neural network outputs high-quality embeddings for every node in the graph which hopefully capture vital aspects of the ogb-ddi graph structure. How will we use these embeddings to predict the existence of edges (drug interactions) that we haven’t yet seen? Short answer, the link predictor network.

Link Predictor & Hits@K

We can take two embeddings corresponding to two nodes u, v in the graph and feed their product through a neural network. This network will apply a series of linear and non-linear transformations to its input, finally using a sigmoid function to output the probability that the edge (u, v) exists in the graph. We can use the class definition provided in an ogb-ddi leaderboard submission[3]:

So we can now generate a prediction probability for whether an edge exists. But where does this probability factor into training and back propagation?

To formulate a sensible prediction task, we need to understand the concept of positive vs. negative edges. Positive edges are edges that exist in the graph (they indicate a drug-drug interaction), while negative edges are edges that do not exist in the graph. Now, you might be wondering — what loss are we using for backpropagation? To provide intuition, we can look at the hits@k metric, which has the premise that a good model will rank positive (existent) edges above negative (non-existent edges). Specifically, in the evaluation step, we’ll use our model to rank each true drug interaction among a set of approximately 100,000 randomly-sampled negative drug interactions: the ratio of positive edges that are ranked at K-place or is hits@k. For evaluation purposes, we select k=20.

We can adapt binary classification loss to capture this objective. Let E denote positive edges, E_neg denote negative edges, and let u,v, denote 2 nodes. We will use the following log loss term:

Observe that this loss drives positive edge scores higher than negative edge scores: this is indeed what we’re aiming for!

Implementing GNN Models: GCN, GraphSAGE, GIN

Let’s move on to implement specific GNN architectures: GCN, GraphSAGE, and GIN! These models differ in their message, aggregation, and update strategies.

Model Hyperparameters

But first, model hyperparameters. To standardize experiments, we use the same set of hyperparameters for each model, as shown below (these were the settings used by the same ogb leaderboard submission we referenced earlier[3]):

To keep things comparable, each GNN model implementation consists of 2 sequences of GCNConv→ ReLU → Dropout, where GCNConv indicates the graph convolutional layer of choice. We train each model for 2 runs (100 epochs each), and select the best model based on hits@20 performance on the validation set over both runs. That is, we train the model for 100 epochs twice (resetting parameters between each run), and select the best model over these 2 runs based on validation set performance. While other implementations on the ogb leaderboard use 10 runs, we chose 2 due to computational resource constraints, but you can certainly increase this number on the Colab notebook we provide to see how metrics change!

GCN

The first model we’ll start with is a GCN (graph convolutional network). GCN’s message, aggregate, and update functions are as follows:

i. Message function:

ii. Aggregation function:

iii. Update function

In the snippet above, we used the GCN class definition provided in an ogb-ddi leaderboard submission[3]. Let’s kick off training with both random features and Node2Vec embeddings.

How do these two models do? The best GCN model trained on randomly initialized embeddings achieves train hits@20= 45.65%, validation hits@20 = 39.54%, test hits@20 = 23.39%. The GCN model trained on Node2Vec embeddings achieves train hits@20= 47.536%, validation hits@20 = 41.24%, and test hits@20 = 27.46%. These figures indicate that Node2Vec does indeed boost test set performance for GCN!

As we did with Node2Vec, let’s use tSNE to visualize the node embeddings for GCN with Node2Vec initial features!

Left: tSNE of GCN node embeddings with Node2Vec initial features. Center: tSNE of GCN embeddings, overlaid with 2k randomly sampled edges from the train split. Right: tSNE of GCN embeddings, overlaid with 2k randomly sampled edges from the test split.

If you recall the tSNE plot for Node2Vec, you might notice that GCN’s embeddings exhibit many more distinct clusters than the Node2Vec visualization. This is likely because of the increased expressivity graph neural network models offer over biased random walk approaches.

Note that the code for training GraphSAGE and GIN is very similar to the code snippet above, so we won’t include them here (they’re also in the Colab).

GraphSAGE

Next, let’s try the slightly more expressive GraphSAGE model to see whether it will outperform GCN. GraphSAGE message, update, and aggregate functions are as follows:

i. Message:

ii. Aggregate:

For aggregate, GraphSAGE uses an element-wise max-pooling function, where MLP can be an arbitrarily deep multi-layer perceptron (a simple single-layer architecture also suffices). This step is what makes GraphSAGE more powerful than GCN!

[Note: there are other aggregation functions discussed in the paper, such as the LSTM aggregator, but we are focusing on the max-pooling function as it is symmetric and trainable].

iii. Update:

We can use the GraphSAGE class definition provided in an ogb-ddi leaderboard submission[3]:

After training, we see that the best GraphSAGE model trained on randomly initialized node embeddings achieves train hits@20= 73.10%, validation hits@20 =63.50% , and test hits@20 = 39.31%. The model trained on Node2Vec embeddings achieves train hits@20= 73.59%, val hits@20 = 63.79%, test hits@20 = 54.01%. Again, we see that Node2Vec initial features improves GraphSAGE model performance, and as expected, these figures are higher than those obtained by GCN.

Let’s plot the tSNE!

Left: tSNE of GraphSAGE node embeddings with Node2Vec initial features. Center: tSNE of GraphSAGE embeddings, overlaid with 2k randomly sampled edges from the train split. Right: tSNE of GraphSAGE embeddings, overlaid with 2k randomly sampled edges from the test split.

GraphSAGE also exhibits many more distinct clusters than Node2Vec, similar to GCN. Again, we might attribute this to the increased expressive power of graph neural networks!

GIN

This network has a unique architecture; in order to design a maximally powerful GNN, the aggregation function needs to be injective. This means that a specific multiset input of node features (a multiset is a set in which repeating values are allowed) must map to a unique output.

The issue with the GCN and GraphSAGE mean and max aggregation functions is that they can fail to distinguish multisets in fairly common situations (refer to the figure below!).

A: GCN failure case (mean aggregation). B: GraphSAGE failure case (max aggregation). Note that distinct colors represent distinct node feature types and are associated with distinct embeddings. For instance, in sub-figure A, the yellow color is associated with the one hot vector [1, 0].

GIN addresses this issue by using an injective aggregation function: according to a theorem in original paper [9], we can model such a function using two nonlinear functions applied to the elements in the multiset input S. Moreover, the aggregation function can be expressed in terms of a composition of high-dimensional element-wise MLP (multi-layer perceptron) functions as follows:

In practice, a dimensionality of around 100–500 for the MLP hidden layers is sufficient to achieve this injectivity. We can break all this down into the following message, aggregation and update functions!

i. Message:

ii. Aggregate:

iii. Update:

Note that MLP is a combination of the two MLP functions described earlier!

GIN’s properties allow it to distinguish between a greater variety of multisets, making it the most expressive model out of the 3 we are running! Let’s see how this model performs. We define our GIN class as follows, keeping the same “forward” function from the other classes to maintain some comparability.

After training, we see that the best GIN model trained on randomly initialized node embeddings achieves train hits@20= 53.26%, validation hits@20 = 46.36%, and test hits@20 = 51.51%. The model trained on Node2Vec embeddings achieves train hits@20=53.78%, validation hits@20 = 46.80%, test hits@20 = 56.14%.

As expected, Node2Vec initial features do improve GIN model performance. Further, these figures seem to indicate that GIN’s expressive power leads to performance gains above both GCN and GraphSAGE. We’ll compare the metrics more formally in next section!

Now, the tSNE plots! GIN also exhibits many more distinct clusters than Node2Vec, similar to GCN and GraphSAGE.

Left: tSNE of GIN node embeddings with Node2Vec initial features. Center: tSNE of GIN embeddings, overlaid with 2k randomly sampled edges from the train split. Right: tSNE of GIN embeddings, overlaid with 2k randomly sampled edges from the test split.

Results & Discussion

Models selected based on best val hits@20 score over 2 runs. Model** indicates that Node2Vec embeddings (dim=256) were used as initial feature representations.

Before we dive in, we need to make an important disclaimer. Due to constraints on compute resources, we report figures after only 2 runs (100 epochs each). However, we observed in our experiments that the hits@20 metric has a significant variance, and to obtain stable results, it’s best to aggregate statistics over at least 5–10 runs. This variance may also explain our fairly low hits@20 scores compared to leaderboard submissions for the same model architectures. Thus, we’ll jump into analysis with the reservation that these figures are not necessarily representative of the full performance potential that these models have.

Now that we cleared the air, let’s dive in! :)

First, we see that models trained on Node2Vec node embeddings outperform models trained on random embeddings. We suspect this is because Node2Vec provides useful information about the relationships between nodes to the GNN models during training. Recall that the Node2Vec parameters p and q characterize a suitable intermediate exploration strategy, allowing us to encode information about both local and global structures in the graph. In our project, the nodes are the drugs themselves while edges are drug interactions. Note that drugs that are similar to each other in chemical composition will most likely interact with similar drugs. For example, if Drug A is chemically similar to Drug B, which has a known drug interaction with Drug C, Drug A is also more likely to have an interaction with Drug C. Node2Vec implicitly assumes that nodes with similar “neighborhoods” are similar, allowing us to exploit these transitive interaction patterns.

Second, we see that GraphSAGE performs better than GCN. The GraphSAGE network has a more expressive aggregation function than the GCN, meaning its ability to discriminate between different multisets of node features during the aggregation step is higher. This is because the inclusion of an MLP within the aggregation function allows GraphSAGE to better capture intricacies in node attributes (if you recall, a high dimensional MLP can approximate any continuous function).

Third, we see that GIN’s performance performs slightly better than the other two GNN models! This confirms our hypothesis that GIN’s performance would exceed that of GraphSAGE and GCN, since the “outer” MLP in the aggregator adds more expressive power in place of the max aggregator used in GraphSAGE. We suspect that GIN’s higher representational power compared to other models allows it to create more robust node embeddings, which in turn lead to more accurate link prediction and higher hits@20 values.

And that concludes our blogpost! Thanks for reading — we hope you feel inspired to start building your own graph neural network models! :)

References

[1] Ward BW. Boersma P, Black LI. Prevalence of multiple chronic conditions among u.s. adults.

Prev Chronic Dis;17:200130. DOI: https:// doi.org/10.5888/pcd17.200130., 2018.

[2] An C Guo Elvis J Lo Ana Marcu Jason R Grant Tanvir Sajed Daniel Johnson Carin Li Zinat Sayeeda et al. David S Wishart, Yannick D Feunang. Drugbank 5.0: a major update to the drugbank database for 2018. Nucleic Acids Research, 46(D1):D1074–D1082, 2018

[3] Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and Jure Leskovec. Open graph benchmark: Datasets for machine learning on graphs. 2020.

[4] Aditya Grover and Jure Leskovec. node2vec: Scalable feature learning for networks. CoRR, abs/1607.00653, 2016.

[5] “Stick Figure.” Tatyana RU, from the Noun Project

[6] Vinod Nair and Geoffrey E. Hinton. Rectified linear units improve restricted boltzmann machines. In Johannes Fürnkranz and Thorsten Joachims, editors, ICML, pages 807–814. Omnipress, 2010.

[7] Thomas N. Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. CoRR, abs/1609.02907, 2016.

[8] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs. CoRR, abs/1706.02216, 2017.

[9] Keyulu Xu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. How powerful are graph neural networks? CoRR, abs/1810.00826, 2018.

[10] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. Multilayer feedforward networks are universal approximators. Neural Networks, 2(5):359–366, 1989.

--

--