A Journey into the Land of Graph Neural Networks — Part I

Greg Kantor
dunnhumby Science blog
9 min readMay 13, 2024

About a year ago I started working on GNNs (Graph Neural Networks). They are a very fascinating subject in their own right, and I would like to take you on a journey of discovery on how they might enhance the world of machine learning as well as potentially beyond. I will share some of my experiences from their application and a review of what I have observed in my limited time of interacting with them.

Graphs In a Nutshell

The best place to start perhaps is with a lightning fast introduction of graphs and graph neural networks. These are very much vast topics and I will not have the space or time to cover everything, and just like Bilbo Baggins said:

“I don’t know half of you half as well as I should like; and I like less than half of you half as well as you deserve.”
— J.R.R. Tolkien

In this section I will aim to introduce what are the basic components of graphs, why they are relevant and then move on to describing some of the currently used algorithms to tackle problems involving such graphs.

Graphs — while it sounds menacing in theory — are actually things we are all exposed to on a daily basis without thinking about it explicitly. Social networks, natural language, images, retail data and many more are actually all examples of graph data. They are all so different, how could this possibly be true?

A graph fundamentally conveys a relationship or a link between objects of any type. These objects can be anything you can possibly think of: people, words, pixels, customers and products, just to bring up the examples from before. The objects are more rigorously defined as nodes, while the links are called edges. Every graph is made up of them.

Graphs form a web of interconnected nodes.

Now that we have defined the fundamental building blocks of graphs, in good data scientist fashion we would be interested in where the data would be coming from as well as building the algorithm. The paradigm introduced by graphs as opposed to regular tabular data is that while you can assign features to parts of the graph, the graph itself is now part of the data and can be considered a feature (to be revealed soon)!

  • Node features — this would be data attached to the nodes themselves such as: username, pixel colour, word length, customer spending or product information.
  • Edge features — this type of data describes the relationship between the nodes such as: number of messages sent, distance between pixels, semantic similarity between words, purchase quantity etc.
  • Subgraph features — here we are selecting a subset of nodes which form a graph of their own, which is a subgraph of the whole. It is possible to attach features to a subgraph too, such as: number of members of a friend group, length of a sentence in a text, number of transactions at a store.
  • Graph features — at this level we are giving features to the entirety of the graph. This could be information like: total number of people in the graph, total number of words in the text, total number of customers, products or stores in the graph.

Here’s a small example of a graph constructed from some basic retail data coming from customers purchasing products.

A table showing two customers purchasing multiple products. Each line contains an instantaneous bit of information.
It is the same data, but represented as a graph. The interactions between customers and products immediately become visible, and a more global image begins to form, rather than a focus on the single purchases.

Since we already have our data scientist hats on and have considered where we could potentially insert our data, what exactly are we using the data for? There are quite a few tasks one could use graphs for, the most popular being:

  • Node classification — this would involve placing nodes into different categories such as: predicting the user types, word types (verbs, adjectives etc) or categories of research papers.
  • Edge prediction — in this case we are presented with a graph and we try to predict whether certain edges exist between nodes. Examples of this include: friend requests between users, customers purchasing products.
  • Subgraph classification — the task here is to place a subset of the graph (including nodes and edges) into a set of given categories. For example: label objects given a set of pixels, identify important functional groups in organic Chemistry, identify customer / product hierarchies in a market.
  • Graph classification — this is the highest level task where we require our algorithm to place the entirety of the graph into a category. This could be: rating a book, label an entire image, categorise a molecule etc.

If you would like to see a more involved and interactive showcase of all of these concepts and applications I would like to recommend this article. It also contains an example where you can build a molecule and they will classify it into pungent or not pungent.

A Preliminary Look at Algorithms

Now that we have our fuel and a road to go on, we only have to find a suitable vehicle to take us there. Over the years there have been some excellent papers written about GNN algorithms and a full review of this field’s history is way beyond the scope of this article, but I will try to present some models which I found fascinating. For an in-depth review, I recommend this paper, as I will be commandeering some of their excellent diagrams.

Just a taste of the Universe of Graph Neural Networks [1].

The ideas from this diagram above which I would like to highlight for us is how the models are split into different categories. Each category represents how a specific part of a GNN was constructed and how it differentiates it from the other algorithms, which could be looking at the same graph. Let us look at each one of these modules:

  • Propagation module — this part of the network describes how information moves between the nodes. The aim of this module is to capture features from multiple nodes, as well as the topology/layout of the graph. A common way of doing this is aggregating over the neighbours of nodes.
  • Sampling module — this part of the algorithm describes how to deal with a graph when it becomes very large. It is usually unfeasible to aggregate information over all nodes, so sampling is used to get a relevant representation of nodes used for calculations.
  • Pooling module — in this part we use the representations of the nodes to infer the representations of some of the higher level elements of the graph such as representations of subgraphs or the graph itself.

The exact nature of these “representations” will be discussed later as they pertain to latent variables, or embeddings which condense information about the nature of the graph and its features. Like I promised, next I shall point out two algorithms which I found very influential in my research.

The Least but also the Most Convoluted Graph Network

Perhaps a lot of you are wondering about how these aggregation procedures take place in reality. What are the inputs and outputs of such a step, and what is the final aim?

To simply answer the question, the inputs (initially) are the features of the objects we are aggregating over, the output is the representation or embedding of the node or object of interest. Embeddings will be tackled at a later stage in this review, but for now it suffices to know that they are just condensed vectors of information about a specific object. (For the meantime, I can recommend this read.)

In the simplest case, we have an aggregation over the neighbouring nodes of a target node, the aggregation can be a simple mean operation over all the features of the nodes. This type of GNN architecture is known as a Graph Convolutional Network (GCN).

The aggregation process of a GCN [3].

The equations presented above showcase in a mathematical way what exactly happens with the input data. Since this acts as a layer of a neural network, there is a set of learnable weights and biases which is attached to it. The nodes’ features are propagated through this layer and that is when we arrive at our destination, the node embeddings. Or is it? Some of you may have noticed that there are superscripts of (k)’s everywhere in the equations. This is because there is no reason why we should be aggregating over only the nearest neighbours. It is possible to start much further away and recursively iterate the above steps until you reach the desired node by continuously aggregating over nearest neighbours. This essentially gives you a tuneable parameter which controls how many edges away from the node of interest your GNN aggregates. Note that the process is recursive, which means that if you have multiple layers (i.e. you start further away) you use the output of the previous iteration instead of the raw node features (also why the first iteration is set to be the raw node features above).

NB: The total number of neighbours can increase very fast when you go further and further away from the target node, hence significantly slowing down your algorithm, or filling up your memory.

Once the desired node representation is aggregated, the network performs a final step, which is passing the result through an MLP (multi-layer perceptron). This basically serves as processing to return the output of the neural network as desired by the problem at hand. Node embeddings and edge embeddings will be discussed in greater detail later on in the article.

Graph Attention is All You Need

Now that we have had a relatively simple example of how graph neural networks can aggregate features from nodes, we can introduce a more complicated one which works better in practice, because it generalises the GCN with tuneable weights for each neighbour.

The GAT (Graph Attention Network) work the same way as the GCN in spirit, but instead of just simply averaging over the neighbours’ features or embeddings, it uses the attention mechanism to allocate weights to the neighbours and then perform a weighted average. (If you wanted to brush up on the attention mechanism, here is a great place to start.)

The aggregation process of a GAT [3].

The use of attention weights in the aggregation process allows the neural network to understand which nodes to focus on, which can lead to a much better absorption of the relevant information. After all, not all nodes are equally important for the calculation of the embedding of a specific target node. Other than the attention mechanism, everything proceeds the same way as in the GCN.

Extensions of the Base Algorithms

The algorithms I presented above have been very influential, but are quite rudimentary compared to their full potential. Some basic extensions where a user would already see improvement would be:

  • It is a good idea to concatenate the edge features connecting the neighbours to the target node with their node features and use that as the full features, but that is just a single direction in which this idea can be expanded in.
  • Another possibility arises when considering multiple layers. Instead of just using the embeddings given by the previous layer on their own, concatenate the original node features to the embeddings and pass them through a fully connected layer to make them the same size as the desired input into the next layer.

This is the end of Part I of this article. Here I have tried to present and collect some of the interesting aspects of graphs, graph data and neural networks which have been developed to absorb information from them. The Journey continues in Part II, which will delve into a very different set of GNNs, which now have to deal with the complexity time series data.

References:

[1] — Zhou, J., Cui, G., Hu, S., Zhang, Z., Yang, C., Liu, Z., Wang, L., Li, C. and Sun, M. (2020). Graph neural networks: A review of methods and applications. AI Open, [online] 1, pp.57–81. doi:https://doi.org/10.1016/j.aiopen.2021.01.001.
[2] — Sanchez-Lengeling, B., Reif, E., Pearce, A. and Wiltschko, A. (2021). A Gentle Introduction to Graph Neural Networks. Distill, [online] 6(8). doi:https://doi.org/10.23915/distill.00033.
[3] — Daigavane, A., Ravindran, B. and Aggarwal, G. (2021). Understanding Convolutions on Graphs. Distill, 6(8). doi:https://doi.org/10.23915/distill.00032.

--

--

Greg Kantor
dunnhumby Science blog

Research Data Scientist at dunnhumby | Theoretical Physics PhD