Finding shortest paths with Graph Neural Networks
In this article we show how a Graph Network with attention read and write can perform shortest path calculations. This network performs this task with 100% accuracy after minimal training.
Neural networks are a way to create functions that no human could write. They do this by harnessing the power of large datasets. On problems for which we have capable neural models, we can use example inputs and outputs to train the network to learn a function that transforms those inputs into those outputs, and hopefully generalizes to other unseen inputs.
We need to be able to build neural networks that can learn functions on graphs. Those neural networks need the right inductive biases so that they can reliably learn useful graph functions. With that foundation, we can build powerful neural graph systems.
Here we present a “Graph network with attention read and write”, a simple network that can effectively compute shortest path. It is an example of how to combine different neural network components to make a system that readily learns a classical graph algorithm.
We present this network both as a novel system in of itself, but more importantly as the basis for further investigation into effective neural graph computation.
Given a question “What is the length of the shortest path between station A and B?” and a graph (as a set of nodes and edges), we want to learn a function that will return an integer answer.
Machine learning on graphs is a young but growing field. For a survey of the different approaches and their history see Graph Neural Networks: A Review of Methods and Applications or our introduction. We’ve also included a list of surveys at the end of our introduction article.
The classic algorithms for calculating shortest paths are A*, Dijkstra’s and Bellman Ford. These are robust and widely implemented. Dijkstra’s is most similar to our use case, of finding the shortest path between two specific nodes with no available path cost heuristic.
The earliest work on neural based solutions to shortest path was motived by communications and packet routing, where approximate methods faster than the classical algorithms were desired. These operate quite different from today’s neural networks, they used iterative back-propagation to solve shortest path on a specific graph. Examples of work in this area include Neural networks for routing communication traffic (1988), A Neural Network for Shortest Path Computation (2000) and Neural Network for Optimization of Routing in Communication Networks (2006).
In this work we seek to build a model that will work on many unseen graphs, in stark contrast to those methods, which solve for a single graph. Furthermore, we seek to offer a foundation for learning more complex graph functions from pairs of input questions and expected outputs.
A recent, groundbreaking approach to the problem was DeepMind’s Differentiable neural computers (PDF), which computed shortest paths across the London Tube map. It did this by taking the graph as a sequence of connection tuples and learning a step-by-step algorithm using a read-write memory. It was provided with a learning curriculum, gradually increasing the graph and question size.
By contrast, our solution performs much better (100% vs 55.3%), on bigger paths (length 9 vs 4), does not require curriculum learning, does not require the training of an LSTM controller, has fewer parameters and is a simpler network with fewer components.
Whilst we haven’t found any other published solutions to this exact problem, there are many instances of similar techniques being used for different problems. A couple of relevant examples:
- Commonsense Knowledge Aware Conversation Generation with Graph Attention uses attention to read out of a knowledge graph
- Deeply learning molecular structure-property relation- ships using attention- and gate-augmented graph convolutional network uses a GRU at each graph node, propagates state between nodes using attention
- DeepPath: A Reinforcement Learning Method for Knowledge Graph Reasoning uses a policy network to navigate paths across a graph
To our knowledge, ours is the first example of combining attention read and write with a graph network.
The problem we’ll solve
Given the question “How many stations are between station 1 and station 15” and a rail network, we’d like the correct answer, e.g. “6”.
More concretely, we’ll train the network with Graph-Question-Answer tuples. Each tuple contains a unique randomly generated graph, an English language question and the expected answer.
These are split into non-overlapping training, validation and test sets.
This data setup produces a network that will work on new, never previously seen graphs. That is, it’ll learn a graph algorithm.
We’ll use the CLEVR-Graph dataset to generate the graphs, questions and answers
An introduction to CLEVR-Graph
When building a machine learning solution, and not achieving high accuracy, it can be hard to know whether the model has a deficiency or if the data has inherent noise and ambiguities.
To remove this uncertainty we’ve employed a synthetic dataset. This is data that we generated, based on our own set of rules. Thanks to the explicit structure of the data, we can be confident a good model can score 100% accuracy. This really helps when comparing different architectures.
CLEVR-graph contains a set of questions and answers about procedurally generated transport network graphs. Here’s what one of its transport networks looks like (it’s modeled on the London Tube) and some example questions and answers:
Every question in CLEVR-graph comes with an answer and a unique, procedurally generated graph.
CLEVR-graph can generate many different types of questions. For this article we’ll generate just those that pertain to shortest paths. There is one template for this (“How many stations are between A and B?”) and it is combined with a randomly chosen pair of stations from each randomly generated graph to give a graph-question-answer triple.
The graph-question-answer triples are generated as a YAML file, which we then compile into TFRecords.
As there is only one question template, the training data lacks the variety you’d get in a more natural (human) source. This makes the dataset easier to solve. We leave language diversity as a future extensional challenge (and would love to see readers’ solutions!).
We’ll build a neural network in TensorFlow to solve our problem. TensorFlow is a popular library for building neural networks and comes with many useful components. The code for this system is available in our repository.
The system we’ll build takes a question, performs multiple iterations of processing, then finally produces an output:
The structure we’ll use is a recurring neural network (RNN) — in an RNN the same cell is executed multiple times sequentially, passing its internal state forward to the next execution.
The RNN cell takes the question and graph as inputs, as well as any outputs from earlier executions of the cell. These are transformed, and an output vector and updated node state are generated by the cell.
Inside the RNN cell are two major components: a graph network and an output cell. Their detail is key to understanding how this network works. We’ll cover those in detail in the next sections.
The RNN cell passes forward a hidden state, the “node state”. This is a table of node states, one vector per node in the graph. The network uses this to keep track of on-going calculations about each node.
The RNN cell is executed a fixed number of times (experimentally determined, generally longer than the longest path between two nodes) and then the final cell’s output is used as the overall output of the system.
That completes a brief survey of the overall structure. The next sections will outline the input to the network, and how the RNN cell works.
The first part of building the system is to create an input data pipeline. This provides three things:
- The input question text “How many stations are between…”
- The input graph to calculate shortest path on
- The expected answer e.g. “6”
All of these are pre-processed into TFRecords so they can be efficiently loaded and passed to the model. The code for this process is in build.py in the accompanying GitHub repository. You can also download the pre-compiled TFRecords.
The question text
There are three steps to transform the English question into information the model can use:
- Split the text into a series of ‘tokens’ (e.g. common words and special characters like ? and spaces)
- Assign each unique token an integer ID and represent each instance of the token as that integer ID
- Embed each token (e.g. word, special character) as a vector. This step is done during model runtime, and for this simple example we’ve used one-hot vectors to encode the integers.
The graph is represented by three data-structures in the TFRecord examples:
- A list of nodes with their id, name and properties
- A list of edges with their source and destination node ids, and any edge properties
- An adjacency matrix, mapping the connection between nodes. It has a 1.0 if two nodes are directly connected and 0.0 otherwise.
Each of these are multi-dimensional tensors.
Names and properties are represented using the same text-encoding scheme (e.g. as integer tokens passed into an embedding) used for the question text
The expected answer
The expected answer (for this dataset, always an integer from zero to nine) is represented as a single textual token (i.e. as an integer), using the same encoding scheme as for the question text and node/edge properties.
The expected answer is used during training mode for loss calculation and back-propagation, during validation and testing it’s used to measure model accuracy and to pinpoint failing data examples for debugging.
How the RNN works
The heart of the network is an RNN. It consists of an RNN cell, which is repeatedly executed, passing its results forwards.
In our experiments, we used 10 RNN iterations (in general, the number of iterations needs to be greater than or equal to the longest path being tested for).
This RNN cell does four things each iteration:
- Write data into a selected node’s state
- Propagate node-states along the edges in the graph
- Read data from a selected node’s state
- Take the read data, all previous RNN cell outputs, combine them, and produce an output for this RNN iteration
With just these four steps, the network is capable of readily learning how to calculate shortest paths.
The graph network is the key to this model’s capabilities. It enables it to compute functions of the graph’s structure.
In the graph network each node n has a state vector S(n,t) at time t. We used a state vector of width 4. Each iteration, the node states are propagated to the node’s neighbors adj(n):
The initial states S(n,0) are zero valued vectors.
Two more pieces are required for this simple state propagation to be capable of shortest path calculations: a node state write and a node state read.
Node state write
The node state write is a mechanism for the model to add a signal vector to the states of particular node(s) in the graph:
The mechanism begins by extracting words from the question, to form the write query q_write. This query will be used to select node states to add the write signal p to.
The write query is generated using attention by index, which calculates which indices in the question words Q should be attended to (as a function of the RNN iteration id r, a one-hot vector), then extracts them as a weighted sum:
The write signal is calculated by taking the RNN iteration id and applying a dense layer with sigmoid activation.
Next the write signal and the write query are fed into an attention by content layer to determine how the write signal will be added to the node states. Attention by content is simply the standard dot-product attention mechanism, where each item is compared with the the query by dot product, to produce a set of scores. The scores are then fed through softmax to become a distribution that sums to one:
In this instance, scores are calculated as the dot product of each node state’s associated node id with the write query. Finally, the write signal is added to the node states in proportion to the scores:
Node state read
Next, the state is read from the graph, in a similar fashion to how the signal was written. A read query is calculated from the input question words, again using attention by index:
Then the read query is used to extract a value from the node states using attention by content. Like before, the read query is compared to each node’s id to create a score distribution:
The final read out value is then computed using the weighted sum of the node states:
The output cell
The final important piece of the RNN is the output cell. It’s essential to the network’s success (removing the previous output look-back decreases accuracy to 95%).
Here’s an overview of the output cell:
The output cell has two parts:
- Attention by index across the previous outputs and the most recent graph network read (this works the same as the read and write queries in the previous section)
- A basic feed-forward network to transform the attention output into the cell’s output
The output cell can combine outputs from earlier iterations with the current graph network output. This allows the cell to repeatedly combine previous outputs, offering a form of simple recursion. This also helps the network to easily look back to the output of an earlier iteration, regardless of total number of RNN iterations.
Training the network
The network’s hyper-parameters were experimentally determined. The learning rate was identified using the Learning Rate Finder protocol, and other parameters such as node state size, number of graph read/write heads and number of RNN iterations were determined through grid search.
The network achieves 100% test accuracy after 9k training cycles (2 minutes on a MacBook Pro CPU). This fast convergence shows that the network has a strong inductive bias towards solving this problem.
Visualizing how it works
I’ve included a prediction-mode attention visualization to let you see what the network is doing. It shows where the read, write and output attention heads are focussed:
The attention is used in mostly obvious ways:
- Every step reads from the first mentioned station’s node state
- Every step writes from the second mentioned station’s node state
- The output cell mostly uses the read value from the network, but regularly combines it (at least partially) with other steps’ output
Efficiency of this approach
With any solution to a problem, it’s worth comparing it to other approaches. Here we compare this model to Differentiable Neural Computers and the standard classical approach.
Compared to the classic approach, Dijkstra’s, this approach (and indeed most neural approaches) is less efficient:
- This model requires a moderate amount of initial training, Dijkstra’s does not
- During prediction mode, this model performs more operations than Dijkstra’s, although as they’re parallel matrix operations, they could possibly have a similar runtime thanks to specialized hardware (e.g. GPUs)
- The two methods have similar run-time: Dijksta’s scales O(|E| + |N|log|N|) where E are edges and N nodes, this method scales O(|longest path|) ~ O(|E|)
However, our approach has the one major benefit that it has the potential to learn different functions depending on the training examples.
Compared to Differentiable Neural Computers our approach performs much better:
- This method achieves much higher accuracy and scalability than DNCs — 100% (paths up to length 9) versus 55.3% (paths up to length 4)
- This method does not require the construction and administration of a learning curriculum
- We suspect this method requires much less training resource than the DNC (we get 100% accuracy after 2 minutes on a laptop CPU), although there are no figures published in the DNC paper
- This method is a simpler network with fewer read heads (1 vs 5), smaller memory state (64 elements vs 128) and no LSTM cell
- We suspect this architecture more readily scales to larger graphs since we parallelize the graph-exploration (e.g. a DNC requires more memory, read heads and runtime to deal with larger graphs or higher edge density)
Learning other functions
As part of this work, we explored using a Gated Recurrent Unit (GRU) at each node as the node-state update function. This worked, however the extra training effort due to the increased parameters brought no benefit, so ultimately the GRU was disabled. We leave as future work using an extension of the presented architecture to learn different graph functions.
An important tool in understanding the role of different parts of a neural network is ablation analysis, i.e. removing a piece and seeing how the network performs. A brief analysis was performed for this network (the network itself is the result of removing many pieces from a more complex network). In each case, the comparison is to the benchmarked 100% test accuracy.
Reducing RNN iterations decreases accuracy as the network can no longer discriminate between paths longer than the iteration count (e.g. 3 iterations achieved 40% test accuracy, 7 iterations achieved 69% test accuracy. Note that for practical reasons, the classes are not fully balanced).
Removing output cell look-back to previous RNN outputs reduced test accuracy to 95%.
Thanks to Andrew Jefferson for the encouragement to write this article and for his input and reviews.
Octavian’s mission is to develop systems with human-level reasoning capabilities. We believe that graph data and deep learning are key ingredients to making this possible. If you interested in learning more about our research or contributing, get in touch.