GNNs in neuroscience: graph convolutional networks for fMRI analysis

Sidney Hough
Stanford CS224W GraphML Tutorials
11 min readJan 18, 2022

By Sidney Hough, Julian Quevedo, and Pino Cholsaipant as part of the Stanford CS224W course project.

You can find the full code for this tutorial in Google Colab here.

Overview

The brain is a network. It really is. It’s a bunch of interconnected neurons whose interactions give rise to cognition. It therefore serves neuroscientists well to use machine learning techniques equipped to handle networked or graph-like data. Graph neural networks (GNNs) are here to save the day! We can use GNNs to extract rich information from anatomical and functional connectomes, often obtained via MRI. Before we dive into the details, we will introduce and explain the general framework of a GNN.

What are GNNs?

The issue with standard neural networks

Neural networks have proved powerful in domains such as image recognition and language generation. However, current state-of-the-art models have a fundamental limitation: they learn from data that is structurally regular. An image can be thought of as a 2D lattice, while a sentence in a particular language can be represented as a 1D array of words, for example.

Problems arise when we try to learn over complex structures. As the title of this tutorial suggests, we will be building a model that learns from activity between various remote regions of the brains. How do we represent our brain regions spatially in matrix, given that intricate interactions occur across arbitrary distances?

Our brains aren’t rectangular.

Standard models also normally rely on a notion of order in the input data to make meaningful predictions. The output we get from passing “do re mi” into a language model differs from the output we get from passing in “mi re do.” How might we assign order to regions of the brain? For the purposes of this tutorial, we only really care about whether two regions coactivate, which they will do just the same regardless of any order we assign them.

Given these points, we need to find an architecture that

  1. Can learn over inputs of arbitrary structure.
  2. Make the same predictions regardless of ordering.

GNNs

Graph Neural Networks (GNNs) satisfy these requirements! We’ll come back and show that they do after discussing the basic architecture.

One fundamental modeling assumption GNNs make is the following:

Nodes that share edges tend to have similar properties/labels.

So if we have a graph consisting of nodes, each node having a feature vector, we want each node v to obtain information about the features of its neighbors (i.e. nodes that share edges with v). Node v’s features/embedding can then be refined by aggregating information it gets from its neighbors.

Nodes that aggregate information from their neighbors leverage graph structure in addition to features. Since nodes only receive information from neighbors, information only flows across edges of the graph.

LEFT: A graph with nodes 1, 2, 3, 4, 5. RIGHT: the blue arrows show possible directions of information flow through edges.

Each GNN forward pass performs this neighborhood aggregation process. More precisely,

  1. For each node v, we look at all of its neighbors N(v) (nodes that share an edge with v).
  2. Each node u in N(v) prepares a message m(u): this is the information u will pass on about itself to v.
  3. Node v collects the messages from its neighbors, and aggregates them into one summary vector: AGGREGATE({m(u), u in N(v)}).
  4. Node v updates its embedding by combining its initial embedding with the aggregated result.
LEFT: step 2. RIGHT: steps 3 and 4.

An example message function might just be a node passing forward its existing representation: m(u) = u. An example aggregation function might be the sum of all m(u)’s. By stacking some number of these layers, we can refine node embeddings which can be used to make predictions.

Let’s check that the architecture we just outlined satisfies the two requirements we mentioned above: (1) the model can ingest inputs of arbitrary structure, and (2) the model makes the same predictions regardless of the ordering.

(1) is clearly true! Our model takes graphs as inputs after all. (2) is also true, because in message passing and aggregation we do not think about node ordering: we simply collect an orderless set of messages and perform aggregation on this orderless set.

https://knowyourmeme.com/photos/1761950-mind-size-brain-size

Back to our tutorial

Now that we understand GNNs, we can return to the main content of this tutorial and the task we are trying to learn.

In this post, we’ll focus on functional connectomes, which encode the many temporal statistical relationships amongst brain regions. Functional magnetic resonance imaging (fMRI) is one useful technique that can help us map out brain activity so that we can discover these statistical relationships. From fMRI data we can derive functional connectivity matrices or graphs where rows/columns (nodes) represent brain regions of interest (ROIs) and entries (edges) represent strength and direction of functional connection (e.g. correlation coefficients).

GNNs learn rich features from connectivity matrices because unlike CNNs they leverage the full connectome (rather than the Euclidean neighborhood of a given ROI). In GNNs convolution occurs over adjacent nodes rather than adjacent pixels, allowing for the arbitrary structural composition that temporal patterns in brain activity tend to constitute (i.e., ROIs that adjoin spatially do not necessarily co-activate). Each ROI defines its own computation graph defined by its edges, or the ROIs with which it shares functional connection.

Data preparation

OpenNeuro is an awesome resource with lots of resting-state and task-based fMRI datasets. If you want to run your own experiments, keep an eye out for “BIDS valid” data — this green check means the dataset is formatted according to a common neuroimaging specification that will make processing easier later on.

What you should see before you download an OpenNeuro dataset.

Tasks a GNN could learn include:

  • Predicting whether or not a subject has depression
  • Predicting a subject’s IQ
  • Predicting what genre of music a subject is listening to

In this post we’ll look at “MRI data of 3–12 year old children and adults during viewing of a short animated film” (associated paper here) [1]. The task will be to predict the age of a participant given fMRI scans taken while subjects watched Pixar’s “Partly Cloudy.” Specifically, we want to know if a child is watching or an adult is watching (the problem is binary).

This 155-subject dataset has already had a popular preprocessing pipeline run on it known as fMRIPrep. If you’re working with your own data, make sure you run fMRIPrep or some other preprocessing pipeline first: the images need robust denoising, normalization, and smoothing to produce reliable results. In fMRI scans factors such as head movement can generate spurious signal fluctuations can greatly affect quality of data.

To get started with machine learning on neuroimaging data, we recommend using Nilearn, a Python package that offers great neuroscience-specific utility functions and compatibility with common formats in the field. Here we use Nilearn to retrieve our dataset, as well as an atlas that parcellates the brain into ROIs:

First we use a Nilearn class called NiftiMapsMasker that gets brain signal time series from the ROIs defined by the atlas. It’s important to specify a confounds file here to regress out noise. This confounds file, data.confounds, was generated by fMRIPrep.

Here, the computed time_series stores Blood Oxygenation Level Dependent (BOLD) signals for each subject. BOLD measurements give us information about how active each ROI is within the brain over time. ROIs are clusters of brain cells, and our bodies send more oxygen-rich blood to these cells when they become active. These local changes in blood-oxygen levels are measured by the MRI machine [2].

BOLD time series corresponding to different ROIs in the brain.

Now we can calculate our connectivity matrices using this time series data! Specifically, we want to determine how correlated the blood-oxygen levels are between each pair of ROIs. The magnitude of these correlations is then used as a proxy for how “connected” two given ROIs are. We consider two connectivity measures here: correlation and partial correlation. Later we’ll use the correlation matrix to define node features and the partial correlation matrix to define edges in our graphs, as in BrainGNN [3].

For fun, we can get a sense for what these matrices look like. We use nilearn.plotting to display a connectivity matrix and connectome for the first correlation matrix.

Matrix and view of brain displaying functional connectivity between ROIs.

Now we need to make these matrices ingestable by our GNN. To do this and to set up our model, we’ll use PyG, a PyTorch library for graph machine learning. Since our dataset is fairly small, we can use a customInMemoryDataset class to manage our data.

We implement the process method to turn the matrices, our raw data, into PyG Data objects which represent graphs. First we read in our .csv matrix files from our dataset folder with np.loadtxt. Then we use NetworkX graphs as a convenient intermediary format since they directly convert a matrix to graph format— from_numpy_matrix interprets our matrices as adjacency matrices where entries represent edge weights between nodes. This means that any two ROIs that have a non-zero partial correlation share an edge. Finally, we convert our NetworkX graphs to Data objects and assign features and labels to these objects.

As mentioned, the partial correlation matrix is used as the basis for edge creation. This is because partial correlation creates sparse graphs so the GNN avoids over-smoothing, an effect where all nodes end up receiving the same embedding in densely connected graphs. Correlation coefficients can be used as node features since they are a different measure of ROI connectivity. Each node’s initial feature vector contains that node’s (ROI’s) correlations to all other ROIs.

Model definition

The GNN that we’ll learn is a variant of cGCN (connectivity-based graph convolution network) [4].

Diagram from Wang et al. An fMRI time course is split into multiple graphs at the input. Then these graphs are passed through the graph convolutional network. An RNN or pooling method summarizes information from these graphs to generate a prediction.

In the cGCN paper, the authors construct a sequence of graphs for each subject, each corresponding to a different step in their time series. The initial node features in each graph are the BOLD signals of the ROI they represent. The edges are determined by taking the average partial correlation matrix of all subjects and choosing the k most influential connections for each ROI. This means that, while each subject’s connectome will have the same neighborhood structure, their temporal node features will differ based on measured BOLD activity.

The proposed cGCN performs message passing on each graph in the sequence separately, using a graph convolutional layer known as EdgeConv. The update rule for any given node i is

where h denotes a neural network, such as a multilayer perceptron (MLP) [5]. The message computed for each neighbor j is its difference in BOLD signal from node i concatenated with node i’s signal, which is then fed into h. Then, node i’s updated embedding is the max aggregation of each of its neighbor’s messages. (It’s a bit confusing, but you can think of using max as our aggregation method, and our update step as simply taking the resulting aggregation as the new node embedding.)

After several layers of message passing, the authors pool over the representations of these multiple graphs to make a final prediction.

For the sake of simplicity we only use one frame, the entire time course. This means we only have to pool information over final node features for one graph. In addition, instead of using raw BOLD signals as our node attributes, we use the correlation coefficients of a given ROI to other ROIs as node attributes as the authors of BrainGNN did.

Additionally, we do not share graph structures between subjects as in cGCN. Instead, we generate k-NN graphs for every subject’s partial correlation matrix so that our data is tailored to each subject’s unique functional connectivity patterns (see DevDataset processing above for this step).

In the model initialization we define the layers of the cGCN. Convolutional layers are the components that collect messages from nodes the previous layer, aggregate at each node, and pass forward new representations.

In between convolutional layers we insert batch normalization layers to ensure node features do not become extremely small or large. This is indeed important for our task since correlation coefficients can be tiny. When we don’t use batch normalization the model does not train well.

An unstable loss curve after removing batch normalization layers.

After batch normalization we add nonlinearities — the standard relu. Finally, we conduct what is called global mean pooling over our node features to obtain a final graph-level representation, since we only care about the age of the subject generating the connectivity matrix and are not predicting properties of any particular ROI. To do this, we use a property of the Data object we received called batch which tells global_mean_pool which nodes belong to which graphs (since in a batch nodes from different graphs are combined into one large graph). We learn a weight matrix over this graph-level representation and perform a softmax to get our final prediction.

We found empirically that this simplified cGCN performed better than a full cGCN with around 170 input frames, perhaps as a result of limited training examples. The implementation of full cGCN is documented and available in our Google Colab linked at the bottom if you are interested.

Training

At last we train our model on the task. We allow our model 32 hidden features. We use Adam as our optimizer, select a batch size of 32, and choose 0.01 as our learning rate.

Our model converges rapidly. Using the code above, we usually end up with near 100% on the train set and 97% on the test set. (Comparably, when we use full cGCN, the model ends up always predicting the majority class — it seems unlikely that this is a bug and might be improved upon by adjusting layer choices and hyperparameters.)

Negative log likelihood loss plotted against epochs.

Awesome! We learned about GNNs and implement a specific variant, cGCN, to try out on an fMRI dataset. We know enough now to experiment with more complicated neuroscience tasks and GNN architectures.There is a lot of low-hanging fruit in the graph learning + neuroscience space — many approaches to fMRI analysis still rely on CNNs or multivariate linear regression, so we encourage you to explore.

References

[1] Richardson, H., Lisandrelli, G., Riobueno-Naylor, A. et al. Development of the social brain from age three to twelve years. Nat Commun 9, 1027 (2018).

[2] https://royalsociety.org/blog/2016/08/qa-what-is-bold/

[3] Li, Xiaoxiao & Duncan, James. (2020). BrainGNN: Interpretable Brain Graph Neural Network for fMRI Analysis. 10.1101/2020.05.16.100057.

[4] Wang L, Li K, Hu XP. Graph convolutional network for fMRI analysis based on connectivity neighborhood. Netw Neurosci. 2021 Feb 1;5(1):83–95. doi: 10.1162/netn_a_00171.

[5] https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.EdgeConv

--

--