Finding shortest paths with Graph Neural Networks

David Mack
Jan 7, 2019 · 14 min read

In this article we show how a Graph Network with attention read and write can perform shortest path calculations. This network performs this task with 100% accuracy after minimal training.


Here at Octavian we believe that graphs are a powerful medium for representing diverse knowledge (for example BenevolentAI uses them to represent pharmaceutical research and knowledge).

Neural networks are a way to create functions that no human could write. They do this by harnessing the power of large datasets. On problems for which we have capable neural models, we can use example inputs and outputs to train the network to learn a function that transforms those inputs into those outputs, and hopefully generalizes to other unseen inputs.

We need to be able to build neural networks that can learn functions on graphs. Those neural networks need the right inductive biases so that they can reliably learn useful graph functions. With that foundation, we can build powerful neural graph systems.

Here we present a “Graph network with attention read and write”, a simple network that can effectively compute shortest path. It is an example of how to combine different neural network components to make a system that readily learns a classical graph algorithm.

We present this network both as a novel system in of itself, but more importantly as the basis for further investigation into effective neural graph computation.

The code for this system is available in our repository.

Problem statement

Given a question “What is the length of the shortest path between station A and B?” and a graph (as a set of nodes and edges), we want to learn a function that will return an integer answer.

Related work

Machine learning on graphs is a young but growing field. For a survey of the different approaches and their history see Graph Neural Networks: A Review of Methods and Applications or our introduction. We’ve also included a list of surveys at the end of our introduction article.

The classic algorithms for calculating shortest paths are A*, Dijkstra’s and Bellman Ford. These are robust and widely implemented. Dijkstra’s is most similar to our use case, of finding the shortest path between two specific nodes with no available path cost heuristic.

The earliest work on neural based solutions to shortest path was motived by communications and packet routing, where approximate methods faster than the classical algorithms were desired. These operate quite different from today’s neural networks, they used iterative back-propagation to solve shortest path on a specific graph. Examples of work in this area include Neural networks for routing communication traffic (1988), A Neural Network for Shortest Path Computation (2000) and Neural Network for Optimization of Routing in Communication Networks (2006).

In this work we seek to build a model that will work on many unseen graphs, in stark contrast to those methods, which solve for a single graph. Furthermore, we seek to offer a foundation for learning more complex graph functions from pairs of input questions and expected outputs.

A recent, groundbreaking approach to the problem was DeepMind’s Differentiable neural computers (PDF), which computed shortest paths across the London Tube map. It did this by taking the graph as a sequence of connection tuples and learning a step-by-step algorithm using a read-write memory. It was provided with a learning curriculum, gradually increasing the graph and question size.

By contrast, our solution performs much better (100% vs 55.3%), on bigger paths (length 9 vs 4), does not require curriculum learning, does not require the training of an LSTM controller, has fewer parameters and is a simpler network with fewer components.

Whilst we haven’t found any other published solutions to this exact problem, there are many instances of similar techniques being used for different problems. A couple of relevant examples:

To our knowledge, ours is the first example of combining attention read and write with a graph network.

The problem we’ll solve

Given the question “How many stations are between station 1 and station 15” and a rail network, we’d like the correct answer, e.g. “6”.

More concretely, we’ll train the network with Graph-Question-Answer tuples. Each tuple contains a unique randomly generated graph, an English language question and the expected answer.

For example:

These are split into non-overlapping training, validation and test sets.

This data setup produces a network that will work on new, never previously seen graphs. That is, it’ll learn a graph algorithm.

We’ll use the CLEVR-Graph dataset to generate the graphs, questions and answers

An introduction to CLEVR-Graph

When building a machine learning solution, and not achieving high accuracy, it can be hard to know whether the model has a deficiency or if the data has inherent noise and ambiguities.

To remove this uncertainty we’ve employed a synthetic dataset. This is data that we generated, based on our own set of rules. Thanks to the explicit structure of the data, we can be confident a good model can score 100% accuracy. This really helps when comparing different architectures.

CLEVR-graph contains a set of questions and answers about procedurally generated transport network graphs. Here’s what one of its transport networks looks like (it’s modeled on the London Tube) and some example questions and answers:

Learn more about the dataset

Every question in CLEVR-graph comes with an answer and a unique, procedurally generated graph.

CLEVR-graph can generate many different types of questions. For this article we’ll generate just those that pertain to shortest paths. There is one template for this (“How many stations are between A and B?”) and it is combined with a randomly chosen pair of stations from each randomly generated graph to give a graph-question-answer triple.

The graph-question-answer triples are generated as a YAML file, which we then compile into TFRecords.

As there is only one question template, the training data lacks the variety you’d get in a more natural (human) source. This makes the dataset easier to solve. We leave language diversity as a future extensional challenge (and would love to see readers’ solutions!).

Solution overview

We’ll build a neural network in TensorFlow to solve our problem. TensorFlow is a popular library for building neural networks and comes with many useful components. The code for this system is available in our repository.

The system we’ll build takes a question, performs multiple iterations of processing, then finally produces an output:

Every step has the question, graph and previous outputs as input — they are hidden here for visual clarity

The structure we’ll use is a recurring neural network (RNN) — in an RNN the same cell is executed multiple times sequentially, passing its internal state forward to the next execution.

The RNN cell takes the question and graph as inputs, as well as any outputs from earlier executions of the cell. These are transformed, and an output vector and updated node state are generated by the cell.

Inside the RNN cell are two major components: a graph network and an output cell. Their detail is key to understanding how this network works. We’ll cover those in detail in the next sections.

The RNN cell passes forward a hidden state, the “node state”. This is a table of node states, one vector per node in the graph. The network uses this to keep track of on-going calculations about each node.

The RNN cell is executed a fixed number of times (experimentally determined, generally longer than the longest path between two nodes) and then the final cell’s output is used as the overall output of the system.

That completes a brief survey of the overall structure. The next sections will outline the input to the network, and how the RNN cell works.

Input data

The first part of building the system is to create an input data pipeline. This provides three things:

All of these are pre-processed into TFRecords so they can be efficiently loaded and passed to the model. The code for this process is in in the accompanying GitHub repository. You can also download the pre-compiled TFRecords.

The question text

There are three steps to transform the English question into information the model can use:

The graph

The graph is represented by three data-structures in the TFRecord examples:

Each of these are multi-dimensional tensors.

Names and properties are represented using the same text-encoding scheme (e.g. as integer tokens passed into an embedding) used for the question text

The expected answer

The expected answer (for this dataset, always an integer from zero to nine) is represented as a single textual token (i.e. as an integer), using the same encoding scheme as for the question text and node/edge properties.

The expected answer is used during training mode for loss calculation and back-propagation, during validation and testing it’s used to measure model accuracy and to pinpoint failing data examples for debugging.

How the RNN works

The heart of the network is an RNN. It consists of an RNN cell, which is repeatedly executed, passing its results forwards.

In our experiments, we used 10 RNN iterations (in general, the number of iterations needs to be greater than or equal to the longest path being tested for).

This RNN cell does four things each iteration:

With just these four steps, the network is capable of readily learning how to calculate shortest paths.

Graph network

The graph network is the key to this model’s capabilities. It enables it to compute functions of the graph’s structure.

In the graph network each node n has a state vector S(n,t) at time t. We used a state vector of width 4. Each iteration, the node states are propagated to the node’s neighbors adj(n):

[Superscript and subscript are used in the formula renderings for ease of comprehension instead of function notation S(n,t)]

The initial states S(n,0) are zero valued vectors.

Two more pieces are required for this simple state propagation to be capable of shortest path calculations: a node state write and a node state read.

Node state write

The node state write is a mechanism for the model to add a signal vector to the states of particular node(s) in the graph:

The mechanism begins by extracting words from the question, to form the write query q_write. This query will be used to select node states to add the write signal p to.

The write query is generated using attention by index, which calculates which indices in the question words Q should be attended to (as a function of the RNN iteration id r, a one-hot vector), then extracts them as a weighted sum:

The write signal is calculated by taking the RNN iteration id and applying a dense layer with sigmoid activation.

Next the write signal and the write query are fed into an attention by content layer to determine how the write signal will be added to the node states. Attention by content is simply the standard dot-product attention mechanism, where each item is compared with the the query by dot product, to produce a set of scores. The scores are then fed through softmax to become a distribution that sums to one:

In this instance, scores are calculated as the dot product of each node state’s associated node id with the write query. Finally, the write signal is added to the node states in proportion to the scores:

Node state read

Next, the state is read from the graph, in a similar fashion to how the signal was written. A read query is calculated from the input question words, again using attention by index:

Then the read query is used to extract a value from the node states using attention by content. Like before, the read query is compared to each node’s id to create a score distribution:

The final read out value is then computed using the weighted sum of the node states:

The output cell

The final important piece of the RNN is the output cell. It’s essential to the network’s success (removing the previous output look-back decreases accuracy to 95%).

Here’s an overview of the output cell:

The output cell has two parts:

The output cell can combine outputs from earlier iterations with the current graph network output. This allows the cell to repeatedly combine previous outputs, offering a form of simple recursion. This also helps the network to easily look back to the output of an earlier iteration, regardless of total number of RNN iterations.


Training the network

The network’s hyper-parameters were experimentally determined. The learning rate was identified using the Learning Rate Finder protocol, and other parameters such as node state size, number of graph read/write heads and number of RNN iterations were determined through grid search.

The network achieves 100% test accuracy after 9k training cycles (2 minutes on a MacBook Pro CPU). This fast convergence shows that the network has a strong inductive bias towards solving this problem.

Visualizing how it works

I’ve included a prediction-mode attention visualization to let you see what the network is doing. It shows where the read, write and output attention heads are focussed:

The attention is used in mostly obvious ways:

Efficiency of this approach

With any solution to a problem, it’s worth comparing it to other approaches. Here we compare this model to Differentiable Neural Computers and the standard classical approach.

Compared to the classic approach, Dijkstra’s, this approach (and indeed most neural approaches) is less efficient:

However, our approach has the one major benefit that it has the potential to learn different functions depending on the training examples.

Compared to Differentiable Neural Computers our approach performs much better:

Learning other functions

As part of this work, we explored using a Gated Recurrent Unit (GRU) at each node as the node-state update function. This worked, however the extra training effort due to the increased parameters brought no benefit, so ultimately the GRU was disabled. We leave as future work using an extension of the presented architecture to learn different graph functions.

Ablation analysis

An important tool in understanding the role of different parts of a neural network is ablation analysis, i.e. removing a piece and seeing how the network performs. A brief analysis was performed for this network (the network itself is the result of removing many pieces from a more complex network). In each case, the comparison is to the benchmarked 100% test accuracy.

Reducing RNN iterations decreases accuracy as the network can no longer discriminate between paths longer than the iteration count (e.g. 3 iterations achieved 40% test accuracy, 7 iterations achieved 69% test accuracy. Note that for practical reasons, the classes are not fully balanced).

Removing output cell look-back to previous RNN outputs reduced test accuracy to 95%.


Thanks to Andrew Jefferson for the encouragement to write this article and for his input and reviews.

Octavian’s research

Octavian’s mission is to develop systems with human-level reasoning capabilities. We believe that graph data and deep learning are key ingredients to making this possible. If you interested in learning more about our research or contributing, get in touch.


Research into machine learning and reasoning

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade
A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store