DAIR.AI
Published in

DAIR.AI

An Illustrated Guide to Graph Neural Networks

A breakdown of the inner workings of GNNs…

Nickelback approves of Graph Deep Learning 😎

❓ Graph, who?

A graph is a data structure comprising of nodes (vertices) and edges connected together to represent information with no definite beginning or end. All the nodes occupy an arbitrary position in space, usually clustered according to similar features when plotted in 2D (or even nD) space.

This is a graph: a bunch of interconnected nodes that represent entities.
Nodes represent users while edges represent the connection/relationship between two entities. Social media graphs are usually a whole lot more enormous and complex!

📑 What you need to know

Here, I’ll be mentioning some concepts right off the bat. I’ll be talking about recurrent units, embedding vector representations, and feed-forward neural networks. It’s nice to know a fair bit of graph theory (as in, what a graph is and what it looks like) as well.

🚪 Enter Graph Neural Networks

Each node has a set of features defining it. In the case of social network graphs, this could be age, gender, country of residence, political leaning, and so on. Each edge may connect nodes together that have similar features. It shows some kind of interaction or relationship between them.

It’s the same graph from above.
The order they are in doesn’t really matter.
The envelopes are simply the one-hot-encoded (embedding) vectors for each node (recurrent unit, now).

📮 Message Passing

Once the conversion of nodes and edges are completed, the graph performs Message Passing between the nodes. This process is also called Neighbourhood Aggregation because it involves pushing messages (aka, the embeddings) from surrounding nodes around a given reference node, through the directed edges.

The violet square is a simple feed-forward NN applied on the embeddings (white envelopes) from the neighbouring nodes. The recurrent function (red triangle) applied to the current embedding (white envelope) and summation of edge neural network outputs (black envelopes) to obtain the new embedding (white envelope prime).

🤔 What do I do with the final vector representations?

Once you perform the Neighbourhood Aggregation/Message Passing procedure a few times, you obtain a completely new set of embeddings for each nodal recurrent unit.

Here’s the final graph with the fully updated node embedding vectors after n repetitions of Message Passing. You can take the representations of all the nodes and sum them together to get H.

📝 Graph Neural Networks, a summary

GNNs are fairly simple to use. In fact, implementing them involved four steps.

  1. Given a graph, we first convert the nodes to recurrent units and the edges to feed-forward neural networks.
  2. Then we perform Neighbourhood Aggregation (Message Passing, if that sounds better) for all nodes n number of times.
  3. Then we sum over the embedding vectors of all nodes to get graph representation H.
  4. Feel free to pass H into higher layers or use it to represent the graph’s unique properties!

🙇🏻‍♂️ Why Graph Neural Networks?

Now that we know how Graph Neural Networks work, why would we want to apply/use them?

🔩 In a nutshell

Here, we covered the basics of Graph Neural Networks with a bunch of visualisations. Graph DL is really interesting and I urge you to try coding up your own implementation of it. There are tons of graph construction libraries like the Deep Graph Library or PyTorch Geometric.

❤️ Love talking tech?

You’re in luck! I love talking about Artificial Intelligence, Data Science, and the progress of science and technology in general. If you want to chat, you can catch me procrastinating on Twitter and LinkedIn.

--

--

Democratizing Artificial Intelligence Research, Education, Technologies

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store