Using GNNs and Protein Expression Networks to Predict Alzheimer’s Disease Diagnosis

By Siddharth Doshi and Olamide Abiose as part of the Stanford CS224W course project.

Background

In this project, we are focusing on how graph neural networks can be used to analyze protein co-expression networks to assist in disease diagnosis. Follow the Colab tutorial here:

Protein co-expression networks are an increasingly-popular tool to understand the biological mechanisms and cell type changes underlying different diseases. These networks are constructed by measuring protein levels in blood and cerebrospinal fluid for hundreds or even thousands of proteins. Graphs can then be constructed based on adjacency matrices, reflecting the similarity in protein expression levels. These co-expression networks add to the information provided by protein-protein interaction networks (which measure the physical interactions between proteins), because they allow researchers to study how protein levels change in dynamic contexts.

In this tutorial, we focus on Alzheimer’s disease (AD), using data from the Alzheimer’s Disease Neuroimaging Initiative (ADNI). AD is the most common form of dementia and one of the leading causes of death for adults age 65 and older. It is characterized by extreme memory loss and cognitive decline, and within the brain, it is marked by amyloid plaques (extracellular deposits of the amyloid-ß protein), and neurofibrillary tangles (composed of hyperphosphorylated tau protein aggregates), and grey matter loss. Protein co-expression networks have been used in recent years to get at proteins and biological pathways important for AD beyond amyloid-ß and tau. However, to our knowledge, none have attempted to use graph neural networks and co-expression networks to predict AD risk.

Dataset

To demonstrate how protein expression levels can be used to predict diagnosis, we use ADNI plasma data collected during participants’ baseline visits. ADNI is a decades-long longitudinal, multi-site study set to enroll nearly 2,000 subjects by its completion. It contains rich, phenotypic data from subjects, including information about clinical status, neuroimaging, biomarkers, cognition, and genetic profiles. The data within this tutorial contains expression data for roughly 190 proteins, collected from 58 healthy, older adults, 395 persons with mild cognitive impairment (MCI), and 112 Alzheimer’s disease patients (565 subjects total). In our dataset, protein analytes are columns and subjects are rows, with each cell representing the protein expression level for a given subject.

After a series of pre-processing steps (included in the Addendum: Additional Information), we created a matrix reflecting the biweight midcorrelation (a type of correlation that relies on the median instead of the mean, commonly used in proteomic analyses) between proteins. This will serve as the weighted adjacency matrix for our graph structures, and will hereby be referred to as the “protein co-expression matrix” or weighted adjacency matrix.

Visualization of adjacency matrix

Our GNN implementation utilizes this weighted adjacency matrix, a list of protein names, the diagnoses for each subject, and the log-transformed, regressed expression values used to make the adjacency matrix. For the purposes of our tutorial, we binarized diagnoses to be “1” if AD and “0” otherwise. This code can easily be updated to accommodate multi level classification (and incorporate MCI diagnoses). An example of the structure of the compiled patient level data (shown by S) and the diagnosis (y) is shown below.

Example of patient-level protein expression data, in matrix S, along with binary diagnostic labels, in vector y.

Our Approach

The key aspect of our approach is to represent each patient as a graph, with proteins as nodes and protein expression levels as node features and make a graph level prediction of their diagnosis. The edges for all individual patient graphs will have the same common fixed graph structure, defined by the weighted adjacency matrix, or protein co-expression matrix outlined above. What will vary amongst these patient graphs are the node values. The node values for each patient graph will be represented by the individual protein expression levels for that given patient, i.e. by the values of a row of the subject matrix S shown above.

Visualisation of data structure for the graph level, or patient level, prediction task

This approach of using a fixed graph across all patients may seem unconventional. However, a guiding assumption is that there is some structure that will emerge within node communities, where perhaps some connected “communities” of proteins play a more important role in Alzheimer’s disease than others. Therefore patients with high expression values for those protein “communities” may be more likely to yield a positive diagnosis.

To make graph level predictions that take into account community structure, we will use hierarchical clustering, as illustrated below.

Overview of our model architecture

We note that our graph should not be permutation invariant, as the characteristics of different specific nodes are important. This contrasts with the permutation invariant nature of GCN’s. To include positional encodings, we add a randomly generated shallow encoding feature vector of length 3 to the node embeddings for each node. This shallow encoding also stays the same across all subject graphs. The features for each node are therefore the scalar protein expression levels (which vary across graphs) and the positional encoding (which is maintained across graphs).

Implementation Theory

We are using a graph convolutional network (GCN) to learn local network information of each node and generate node embeddings accordingly. GCNs rely on message passing methods, where nodes exchange information sending “messages”. These messages are then “aggregated” or combined in some way to learn the new embedding of the node. This process is permutation equivariant, meaning if we mix up the node order, the results stay the same. To maintain some encoding of different proteins, we have manually included a shallow encoding of a random vector to each protein which stays the same across all graphs.

If we want to make a graph level prediction, we want to make some aggregation of all node information. However, with naive flat aggregations, like mean of all node values, or max of all node values, significant information about subtleties in the graph are lost. To address this and capture local structures, we use hierarchical pooling, where pooling occurs in various stages. By aggregating information from say, nodes with similar GCN embeddings, and then applying non linear transformations, finer information is captured. In practice we do this by training GCN embeddings, and then using the ASAPool method to cluster, and repeat this process (Ranjan et al).

Implementation Walk-Through

Loading and Preprocessing Data:

We first load the adjacency matrix as a NumPy object and convert it into a graph (G) using NetworkX’s nx.from_numpy_matrix command. We also create a string list of protein names and set these as node attributes using the nx.set_node_attributes. We also create an 565 x 51 NumPy object from the pre-processed expression values. We convert this, the binarized list of diagnoses, and the adjacency matrix to PyTorch tensors; we also convert the NetworkX graph G to a PyTorch Geometric object using torch_geometric.utils.from_networkx.

We then create a dictionary (split_idx) documenting the indices that will be used for the train, validation, and test sets, respectively. For every subject, we capture the row vector containing expression levels for all 51 proteins and transpose it into a column vector, storing it as x. We also capture the diagnosis information for each subject, storing it as y. If a subject’s index falls within the list of values corresponding to “train” in our split_idx dictionary, we append a data object describing a homogenous graph to train_list. This data object contains x as our node feature matrix and y as our ground-truth graph-level labels.

Additionally, it captures the edge indices and weights from G and stores these as an edge_index and edge_attr for each of these individual data objects. Put a different way, for every subject, we create a graph (stored in a PyTorch Geometric data object) where nodes are proteins and features are expression levels; the edges and weights are fixed across all subjects, as they are taken from the adjacency matrix (converted to a graph object G). Each graph will be classified into a diagnostic category (1 for AD and 0 otherwise). Subjects’ graphs will be appended to train_list, valid_list, or test_list based on subjects’ indices. We store these respective lists into different PyTorch DataLoaders, allowing us to pass samples in batches of size 32.

GCN Implementation:

We construct our first graph convolutional network to perform node-level prediction, with a total of five layers, including 256-dimensional hidden layers. Taking our feature vector x and the list of edge indices stored in graph G, we apply batch normalization, the ​​rectified linear unit function (our nonlinearity), and dropout with a 50% probability rate to all but the final layer. For our last layer, we apply a softmax function (which converts the output of this linear layer into a categorical probability distribution), followed by a logarithmic transformation.

We then use the module ASAPool to group nodes into clusters. Briefly, ASAPool initially creates clusters for all nodes based on their 1-hop neighbourhood, where each node is the mediod for its own cluster. Membership of nodes in clusters is determined using a self-attention mechanism, where a query attending on all constituent nodes in a cluster provide attention scores defining the membership strength of a given node in a cluster. Top performing clusters are then selected using a fitness score based on a graph convolutional method to capture local maxima/minima. A fraction of the best fitting clusters are selected for the pooled graph, and new edge weights are computed by considering whether pooled clusters contain any common nodes, or if any constituent nodes in the clusters are neighbours in the original graph. Further details can be found in the original paper (Ranjan et al).

In our model, we use a 50% pooling ratio; in other words, we select for the top 50% of best-fitting clusters in our pooled graph. This is repeated for another cycle of GCN classification and pooling, followed by a final (GCN) layer. Global mean pooling is then used to perform a graph-level aggregation, and a linear transformation is applied to this graph-level prediction, giving us our final model output.

These GCN and ASAPool layers comprise our model architecture. As we train the model, we take each batch and set the gradients of the optimizer to zero. We run our model for each batch and we calculate the loss using torch.nn.BCEWithLogitsLoss. We then backpropagate the errors and update the model parameters using loss.backward()and optimizer.step(), respectively. As we loop through each batch, we record the true diagnostic labels and our model outputs in two separate lists. We combine these in a dictionary with y_true and y_pred as keys.

Performance

We utilize the ROC-AUC evaluator from the “ogbg-molhiv” dataset to evaluate the performance of our model. This constructs a receiver operator characteristic curve that plots the true positive rate against the false positive rate; the area under this curve measures how well a model predicts between two diagnostic groups (in our case, AD or otherwise). During each epoch, we record the training loss and accuracy for the training, validation, and test groups (accuracy calculated by the ROC-AUC evaluator comparing our y_true and y_pred values). We ran our model for 50 epochs with a 0.001 learning rate; at the end, our training set accuracy is 50.15%, our validation set accuracy is 49.77% and our test set accuracy is 49.77%.

We compared this performance to a basic ordinary least squares model, where our X was the pre-processed matrix of protein expression values and y was the list of diagnoses per subject. We created a 52 1 weight vector (), representing the weight each protein (i.e., unique parameter in the regression model), has over the diagnosis outcome. The output of our model is X; we converted this and our y vector to PyTorch tensors, and again calculated the loss using torch.nn.BCEWithLogitsLoss. We measured accuracy using a simple threshold where predicted values less than 0 would be assigned 0, else they would be assigned 1. The loss for this OLS model was 0.7197, and the accuracy was 28% on the test set.

Discussion and Further Improvement

We note that our model performs significantly better on the test set than a least squares model (28%). Not shown is the performance of a single GCN without the hierarchical structuring, which has a similar performance of 25–30% (including when we change the GCN to other models such as GAT). This indicates the hierarchical model has significant benefits compared to naive flat models.

However, given the performance of ~50% accuracy on the test, training and validation datasets, there is significant room for improvement in our model. This could be implemented by using base models other than GCN’s for generating node embeddings, including Graph Attention Networks, which may work well on the fully connected graph structures that our protein networks are currently represented by. Graph isomorphism networks (GIN), which are more expressive in terms of structural differentiation may also be used.

There is also room for improvement of the way the data is pre-processed and stored. For example, we could threshold and drop edges in the adjacency matrix below a certain threshold, to avoid the graphs being so connected. We could also more explicitly account for the role of edge weights in our processing, which is currently being accounted for only by in-built functions and hasn’t been dealt with explicitly in our implementation.

Searching of the hyperparameter space, or changing model configurations (e.g. degree of clustering, number of clustering cycles) could also be of benefit.

It is interesting to consider whether there are fundamental limitations to our approach of operating on a fixed graph structure across all patients. Given the only feature across graphs that is changing is a single scalar feature for each node, there may not be enough expressive ability to differentiate between patients and accurately classify for disease. Systematic investigation of this point may be of interest. Perhaps the dataset could be augmented by adding in node embeddings for each patient based on other data from the ADNI such as neuro-imaging data to enhance the data richness.

References

Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks, https://arxiv.org/abs/1609.02907

Rex Ying, Jiaxuan You, Christopher Morris, Xiang Ren, William L. Hamilton, Jure Leskovec, Hierarchical Graph Representation Learning with Differentiable Pooling, https://arxiv.org/abs/1806.08804

Ekagra Ranjan, Soumya Sanyal, Partha Pratim Talukdar, ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations, https://arxiv.org/abs/1911.07979

Addendum: Additional Details

Data Preparation

We first pre-process the data by removing proteins where a certain proportion of samples are below a least detectable dose (LDD). The least detectable dose is the lowest reliable measurement for each protein (a unique value for every protein). If a large percentage of samples for a given protein are below the LDD, then we lose confidence in the reliability of that protein’s measurement altogether. We used a threshold of 25%, following the method laid out by the ADNI Biomarkers Consortium; if more than 25% of samples for a given protein were below the LDD, we dropped it from further analysis. Otherwise, we replaced values below the LDD with the actual LDD level. This left us with 52 proteins. We then logarithmically transform the protein expression levels to make variation across proteins to be within similar orders of magnitude (see Fig. 1). Our final pre-processing step is to regress out the effects of covariates that may be influencing protein expression levels. We want our protein co-expression network to effectively capture the relationship between proteins, and these estimates are potentially obscured by the effects of age, gender, and diagnosis.

--

--

Stanford CS224W: Machine Learning with Graphs
Stanford CS224W: Machine Learning with Graphs

Published in Stanford CS224W: Machine Learning with Graphs

Tutorials of machine learning on graphs using PyG, written by Stanford students in CS224W.

Responses (1)