Implementation and Understanding of Graph Neural Networks(GNN)

Abdul Rehman
Red Buffer
Published in
10 min readSep 7, 2022
Photo by GuerrillaBuzz Crypto PR on Unsplash

Neural Networks are good at capturing hidden patterns of Euclidean data (images, text, videos). But what about applications where data is generated from non-Euclidean domains, represented as graphs with complex relationships and inter-dependencies between objects?
That’s where Graph Neural Networks (GNN) come in, which we’ll explore in this article. We’ll start with graph theories and basic definitions and at the end, we will apply GNN to images using PyTorch Geometric

Table of contents

  • What is a Graph?
  • Why Graph Neural Networks?
  • Deep dive into Message Passing
  • Applying GNN on images using PyTorch Geometric
  • Conclusion
  • Resources

What is a Graph?

The most fundamental part of GNN is a Graph. In computer science, a graph is a data structure consisting of two components: nodes (vertices) and edges.
A graph G can be defined as G = (V, E), where V is the set of nodes, and E are the edges between them. If there are directional dependencies between nodes then edges are directed. If not, edges are indirect.

Directed Graph

A graph can represent things like social media networks, or molecules. Think of nodes as users, and edges as connections. A social media graph might look like this:

Social Media Graph

A graph is often represented by A, an adjacency matrix. If a graph has n nodes, A has a dimension of (n × n). Sometimes the nodes have a set of features (for example, a user profile). If the node has f numbers of features, then the node feature matrix X has a dimension of (n × f).

Graph Neural Networks

Graph Neural Network is a type of Neural Network which directly operates on the Graph structure and provides an easy way to do node-level, edge-level, and graph-level prediction tasks.
To understand the working of GNN, let's take an example of an image from the MNIST dataset.

Input image (handwritten digit)

In Convolutional Neural Networks, information is extracted by sliding the filter (of some odd-numbered kernel size such as 3x3, 9x9, etc.) over the whole image. Let's say our filter is of kernel size 3x3, then it will take information from all its eight neighboring cells by calculating their cross product.

A common convolutional neural network having several conv layers and few fully connected layers at the end.

Now to understand the GNN, let's assume each cell is a node and all other eight cells in its neighbor are its neighbor nodes.

Now let’s say we want to perform a link prediction task using GNN. For example, finding a best match for some person. Each node(person) in a graph(social network) has some features such as age, gender, occupation, etc. Using a similar aggregation mechanism (like cross-product in CNN) on graphs. Relaxing the properties of fixed structure and order, we can similarly gain a better understanding of each node by aggregating the messages received from all neighboring nodes.

Consider passing information between all nodes. The process of sharing information in GNN is known as “Message Passing”. Each layer in a network will be considered as a single round of message passing. And, with each round of message passing, more and more information gets passed around in the graph. Over sufficient rounds of message passing, we will eventually obtain a final representation of each node in the graph that well describe it given the larger context. We can also view the rounds of message passing as a series of layers just like in our convolutional neural network.

Message Passing

Here we go from the input, to round 1 of message passing and then round 2 of message passing. And this produces our final understanding of nodes.

As a human, our brain suggests that people with similar characteristics to him will be placed closer to him, and they are therefore stronger candidates. Similarly in the process of reasoning over a graph, the GNN seeks to summarize all that information of each node into a numerical representation. Nodes that are more ‘similar’ will be closer to each other. We call this representation of nodes as ‘node embeddings’. And the space containing all the possible representations is known as ‘embeddings space’.

Embeddings Space

How does Neural Network know where to place these embeddings in the embeddings space?

Neural Networks need an appropriate measurable objective to guide the network into learning something useful for our task. In our case, one way is to tell the network that embeddings of people who spend a lot of time with him (or with similar interests), should be closer to him.

To tell the network how well it is doing, calculate the loss.
Loss for all the nodes:

Graph Neural Networks learn final representations of nodes through a forward propagation, compute loss and then performs backpropagation.

After several numbers of epochs when the network learns embeddings that are good enough to meet our objective, we can simply shortlist a bunch of people who are in a certain radius from him and rank them as potential matches according to the numerical distance.

So, we are done with our task which is link prediction. Some other Graph Neural Networks tasks are:

  • Link Prediction
  • Node Classification (Yes/NO, etc.)
  • Clustering (Introverts/Extroverts, etc.)
  • Graph Classification

Deep dive into Message Passing

Message passing is the most important component in GNN. It is actually a mathematical function f() that updates the receiver node by using the messages from each neighboring sender node. Let us take an example to understand the message passing function also known as the “Aggregation Function”.

Here we have four nodes and we are trying to update the receiver node (x1) using the message passing function. First, we will take a product between each node feature (xi), a constant (c) and weight (W). Where constant (c) is the number of neighboring nodes each node (xi) have and W is the weight matrix which GNN can learn through backpropagation to know which features are more important. Then take the whole sum of products. In order to learn more complex patterns, we are adding some non-linearity (sigma) in the function output.
Since changing the permutation of the nodes, doesn’t change the node’s neighbors, what we really want is a permutation-invariant function that consistently aggregates information from a node’s neighborhood regardless of ordering.
Now let’s modify the function to generalize it to any other node not just the first node (x1).

Explanation:

  • All the neighbors are denoted by ‘j’.
  • Group all these terms (other than receiver node term) and denote the weights as a function.
  • Since we are updating the receiver node, we might want to treat its node features differently from the rest.
  • Denoting the outermost non-linear function as a general learnable function.

So, in the above message passing function weights of the neighbors are fixed based on the structure of the graph which is one of the three general flavors of Graph Neural Networks layers called Convolutional.

We can also learn the weights given to each neighbor according to their features, which is another flavor of Graph Neural Network layers called Attentional. In Attentional GNN layers, weights of neighbors are learned based on the interactions of the features between the nodes.

Attentional Message Passing

But the most general form of message passing would be that each pair of nodes collaborate to produce a specific message between them.

General Message Passing

Computing the updates for each node sequentially is a slow process. In practice, we use linear algebra to fasten this whole process of message passing. We summarized the graph’s edges using a table we call an Adjacency Matrix.

Adjacency Matrix

Here we have four nodes (graph’s edges). In our case, we also want to connect the edges so putting 1 on the diagonal also. We also have two more matrices known as Features Matrix and Weights Matrix. The feature matrix contains all the features of nodes while the weight matrix contains the learnable weights.

But the important weights for each neighbor are actually computed by changing the adjacency matrix accordingly.

With this we can use a highly optimized matrix multiplication library, to perform message passing for all the nodes at once.

Where A is the adjacency matrix of dimension (n x n), X is the features matrix of dimension (n x d), and W is the weights matrix of dimension (d x d).

Implementing Graph Neural Networks on images using PyTorch Geometric

In this part of the article, we will explore the implementation of graph neural networks and investigate what representations these networks learn. Along the way, we’ll see how PyTorch Geometric can help us with constructing and training graph models.

NOTE: We are using Google Colab for implementing GNN on images using PyTorch. So modify the run commands accordingly if anyone wants to run it locally.

Preliminaries: PyTorch
We’ll first demonstrate some essential features of PyTorch which we’ll use throughout. PyTorch is a general machine learning library that allows us to dynamically define computation graphs which we’ll use to describe our models and their training processes.
We’ll start by importing everything we need:

Installing required dependencies.

We’ll first download and load the dataset (here the MNIST handwritten digits dataset) through the DataLoader utility:

Downloading the dataset and creating data loaders for training and testing.

Defining the model
The GNNStack is our general framework for a GNN which can handle different types of convolutional layers, and both node and graph classification. The build_conv_model method determines which type of convolutional layer to use for the given task -- here we choose to use a graph convolutional network for node classification, and a graph isomorphism network for graph classification. Note that PyTorch Geometric provides out-of-the-box modules for these layers, which we use here. The model consists of 3 layers of convolution, followed by mean pooling in the case of graph classification, followed by two fully-connected layers. Since our goal here is classification, we use a negative log-likelihood loss function.

Here pyg_nn.GCNConv and pyg_nn.GINConv are instances of MessagePassing. They define a single layer of graph convolution, which can be decomposed into:

  • Message computation
  • Aggregation
  • Update
  • Pooling

Training Setup:
We train the model in a standard way here, running it forwards to compute its predicted label distribution and back-propagating the error. Note the task setup in our graph setting: for node classification, we define a subset of nodes to be training nodes and the rest of the nodes to be test nodes and mask out the test nodes during training via batch.train_mask. For graph classification, we use 80% of the graphs for training and the remainder for testing, as in other classification settings.

Test time, for the CiteSeer/Cora node classification task, there is only 1 graph. So we use masking to determine validation and test set.
For graph classification tasks, a subset of graphs is considered validation/test graph.

Training the model
Let’s train our model and visualize its progress. First, run this snippet to generate a link to TensorBoardX, which will take you to a page where you can visualize the loss and accuracy curves of the model.
Note: Google Colab

get_ipython().system_raw('tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'.format("./log"))get_ipython().system_raw('./ngrok http 6006 &')!curl -s http://localhost:4040/api/tunnels | python3 -c \"import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

Now run this snippet to start the training. When it’s finished, you should be able to see its training and test performance over time on the TensorBoardX page. If you run the snippet multiple times, you will be able to see multiple training curves and compare them.

Visualize Node embeddings
One great quality about graph neural networks is that, like other deep methods, their hidden layers provide low-dimensional representations of our data. In the case of node classification, we get a low-dimensional representation for each node in our graph. Let’s visualize the output of the last convolutional layer in our node classification GNN via TSNE, a method for plotting high-dimensional data. Nodes are colored according to their labels. We see that nodes with similar labels tend to be near each other in the embedding space, a good indication that our model has learned a useful representation.

Visualizing results

Matplotlib (Scatter Plot)

Conclusion

In this article, we have applied Graph Neural Networks on images using PyTorch. But before implementation, we discussed some basic concepts about graphs, GNN, and Message Passing. Learned in depth how GNN works and how we can pass messages using the message passing function. So that beginners can also be able to implement GNN after learning some basic concepts.

Resources

https://medium.com/r/?url=https%3A%2F%2Fneptune.ai%2Fblog%2Fgraph-neural-network-and-some-of-gnn-applications

--

--