Disease-Gene Interactions with Graph Neural Networks and Graph Autoencoders

Terence Tam
Stanford CS224W GraphML Tutorials
18 min readJan 16, 2022

By Kathy Fan, Terence Tam, and Anthony Tzen, as part of the Stanford CS224W course project.

It’s a typical night when you’re aimlessly scrolling through your phone for no particular reason at 1am. You open the Pinterest app and the homepage features some luscious monstera plants — just like the ones you pinned to your “Indoor jungle” board yesterday. Then you skim through the news and read about how Google is coming out with a new TPU chip that is designed to have optimized performance over its area. Shortly after, your med-school friend texts you — with, way too many exclamation marks — about exciting advancements in drug discovery.

What do these have in common? They’re all real-world applications that demonstrate the power of graph neural networks (GNNs). In recent years, GNNs have become an increasingly hot subfield of machine learning and are even regarded as the new “frontier” of deep learning.

Courtesy of Stanford CS224W Lecture Slide

This tutorial will walk you through the basics of GNNs and demonstrate how to readily apply advanced GNN architecture to a real-world dataset. Throughout the tutorial, we will assume that readers are familiar with machine learning concepts. Familiarity with PyTorch is also a plus, since we will be using the pytorch_geometric package in our accompanying Colab notebook.

Here is an outline of what we will cover:

  • Graph and GNN basics
  • Graph convolutional networks (GCNs) as a building block for our Graph Autoencoder (GAE) architecture
  • The GAE architecture and a complete example of its application on disease-gene interaction predictions
  • Enhancing the base GAE model for better predictive power
  • Feature augmentation — adding more features to the input graph data

Let’s get started!

What is a graph?

Before we dive into graph neural networks, let’s first define a graph. Most generally, a graph can be thought of as a collection of nodes (vertices) and edges: G = (V, E). To represent how nodes connect to each other in a graph, we often use an adjacency matrix A, which is a square matrix of dimension |V|x|V| for a graph with |V| nodes. In the simple case of an undirected graph, adjacency matrix entry A_ij = A_ji = 1 if there’s an edge between i and j, and 0 otherwise.

Sample graph, with corresponding adjacency matrix A of size |V|².

In practical real-world data, most graphs are sparse, meaning that a lot of the entries in the adjacency matrix will end up being 0. Therefore, to represent graphs more efficiently, we prefer to use an edge list. Rather than having |V|² entries, an edge list only specifies the edges that exist in the graph.

The same graph, with edge list representation of size |E|.

Finally, each node can have associated node features, summarized into a vector per node — let’s call it x. Pretty much anything can be a node feature: for example, a numerical constant, a class label, or a word2vec embedding of the node’s description (as you will see later in the data augmentation section of this blog!). We often stack all the feature vectors together into a feature matrix X.

GNNs and GCNs

Today, many well-known deep learning methods exist for data that comes in the form of sequences (i.e. words in a sentence) or images. In fact, all these types of data can be viewed as special cases of graphs: sequences are graphs that are linear in structure, and images are graphs with a structured lattice of nodes. GNNs extend prior methods and work on generalized input graphs that can have any arbitrary shape.

Sequence, image, and arbitrary-shaped graph.

How do graph neural networks work? Although there are many different variations of graph neural network layers, at the heart of each layer are three steps: message passing, aggregation, and update. We will explain these steps in the context of one of the most fundamental GNN layers, the graph convolutional network (GCN).

Recall from deep learning that in a convolutional neural network (CNN), the convolutional layer works by sliding a fixed-size filter across the input image in order to produce a lower-dimensional output feature map. For example, a 3x3 filter would reduce nine pixels in the input image down to a single value. We can interpret convolution as an operation that updates a pixel’s embedding by aggregating the information passed from all the adjacent pixels in the convolutional filter.

Toy example of convolution.

GCNs generalize this concept to irregularly-shaped graphs. Since we don’t have a grid-like structure anymore, we cannot define our convolution with a fixed-size filter; instead, for a node v, we use edge connectivity to define the neighborhood 𝓝(v); node u ∈ 𝓝(v) if there is an edge between u and v.

Thus, a GCN layer can be represented as:

GCN layer as message passing, aggregation, and update.

That is, for a node v, each of its neighbors passes (red color) along h^(l-1)_u, its representation from the previous layer l-1. These messages are then aggregated (blue color) via a summation across all the neighbors. Finally, we apply a linear transformation with W^(l) and a sigmoid nonlinearity in order to obtain an updated representation for node v. We do this for all the nodes of the graph, just like how in a CNN, we slide the convolutional filter across the entire image.

We can further stack GCN layers; in this way, a node’s representation incrementally accumulates information from its 1-hop neighborhood, then 2-hop neighborhood, then 3-hop, and so on for each additional GCN layer.

In pytorch_geometric, a GCN layer can be built with one line of code:

in_channels and out_channels denote the size of a node’s input representation dimension and output representation dimension respectively.

GCNs, albeit the simplest of GNNs, work great in practice — variations of GCNs often rank at the top of graph dataset benchmarks. Below is a screenshot of the leaderboard for a dataset from the Open Graph Benchmark; you can view a live version of this leaderboard here.

Building a complete GNN model — GAE

Now that we know how GCNs work, let’s use it as a building block for constructing a complete GNN model: the graph autoencoder (GAE)! If you’re familiar with autoencoders in the general machine learning realm, you’re well-equipped for understanding the concepts behind the GAE. In a GAE, we have an encoder whose job is to map the input graph into a lower dimensional space, and a decoder that reconstructs the input graph from the lower-dimensional embeddings. That is, we interpret the decoder output as the reconstructed adjacency matrix A*. The goal is to optimize the model such that the reconstruction loss (difference between A* and the original graph input A) is minimized.

High level diagram of a GAE model.

We will instantiate this abstract GAE model first with a GCN encoder. Specifically, we will define a GCN with two graph convolutional layers, a ReLU, and a dropout to help model performance.

Note, though, that you can also plug in other GNNs for the encoder; in fact, you can even use “shallow” (non-neural) encoders! (To learn more about shallow encoders, see section 2.2 of this paper.)

Putting it all together, the GCNEncoder is defined as follows.

As for the decoder, the most common choice is to use a dot product operator. For two nodes that are connected in the graph, we want the dot product of their embedding vectors to be big such that it models the value “1” in the corresponding adjacency matrix entry. On the other hand, we hope that two nodes which are not connected will have a small dot product, to model the “0” in the adjacency matrix entry. The default decoder in PyG’s base GAE class, InnerProductDecoder, is essentially this dot product operator. Hence, to initialize a concrete instance of GAE model, all we have to do is specify the encoder we just created above:

The loss function we’ve described earlier(reconstruction loss between A and A*) is also pre-defined for the GAE class:

loss = model.recon_loss(z, train_edge_index)

Here, we are passing in the embeddings z (produced by the encoder) and the training edges. With this loss function defined, our GAE model can be iteratively trained with the below train function.

That’s it! With just a couple lines of PyG code, we now have a working GAE.

Apply it to our dataset — Disease-gene interactions

Let’s see our GAE in action. For our tutorial, we have chosen a biologically-motivated application domain: predicting disease-gene associations.

The task

The disease-gene association prediction problem (associated paper) aims to find new associations between diseases and genes, based on existing data on diseases, genes, and known associations. Finding the linkage between genes and diseases could be crucial for applications in the pharmaceutical industry, where drugs are often developed to target specific proteins encoded by particular genes. More broadly, this knowledge could help scientists understand the functional similarities between genes and the underlying genetic similarities between diseases.

Traditional processes for discovering disease-gene associations can be labor and data-intensive. For example, genome-wide association studies require finding a specific cohort of subjects and dedicated genetics software to analyze the data. A deep learning approach could bypass these difficulties by leveraging the model’s ability to learn latent features in predicting new associations. In particular, the task can be modeled as a link prediction task on a graph where nodes represent genes and diseases, and edges represent associations between the genes and diseases.

Toy example of a disease-gene association bipartite graph, with 6 disease nodes and 10 gene nodes.

The dataset

To train and evaluate our model on this task, we will use the DG-AssocMiner dataset. This dataset is one of many available through Stanford’s BioSNAP collection. DG-AssocMiner consists of an undirected, bipartite graph where edges only exist between diseases and genes.

The data can be downloaded in the form of a csv file. Altogether, the network contains 519 disease nodes, 7294 gene nodes, and 21357 edges. The original data set is essentially an edge list mapping disease nodes to gene nodes in text format.

#     Disease ID Disease Name                 Gene ID 
0 C0036095 Salivary Gland Neoplasms 1462
1 C0036095 Salivary Gland Neoplasms 1612
2 C0036095 Salivary Gland Neoplasms 182
3 C0036095 Salivary Gland Neoplasms 2011
4 C0036095 Salivary Gland Neoplasms 2019
...

In our Colab, we demonstrate how to load the dataset into a pytorch_geometric.Data object. Since there are actually a number of built-in PyG datasets for which you won’t have to perform these extra preprocessing steps, we will omit them in this tutorial.

Note that we do not have any node features — yet ;). For now, we will use a dummy feature vector of uniform embeddings x where all the values are 1.

We create train, validation, and test sets by using the RandomLinkSplit method in pytorch_geometric. Behind the scenes, this method creates dataset splits by partitioning the edges in the graph. This is done by hiding some edges from the model during each phase. As shown in the diagram, in the training phase, we use the solid edges to predict the hidden edges (dotted). Then, during validation, we unhide a few more edges for the model to predict on, using all the edges it has seen so far. Similarly, during the test phase, we uncover a few more edges and the model is allowed to use all the edges from training and validation.

Link splits for training, validation, and test.

Before we do any training, let’s visualize how our untrained GAE model would embed the nodes with this dataset. Below is a 2D PCA projection of the GAE model with random initialization weights.

PCA projection of the embeddings when GAE model is untrained. Blue colors are disease nodes and red colors are gene nodes.

As you can see, initially the nodes are somewhat randomly scattered in the two-dimensional PCA projections.

Training

We set up a standard machine learning training pipeline. Our Colab contains code for hyperparameter tuning, if you’d like to play around with different values for the learning rate, hidden dimension, epochs of training, and more. For our tutorial, we suggest to try with the following set of parameters:

Hidden dimension: 200

Output feature size: 20

Epochs: 30

Optimizer: Adam

Learning rate: 0.1

Dropout: 0.5

Performance metrics

Since the disease-gene association prediction problem is essentially a classification problem, where the model outputs whether it believes an edge exists or not, we analyze our model’s performance during both training and testing with two standard metrics: area under the ROC curve (ROC-AUC), and average precision (AP).

ROC-AUC

ROC-AUC provides an aggregate measure of performance across all possible classification thresholds. For the link prediction problem, it represents how well the model can predict if a positive edge is really a positive edge (and vice versa if the negative edge is really a negative edge). A ROC-AUC closer to 1 means the model has good separability of the positive and negative edges.

Let’s plot the ROC curve for our trained model! For the full code, see the Colab notebook. The key is to note that the sklearn library has a built-in function for calculating the ROC-AUC score:

Here’s what the ROC curve looks like, from our fully trained GAE model:

Indeed, it looks like our model (blue line) was able to learn how to predict disease-gene edges far more successfully than a random classifier (dotted red line in ROC curve). This suggests that the model was able to learn useful representations for the input graph.

Average precision

Our second performance metric is average precision, or AP. Average precision summarizes the precision-recall curve as the weighted mean of precisions at each threshold n. Intuitively, it is the area under the Precision-Recall curve.

This performance metric indicates whether a model can identify all positive edges without accidentally marking too many negative edges as positive.

Putting it together, here is what the training epochs looks like with our GAE model and the disease-gene dataset.

Results and Insights

Let’s visualize how our node embeddings have changed after training is complete, for the GAE model.

PCA projection of the embeddings when GAE model is fully trained. Blue colors are disease nodes and red colors are gene nodes.

As shown, by the end of training, the node embeddings have shifted to form more visible clusters differentiating gene and disease nodes. Remember, we haven’t passed anything to the model that indicates there are two types of nodes (disease and genes), as we are using a constant node feature for all nodes. The model learns on its own from the graph connectivity information that it is a bipartite graph!

To better understand what the final embeddings represent, we can look at specific data points. For instance, the BRCA1 and BRCA2 genes are well-studied in the medical literature — mutations in these genes are strongly linked to breast cancer.

If we look at the top edge predictions between related cancer diseases (“neoplasms”) and “BRCA” genes, the model also predicted that they are highly related (i.e. large dot products between disease nodes’ and genes nodes’ embeddings).

Dotprod, Edges(Disease, Gene)
================================================================
42.74, (Prostatic Neoplasms,BRCA1 associated protein 1)
39.79, (Mammary Neoplasms,BRCA1 associated protein 1)
35.82, (Stomach Neoplasms,BRCA1 associated protein 1)
35.42, (Prostatic Neoplasms,BRCA1 interacting protein C-terminal helicase 1)
34.61, (Colorectal Neoplasms,BRCA1 associated protein 1)
33.35, (Mammary Neoplasms,BRCA1 interacting protein C-terminal helicase 1)
33.28, (Neoplasm Metastasis,BRCA1 associated protein 1)
30.73, (Malignant neoplasm of prostate,BRCA1 associated protein 1)
30.45, (Mammary Neoplasms, Experimental,BRCA1 associated protein 1)
29.91, (Stomach Neoplasms,BRCA1 interacting protein C-terminal helicase 1)
29.83, (Lung Neoplasms,BRCA1 associated protein 1)
29.69, (Animal Mammary Neoplasms,BRCA1 associated protein 1)
29.47, (Liver neoplasms,BRCA1 associated protein 1)
29.01, (Bladder Neoplasm,BRCA1 associated protein 1)

The top edge predictions reveal that our model indeed learned a strong association between BRCA1 and mammary neoplasms, a medical term for breast cancer. The model’s prediction of an association between BRCA1 and prostate cancer is also validated in the literature; men with the BRCA1 gene aren’t as in-risk of breast cancer, but their chances of prostate cancer are significantly higher. In addition, we find that interestingly, the BRCA1 gene is also strongly linked with many other types of cancer.

To summarize, here is the performance of our model:

Try running our Colab to see if you can reproduce similar results!

Enhancing the base model with variations: VGAE

Let’s try out a slightly different version of our model: the Variational GAE (VGAE). The VGAE is very similar to the GAE. However, instead of encoding each node as a specific point in the latent embedding space, the VGAE uses a multivariate Gaussian distribution as the encoder’s output heads. The decoder then samples from this distribution.

VGAE Architecture

Altogether, this provides more expressiveness and representative power to the model. This is because VGAE can generate a new graph from the original input graph using the learned Gaussian distribution parameters. In contrast, GAE aims to just reconstruct the original adjacency matrix A.

Since the VGAE class is also implemented in pytorch_geometric, we can adapt our model with just a few additional lines of code. Specifically, we need to modify our encoder to have separate GCNConv() layers to generate the mean and variance of the distribution.

One last change is in the loss function: we now have

loss = (
model.recon_loss(z, train_pos_edge_index) +
(1 / train_set.num_nodes) * model.kl_loss()

)

The first term model.recon_loss() is the same reconstruction loss as before (cross-entropy loss between A and A*). The extra model.kl_loss() term represents KL divergence, a regularizer that tries to make the encoder’s mean and variance parameters close to the Gaussian distribution N(0, 1).

For comparability, we use all the same parameters as for the GAE model. Here are the best VGAE’s results compared to GAE:

The VGAE had slightly higher performance according to both metrics, but the difference is not that drastic. We think this is because our task and dataset are easy enough such that the GAE is already able to perform well, and the additional expressiveness of the VGAE does not give it a significant advantage in this case.

Advanced: Graph enhancements by adding features

So far, we have been focused on using the graph structure (via passing an edge list) as the model input. We haven’t explored our model’s full capabilities, because instead of giving it meaningful node features, we have been using (arbitrary) uniform vectors for X. Let’s try adding some purposeful node features and see if our model can do even better!

Feature matrix X is actually part of the GAE input we can optimize, besides the adjacency matrix A.

Since our nodes represent real diseases and genes, we have string descriptions corresponding to each node, which we can encode into features. Specifically, each disease node comes with a name, such as “salivary gland neoplasms”. Getting the gene descriptions is a little more roundabout: our dataset comes with NCBI gene IDs, while the GeneSynopsis dataset, a dataset that contains gene descriptions, use ENSEMBL gene IDs. Therefore, we have to first map NCBI IDs to ENSEMBL IDs and then look up the corresponding descriptions. We provide these mapping functions for you in our Colab.

Next, we encode each string description into a bag-of-words one-hot encoding and reduce the size of the embedding space by applying a word2vec model. Again, don’t worry if these NLP techniques are unfamiliar to you — we provide the necessary code and details in our Colab.

In the Colab, the size of the embedding space is customizable and can be fine-tuned to optimize performance. We found that a 25-dimensional embedding space for gene descriptions and a 15-dimensional embedding space for disease descriptions was best. This means that there would be 40 node features. For gene nodes, the first 25 features would correspond to the gene’s description embedding, while the last 15 would be the average of the disease node embeddings. For disease nodes, the first 25 features is the average gene nodes’ embedding, while the last 15 would correspond with the disease description embedding. With these new features, we can run the same experiments on GAE and VGAE models.

Embedding gene names into efficient embeddings

Interestingly, with these node features, our model actually performed worse, plateauing at around 90% AP. The number of distinct words is on the same order of the number of nodes, and each description is comprised of 2–4 words on average, so this result is not too surprising — there was little information shared across node descriptions that could be used to supplement the structural data in the learning process. Indeed, the PCA of the embeddings shows that the embeddings fall into a more noisy, normal distribution with gene and disease embeddings partially overlapping, contrasting with the two distinct clusters of gene and disease embeddings that we previously saw. This implies that our word2vec node features were actually unnecessary for a performant GAE or VGAE model on this dataset. The graph structure itself already held enough information for us to accurately predict top links between genes and nodes!

PCA of learned node embeddings, using description embeddings as initial embeddings. Blue colors are disease nodes and red colors are gene nodes.

Yet, even with slightly worse performance, note that using these word2vec node features did make it easier for us to interpret some of our results. Since the node features do not start uniformly, the AP and AUC performance curves notably follow a more natural, smooth curve and plateau. In addition, the larger variation and differentiation between embeddings allows for more nuanced patterns among the gene and disease nodes.

Thus, the use of domain knowledge for creating node features can still be useful. With more detailed and organized descriptions of each gene and disease, perhaps we can generate node features that would improve both performance and interpretability.

Conclusion

In this tutorial, we covered the basic concept of GNNs, and in particular, GCNs. We then built a complete GAE model and saw how it learned to perform disease-gene association prediction. In addition to the GAE architecture, we also explored the VGAE. Finally, we tried adding node features to enrich our graph input.

You can review all the code for this tutorial here.

With the knowledge from this tutorial in hand, we hope you feel inspired to explore more ideas and applications with GNN! For instance, you can try running the models from this tutorial on a different dataset. Here is a drug-gene interaction network that will be very similar to work with, compared to the dataset from our tutorial. As mentioned in the intro, recommender systems (like Pinterest, or Amazon) are also a popular application of GNNs. You can try using this dataset to form a user-item interaction graph and test out the GAE model.

Another way to extend what was covered in this tutorial is to adapt the existing code and try a different type of GNN layer in the encoder; GraphSAGE and graph attention network (GAT) are two other popular options, and they’re readily available in PyG as well. (We like this tutorial on GraphSAGE, and this tutorial on the GAT model.)

Finally, in addition to the node features we added, there are other methods to enhance the graph. You can use disease-disease relations to add edges between disease nodes, and similarly, explicitly model relationships between genes. Another idea is to add a “universal” virtual node connecting all genes, and another one connecting all diseases. This could potentially make message-passing more effective, since nodes become closer together. These modifications will mean that your graph is no longer bipartite, but with a few small tweaks to accommodate the new edge types, the model will still work.

If you try any of the above suggestions, drop a comment below! And if you become the inventor of the next GNN-based web application, don’t forget to credit us! :)

References

M. Fey and J. E. Lenssen. Fast graph representation learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds, 2019.

W. L. Hamilton, R. Ying, and J. Leskovec. Representation learning on graphs: Methods and applications. CoRR, abs/1709.05584, 2017.

A. T Marees, H. de Kluiver, S. Stringer, F. Vorspan, E. Curis, C. Marie-Claire, and E. M Derks. A tutorial on conducting genome-wide association studies: Quality control and statistical analysis. Int J Methods Psychiatr Res, 27, 2018.

F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay. Scikit-learn: Maching Learning in Python. Journal of Machine Learning Research, 12, 2011.

V. Singh and P. Liò. Towards probabilistic generative models harnessing graph neural networks for disease-gene prediction. CoRR, abs/1907.05628, 2019. URL http://arxiv.org/abs/1907. 05628.

S. M. M. Zitnik, R. Sosic and J. Leskovec. BioSNAP Datasets: Stanford biomedical network dataset collection, 2018. http://snap.stanford.edu/biodata.

--

--