AN OVERVIEW: GRAPH NEURAL NETWORKS

diliprb1999@gmail.com
SFU Professional Computer Science
13 min readFeb 10, 2023

Authors: Dilip Reddy Basireddy,Nagendra Reddy Vippala,Hassan Ahamed Shaik

This blog is written and maintained by students in the Master of Science in Professional Computer Science Program at Simon Fraser University as part of their course credit. To learn more about this unique program, please visit {sfu.ca/computing/mpcs}.

Many data structures in computer science can be used to store information in a structured format, one such data structure is a graph. Graph structures have nodes or vertices connected to each other via edges.

Graph Types

There are multiple types of graphs; directed graphs have edges with a single direction, whereas undirected graphs’ edges are bidirectional. These edges represent the node relationships. There are homogenous graphs, which only have one node type, whereas heterogeneous graphs have multiple node types.

Applications of Graphs:

The Internet uses directed graphs to represent web pages, and nodes in the graphs are websites. If there is an edge between two pages, there is a hyperlink between them. To sort search results, Google applies the PageRank algorithm to these graphs.

Social media sites like Facebook use undirected graphs, where users are considered to be the vertices, and if they are friends then there is an edge running between them. Facebook’s friend suggestion algorithm uses graph theory.

Molecules and biochemical structures can be represented as graphs.

One more important application of graphs is knowledge graphs; this concept is used to represent real-world objects or knowledge related to the objects in the form of graphs. Knowledge graphs have edges, or nodes, which represent the relationship between nodes. For example, to represent the fact “Dog is an animal,” the words “dog” and “animal” are placed in nodes, and the “is” relation is represented via an edge.

Computationally, graphs can be easily represented by using adjacency matrices. If the element at index “AB” is zero, then there is no edge between node “A” and node “ B”, if it is non-zero, then there is an edge between A and B with a weight equal to the corresponding value.

Introduction to Graph Neural Networks:

Neural networks are complex non-linear functions that take a set of numbers as input and output predictions, the parameters of the functions are learned during training. Most neural networks work with structured data and some unstructured data types like text and images.

As discussed before, there is a lot of data available to us in graph formats, for example, chemical structures, social network graphs, or knowledge graphs, which are in unstructured formats that can’t be put together in text or image format.

GNNs are a special type of neural network that can read graphs as inputs. They are powerful tools that can be used for node-level, edge-level, and graph-level prediction tasks.

Initially, all nodes of the graph have different vector embeddings. Vector embeddings are lists of numbers given as input to machine learning models — which represent some data related to real-world entities in a latent space.

Vector embedding can be generated with the help of domain knowledge, and there are some specific algorithms like word2vec, FastText, SBERT, and the universal sentence encoder (from Google) that can be used to generate meaningful text embeddings. The embeddings of all nodes in a graph are stored together in the feature matrix.

The purpose of training GNNs is to improve individual nodes’ embeddings in the graph so that, it captures some neighbourhood data or structural data. All node embeddings are updated by aggregating their neighbours’ data with their embedding.

These final embeddings can then be used for various downstream tasks like node labelling, link recommendation, graph clustering, etc. This update of node embedding is called “graph convolution” or “message passing,” and the basic variant of GNN is called “Graph Convolution Network.”

Each convolution of the GNN contains:

  • Collection of neighbourhood data.
  • Using the permutation-invariant aggregation function (mean, max, average, sum) to aggregate collected data with self-embedding.
  • Passing aggregated data through a neural network whose weights are learned through training. The output of the neural network is used as an updated node embedding. As layers increase, information is passed between nodes over longer distances, and nodes gather more structural data.

We complete convolutions of all nodes in the graph at a time by multiplying the adjacency matrix with the feature matrix, for example:

The first matrix is the adjacency matrix of a graph, second matrix (column matrix) is the feature matrix i.e., it contains node embeddings.

Multiplying adjacency matrix * features matrix (A*X) is equivalent to each node collecting data from its neighbours and using sum as the aggregation function. This multiplication result is then sent to a neural network, the output of which is used to update the node embedding of all nodes in the graph. This whole process is repeated multiple times so the nodes can get information from more distant neighbours.

In the above equation, AX represents multiplying the adjacency matrix with the feature matrix; later, it is passed to a neural network with weights W0 and W1 and some activation functions like RELU.

There are a few libraries in Python, like DGL, PyTorch Geometric, and StellarGraph, that can be used to train graph neural networks.

All the mathematics discussed before can be implemented using the DGL and PyTorch libraries. In the above code, the forward(g, features) function is taking the adjacency matrix and the feature matrix as inputs, and it does graph convolution as discussed before and updates all nodes with the newly learned embeddings.

Heterogenous graphs:

Most real-world graphs are heterogeneous, which means that different types of nodes have different embedding lengths. Heterogenous graphs have relations, i.e., each node is connected to another node via a relation (different types of edges for each relation).

The method discussed above cannot work with this type of heterogeneous graph.

A heterogeneous graph has two types of nodes user and game.

In the DGL library, heterogeneous graphs can be defined using a set of node-edge-node objects, where each relation in the graph is a single triplet (source node type, edge type, destination node type).

Relational graph neural networks:

The main difference between R-GCNs and normal GCNs is that R-GCNs can be trained on multiple relation types, whereas normal GCNs require all nodes to have the same structure.

In GCN, during convolution, one set of weights is shared by all nodes. In contrast, in R-GCN, different edge types use different weights, and only edges of the same relation type “r” are associated with the same weight.

“h” denotes the embedding of a node, “r” is a relation, and “Wr” is the weight associated with the relation. Neighbourhood embeddings that have a relationship “r” are multiplied by weights “Wr,” and all embeddings are aggregated using the summation function. This output is added to the self- node’s embedding, after which it is passed to the activation function.

One of the problems with using R-GCNs on highly multi-relational data is that when there are many different relations in a graph, we have to use different sets of weights for all of them. This leads to overfitting, and relations with a smaller number of nodes can be overlooked during training. To address this issue, the RGCN paper has used the basis function (a concept in linear algebra where a high-dimensional latent space can be represented using a set of basis vectors) to reduce the number of different weights to be used.

Edge Features:

For graphs, without edge features, GCN uses adjacency and feature matrices, and their multiplication performs implicit graph convolution.

For one-dimensional graphs, we can use a weighted adjacency matrix, where if there is an edge with a weight of “P,” then the weighted adjacency matrix [A] [B] will be equal to P.

Graph with weighted adjacency matrix

Graphs can have edges with higher-dimensional weights (vectors rather than numbers); in such cases, we can use message-passing neural networks or MP-GNN.

MP-GNN convolution has three phases: transform, aggregate, and update.

Transform: each node gets its neighbour embedding and the edge embedding that connects them to their neighbours, these triplets (hv, hu, euv) are concatenated and passed through an MLP(M) to get transformed features(mv)

Aggregate and update: the transformed features of all neighbour triplets are aggregated with self-node’s embedding; this output is passed through another neural network and the final result is used to update all nodes’ embeddings.

Inductive approach with GraphSage:

Normal GNNs are inherently transductive; they do not work with new data; that is, they expect the entire graph structure to be present during training and only the embeddings of those nodes are updated. If new nodes are added during inference, the whole graph has to be trained again.

In contrast, because the GraphSage model employs an inductive approach, it can easily generalise to previously unseen data without the need for retraining on test data.GraphSage’s paper also introduces using sampling and mini-batches to work with large graphs.

Instead of training the model just to update the node embeddings, the main idea behind the GraphSage algorithm is training a set of aggregation functions that learn the patterns behind the graph’s structure using neural networks; these trained aggregation functions are then used to combine neighbourhood embeddings during inference.

At every iteration, each node selects all neighbours at K distance, and their embeddings are aggregated using a learnable and permutation-invariable aggregator function. These functions are learned and are different for different distances.

The authors of the GraphSage paper experimented with different aggregator functions, including max-pool, mean aggregation, and even LSTM aggregation (LSTM is a type of recurrent neural network).

Mean function: the mean of the current node along with its neighbours is passed through a neural network, and the weights of the neural network are trained during training.

Max-Pooling: Each node vector is passed through a trained multi-layer perceptron; after that, the element-wise max-pooling operation is applied to the outputs of the neural network, i.e., if the vector of node “A” is [10,2,-5] and the vector of node “B” is [1,8,6], the max-pooling output will be [10,8,6].

After training and applying aggregation functions to all neighbourhood sets, the GraphSage model concatenates the current embedding of all nodes with the output of the aggregation function, and this concatenated vector is passed through another multi-layer perceptron, which converts all nodes’ embeddings into fixed-size vectors. Following this, each node vector is normalised based on its neighbourhood size, which decreases the chance of an exploding gradient problem.

During training, along with updating node embeddings, the main focus of the model is learning the weights of all MLPs used for aggregations and concatenations.

Inference:

As discussed before, one of the main advantages of the GraphSage algorithm is performing inference on unseen data. During inference time, if new nodes are added to the existing trained graph to perform prediction tasks on these new nodes, we can directly use the trained aggregation functions to perform vector updates. Once the vectors of these new nodes are updated, we can use these embeddings in downstream tasks like prediction by passing them through fully connected layers or other machine learning algorithms like random forests or XGBOOST.

Another important strategy used during inference is sampling, i.e., instead of aggregating all the neighbourhood embeddings for very large graphs. GraphSage randomly samples neighbours from various distances; this can reduce the time complexity of inference without much effect on accuracy.

Real-world use case:

One of the downstream tasks for graph neural networks is link prediction. The main aim of this task is to predict the probability of a link between a new inference node and existing nodes in the graph. We can use the GraphSage algorithm for this task. The StellarGraph library in Python provides an easy-to-use framework to train the GraphSage model. Link prediction can be very useful to build recommendation engines; companies like UberEATS and Pinterest use the GraphSage algorithm.

For example, let’s consider the UberEATS scenario, where a graph is built with users and restaurants as nodes; if two restaurants are similar, they are connected with an edge; users are connected if their orders are similar; and users and restaurants are connected if a user ordered from that restaurant. To recommend restaurants to a new user, we can use link prediction. During inference, if our model predicts a link between the user and a restaurant, then we can recommend that restaurant to the user.

The figure on the left side shows a brief overview of the Link Prediction task whereas the figure on the left shows the process of generating link embedding from node embedding and the usage of a fully connected neural network for the link prediction task.

Initial embeddings of users and restaurants are created using domain knowledge. To address the link prediction task, we should train a GraphSage model that takes a set of labelled negative (no link) and positive (link) user-restaurant node pairs. During training, the GraphSage model learns aggregation functions that can improve the current node embedding such that when we use these node embeddings for link prediction, we get accurate results.

To generate link embeddings, node embeddings are concatenated, as shown in the above picture. Once link embeddings are generated from positive and negative training samples, the model uses a fully connected layer to perform prediction. The GraphSage model and fully connected layers are trained together using loss functions like binary cross-entropy loss; this helps the model learn to predict if a link can be formed from two nodes.

During inference, when we want to recommend restaurants to a new user, we can add that user to the existing graph and predict links between the new user and all the filtered restaurants. If the model predicts a link possibility between the user and a restaurant, we can recommend that restaurant to the new user.

Attention in graphs

The incorporation of an attention mechanism has been a significant milestone for neural networks in recent years. Before attention, neural networks were performing poorly for long sequences of inputs because those algorithms could not recognise the links between parts of inputs. For example, while translating the phrase “How is your day going?” to French, “Comment se passe-t-ta journée?” a model with an attention mechanism can understand that the word “how” refers to “comment,” and “day” refers to “jour.” By recognising these kinds of references, the model can learn the inherent patterns of the language. Attention mechanisms are mostly used in a type of neural network called Transformers.

To briefly explain the mechanism behind “attention,” during the training process, these models learn some weights that are used to transform the word vectors into “query” and “key” vectors; each word in the sentence will have a query and key vector. To find attention between the words of a sentence, the dot product between the query of the word and keys for all other words is calculated. An important property of the dot product is that if two vectors are similar, their dot product will be closer to one, otherwise, it will be closer to zero. This helps the model learn the relationships between the words, i.e., a similar word gets more attention.

SoftMax functions are used to scale the attention scores so that the addition of all attention scores will be equal to one. The important concept to notice here is that the weights used to generate query and key values are learnable parameters, and they are improved during the training process.

Graph attention networks (GATs) use this attention mechanism to score the relevance of neighbours concerning the current node. GATs expand upon the aggregation function used in normal GNNs.

Before aggregating neighbour vectors, they are multiplied with attention scores (alpha in the above picture). Each neighbour’s attention is calculated by feeding the source-neighbour (query-key as previously discussed) vector to the self-attention layer (e12 = a [Whi || Whj] where “a” is a learnable weight matrix). Following that, attention scores are passed to a SoftMax function, where the sum of all neighbours’ attention scores equals one. For example, if a node has three neighbours with attention scores of [0.7, 0.2, 0.1], the neighbour with the highest attention score receives the most attention.

After getting attention scores for all neighbourhoods, normal aggregation functions are used to complete the graph convolution process.

Conclusion:

In recent years, graph neural networks have been trending mainly in the fields of chemistry and biotechnology, where they are being used to do machine-learning tasks on chemicals and protein structures.

Another interesting concept is self-supervised learning, where a model can learn from unlabeled datasets. For example, part of the graph will be masked during training, and the task of the model is to predict the missing graph.

If you’ve made it this far, we hope you have developed some basic understanding and interest in graph neural networks. Hoping this we are concluding the blog.

--

--