A comprehensive introduction to GNNs — Part 3

Nicolas Raymond
Analytics Vidhya
Published in
9 min readFeb 9, 2022

From the vanilla GNN to the Graph Attention Network (GAT)

Before you start reading

Hi dear reader! If you just hopped into this introduction to GNNs, I encourage you to have a look at the first and the second part already released to enjoy this exciting new publication to its full potential.

What is coming next?

The series of publications aims to cover the following topics :

This part will start with the concept of node embedding and will end with the presentation of GATs.

Node embedding

Last time we talked about GNNs, I described their architectures using a composition of functions that you can see in the figure below.

GNN architecture seen as a composition of functions.

Let me quickly remind you about their purposes! Let us imagine that we give a graph-structured dataset as an input to an already trained GNN in a context of node classification. Said differently, we have a dataset forming an information network where each node has features (a vector representation) associated to it and we want to predict the class target of one or few nodes in the graph.

Here are the steps happening when the graph-structured dataset passes through the trained GNN:

  • (q) Each node of the graph updates its vector representation by aggregating information from its own features, but also its neighbors’ features. At this point, we refer to the vector representations of nodes as embeddings (z) since they embed rich information about their surroundings.
  • (f) The embeddings (z) are then transformed independently to obtain encodings, which are simply more meaningful (i.e distilled) version of the embeddings themselves. You could compare an embedding to a list of ingredients and an encoding to a more complex recipe made out of them. Since we know that the GNN is already trained, the function f should have learnt carefully how to combine the ingredients in order to get a recipe that helps the function g making good classification. Note that this part is optional, the embeddings could be used directly as encodings for the next step.
  • (g) The encodings are finally passed through a function (e.g softmax) to get class predictions.

The first step, which is the node embedding part (q), is really what distinguishes a GNN from any other feedforward neural network. But how are node embeddings actually created? Let me help you understand the “how” part with few figures!

Let us consider the graph in the figure below and let us suppose that for each node, we would like to create an embedding that captures information from the neighbors within a 2 hops distance range.

Node embedding creation concept.

In this example, if we look at node 1, we would expect it to receive information from all the nodes in the graph by the end of the node embedding process since all the nodes are not further than 2 hops away. However, we would also want to proceed in a way that makes the direct neighbors (2 & 5) more impactful in the creation of the new node representation.

Hence, the actual way of proceeding is to iteratively update node states by aggregating their own vector representations with the vector representations of their direct neighbors. Throughout each step, we will refer to the nodes’ vector representations as their hidden states (h). For each node, the initial hidden state (h(0)) is the vector with the original features while the last hidden state is its embedding (h(K) = z). In this particular case, K = 2 since we wanted every nodes to consider neighbors in the 2-hops neighborhood.

Iterative hidden states updates with K=2.

Following this idea, GNN architectures will mainly differ from their aggregation methods during the hidden states update.

Aggregation methods seen as black boxes.

The vanilla GNN

Brace yourselves folks, this part might be the trickiest one. It is totally fine if you do not grasp everything from this section. I think that every concept presented here are really general and will get lot clearer in the Graph Convolutional Networks (GCNs) section.

The concept of GNN came with the idea of generalizing the convolution operation, well recognized with images, to graph structured data.

Comparison of convolution applied to an image versus a graph.

It first introduced the idea of a function, with learnable parameters, shared by all nodes in order to allow them to update their hidden states. To update the hidden state of a node, this function known as local transition function (or message passing function) needs the current hidden state of the node itself and also the current hidden states of its direct neighbors.

Definition of the local transition function.
Black box visualization of the local transition function.

Nonetheless, Scarselli et al. added the criteria that the local transition function must be a contraction, meaning that each time the hidden states of any pair of nodes are updated, the distance between them should be shorter or equal. This criteria ensure some convergence properties shown a little bit further.

The contraction criteria of the local transition function.

A second introduced concept was a function the called global transition function. It is a function with learnable parameters that updates all the nodes’ hidden states of a graph at once. To do so, the global transition function needs all the current hidden states and the adjacency matrix of the graph. In the vocabulary of neural networks we could say that the global transition function is a layer of the GNN. You can see in the figure below that it is basically an efficient way of applying the local transition function to all the nodes in a single step.

Definition of the global transition function.

Since we supposed earlier that the local transition function is a contraction, the Banach fixed point theorem (see the source in the figure bellow), ensures that each hidden state should converge to a fixed value if we apply the global transition function repeatedly. The goal would be to use these stable hidden states as node embeddings. However, in practice, we will use an estimate of these stable hidden states by applying the global transition function for a pre-defined number of iterations (K). Note that applying the function K times would make every node embeddings carry information within a K-hops neighborhood.

Convergence property of the global transition function.

The composition of function resulting from this iterative application of the global transition function will represent the function q in the vanilla GNN framework.

The vanilla GNN framework.

Graph Convolutional Networks (GCNs)

After the idea of the vanilla GNN, the Graph Convolution Networks (GCNs) proposed two great changes:

  • Only execute a low number of hidden state updates (2–3) so that the node embeddings only carry information from a close neighborhood.
  • Use different parameters (W) for each layer to allow variation of the hidden states sizes between layers.

In the figure below, we can see a simple way to update a node’s hidden state. It consists of taking an average of the node’s current hidden state and its direct neighbors’ current hidden states, execute a matrix multiplication of the average with the weight parameters matrix (W) and apply a non-linear activation function to the result (sigma).

Description of the hidden states update by average.
Identification of the local transition function.

In both figures below, we can visualize the computation steps behind the hidden state update of the node 1 in the example graph.

Hidden state computation represented with neural network architecture.
Black box representation of the hidden state update.

Now you might wonder how to efficiently update all these hidden states at once. Let me show you right away with detailed illustrations! The next figure defines the global transition function associated to the simple local transition function we just explored. There might be a bunch of mathematical symbols that can be confusing but I will summarize the whole idea in the 4 following figures.

Definition of the global transition function with the basic approach.

In order to efficiently calculate the hidden states of nodes in the graph, the first thing you need to do is to modify the original graph a tiny bit! Precisely, you need to added a directed edge from each node to itself. This action consists of adding 1s on the diagonal of the original adjacency matrix (A) in order to get a modified one (A~). From this matrix, we can then calculate the inverse of the in-degree matrix (D~).

Addition of self connections to the graph.

With these two steps done, we can now consider the simple weight matrix (W) shown below to execute a single hidden states update.

Initialization of weight matrix.
Matrix multiplications in the parenthesis.
Application of a ReLU to the result.

Isn’t it nice and easy? The usage of the modified adjacency matrix and its associated in-degree matrix saved us from an inefficient “for loop”!

Note that there are plenty of other simple ways to aggregate the hidden states of nodes. For example Kipf and Welling [2017] proposed a weighted average that decreases the importance of direct neighbors that have a many direct neighbors themselves!

Local transition function of the main GCN architecture.
Global transition function of the main GCN architecture.

Graph Attention Networks (GATs)

Now that I’ve shown you the basics of GCNs, do you have any idea of how we could modify them in order to possibly achieve greater scores without too much additional parameters?

Since most of the aggregation functions are based on a weighted average of the hidden states, the Graph Attention Networks (GATs) proposed that each node should be able to learn how to attribute weights (i.e importance) to their direct neighbors.

As you can notice in the next figure, we will note the importance of the node j according to the node i with the coefficient alpha-ij.

Definition of an attention coefficient.

To provide each node with the capability of learning their neighbors’ weights, GATs are using attention mechanisms for each of their layers. But how does an attention mechanism actually work?

Well, let us suppose that a node i wants to learn the importance of each of its neighbors. In order to do so, we first need to execute the following steps for each node j in the neighborhood (including the node i itself):

  • Apply the same linear transformation to node i and node j hidden states.
  • Concatenate both results.
  • Execute a dot product with the attention vector (a) of the layer and the concatenated results.
  • Apply a Leaky ReLU.
  • Calculate the exponential of the result.

Then, we divide all the coefficients calculated by their total sum to make sure they are all in the range [0, 1] and sum up to 1. Each of these coefficients (alpha-ij) now represents the weight of node j according to node i.

Definition of the GAT layer with the attention mechanism.

In the next figures you can visualize the computation required in a simple example where the hidden state of node 1 is updated.

Visualization of the attention mechanism (LR = Leaky ReLU).
Visualization of the hidden state update.

Before you go

Thank you for your reading! Feel free to visit my LinkedIn page.

Special thanks

I would like to give a special mention to Stanford’s CS224W: Machine Learning with Graphs lectures available for free online. They helped me to understand the GNNs while also inspiring me for the design of some figures.

References

Articles

  • Kipf, Thomas N., and Max Welling. “Semi-Supervised Classification with Graph Convolutional Networks.” ArXiv:1609.02907 [Cs, Stat], Feb. 2017, http://arxiv.org/abs/1609.02907.
  • Scarselli, Franco, et al. “The Graph Neural Network Model.” IEEE Transactions on Neural Networks, vol. 20, no. 1, Jan. 2009, pp. 61–80, doi:10.1109/TNN.2008.2005605.
  • Shukla, Satish, et al. “A Generalized Banach Fixed Point Theorem.” Bulletin of the Malaysian Mathematical Sciences Society, vol. 39, no. 4, Oct. 2016, pp. 1529–39, doi:10.1007/s40840–015–0255–5.
  • Veličković, Petar, et al. “Graph Attention Networks.” ArXiv:1710.10903 [Cs, Stat], Feb. 2018, http://arxiv.org/abs/1710.10903.

Websites

--

--

Nicolas Raymond
Analytics Vidhya

MSc. BMath. Machine Learning Intern at Alberta Machine Intelligence Institute (Amii).