Deep Learning with Knowledge Graphs

Last week I gave a talk at Connected Data London on the approach that we have developed at Octavian to use neural networks to perform tasks on knowledge graphs.

Here’s my talk, recorded during a recent Neo4J online meetup:

In this post I will summarize that talk (including most of the slides) and provide links to the papers that have most significantly influenced us.

To find out more about our vision for a new approach to building the next generation of database query engine see our recent article.


What is a graph?

Two functionally identical graph models

We are using a property graph or attributed graph model. Nodes (vertices) and relations (edges) can have properties. In addition our Neural Network has a global state that is external to the graph. The slide shows two representations of this model one from Neo4j and the other from DeepMind (n.b. these are effectively identical).

Why are we interested in graphs?

All the graphs!

Graphs have a rich history, starting with Leohnard Euler in the 18th century to a whole range of graphs today. Within the field of computer science there are many applications of graphs: graph databases, knowledge graphs, semantic graphs, computation graphs, social networks, transport graphs and many more.

Graphs have played a key role in the rise of Google (their first breakthrough was using PageRank to power searches, today their Knowledge Graph has grown in importance) and Facebook. From politics to low cost international air travel, graph algorithms have had a major impact on our world.

What is Deep Learning?

I’m not sure about A.I. let’s talk about Deep Learning…

Deep learning is a branch of machine learning centered around training multi layer (“deep”) neural networks using gradient descent. The basic building block of these neural networks is the dense (or fully connected) network.

A deep neural network using dense layers

Using deep learning has allowed us to train computers to tackle a range of previously challenging tasks from playing Go to image recognition with superhuman performance.

MacNets and other examples of superhuman image processing neural networks

Machine Learning

In general machine learning is a simple concept. We create a model of how we think things work e.g. y = mx + c this could be:

house_price = m • number_of_bedrooms + c
Machine learning, view from 20,000ft

We train (fit) the parameters of our model (m and c in the example) using the data that we have. Once our training is done we have some learned parameter values and we have a model that we can use to make predictions.

Sometimes the parameters are useful by themselves (e.g. when we use a neural network to train a word embedding such as word2vec).

Deep Learning on Graphs

At Octavian one of the questions we asked ourselves is: how would we like machine learning on graphs to look from 20,000ft?

To help answer this question, we compared traditional forms of deep learning to the world of graph learning:

Comparing graph machine learning with other setups

We identified three graph-data tasks which we believe require graph-native implementations: Regression, Classification and Embedding.

Aside: there are other graph-specific tasks such as link prediction that don’t easily fit into the three tasks above.

We observed that many existing techniques for machine learning on graphs have some fundamental limitations:

  • Some do not work on unseen graphs (because they require first training a graph embedding)
  • Some require converting the graph into a table and discarding its structure (e.g. sampling from a graph using random walks)

Existing Work

Performance of DL models on graph problems is not superhuman

Much of the existing work using Deep Learning on graphs focuses on two areas.

  1. Making predictions about molecules (including proteins), their properties and reactions.
  2. Node classification/categorisation in large, static graphs.

Graphs, Neural Networks and Structural Priors

It’s often said that Deep Learning works well with unstructured data — images, free text, reinforcement learning etc.

But our superhuman neural networks are actually dealing with very specifically structured information and the neural network architectures are engineered to match the structure of the information they work well with.

Data structures that work with neural networks

Images are in a sense structured: they have a rigid 2D (or 3D) structure where pixels that are close to each-other are more relevant to each-other than pixels that are far apart. Sequences (e.g. over time) have a 1D structure where items that are adjacent are more relevance to one another than items that are far apart.

Dense layers make sense for Go where locations that are far apart on the board can have equal influence one another

When working with images and sequences, dense layers (e.g. where every input is connected to every output) doesn’t work well. Neural network layers that reflect and exploit the structure of the input medium achieve the best results.

For sequences Residual Neural Networks (RNNs) are used and for images Convolutional Neural Networks (CNNs) are used.

Convolutional Networks structurally encode that pixels which are close to one another are more significant that pixels which are far apart

In a convolutional neural network each pixel in the hidden layer only depends on a group of nearby pixels in the input (compare this to a dense layer where every hidden layer pixel depends on every input pixel).

Neither dense nor convolutional networks make sense for a transit graph

Nodes in graphs do not have fixed relations like nearby pixels in an image or adjacent items in a sequence. To make deep learning successful with graphs it’s not enough to convert graphs to matrix representation and put that input into existing Neural Network models. We have to figure out how to create Neural Network models that work well for graphs.

This paper makes the same argument more effectively than I do — check it out!

We aren’t the only people thinking about this. Some very clever people at DeepMind, Google Brain, MIT and University of Edinburgh lay out a similar position in their paper on Relational Inductive Biases. I recommend this paper to anyone interested in deep learning on graphs.

The paper introduces a general algorithm for propagating information through a graph and argues that by using neural networks to learn six functions to perform aggregations and transforms within the structure of the graph they can achieve state of the art performance on a selection of graph tasks.

One algorithm to rule them all?

By propagating information between nodes principally using the graph edges the authors argue they are maintaining the relational inductive biases present in the graph structure.

The MacGraph neural network architecture that we have been developing at Octavian has similarities to the relational inductive biases approach. It employes a global state that exists outside the graph and also propagates information between the graph nodes

Octavian’s experimental results

Before I can tell you about our results at Octavian I have to mention the task that we used to test our neural graph architecture.

our synthetic benchmark dataset

You can read more about CLEVR-Graph here. It’s a synthetic (procedurally generated) dataset which consists of 10,000 fictional transit networks loosely modelled on the London underground. For each randomly generated transit network graph we have a single question and correct answer.

Some example questions from the CLEVR-Ggraph question bank and an example graph

The crucial thing about this task is that each graph used to test the network is one the network has never seen before. Therefore it cannot memorise the answers to the questions but must learn how to extract the answer from new graphs.

At time of writing MacGraph is achieving almost-perfect results on tasks requiring 6 different skills:

MacGraph’s latest results on CLEVR-graph

I think that one of the most exciting skills is MacGraph’s ability to answer “How many stations are between{station} and {station}” because to solve that question it’s necessary to determine the shortest path between the stations (Dijkstra’s algorithm) which is a complex and graph-specific algorithm.

How does MacGraph work?

It’s not sufficient to just propagate information between nodes in the graph using transformation and aggregation functions. To answer natural language questions about a graph with natural language answers it’s necessary to transform the input question into a graph state that results in the correct answer being reached and it’s necessary to extract the answer information from the graph state and transform it into the desired answer.

…almost

Our solution for transforming between natural language and graph state is to use attention. You can read more about how this works here.

an alternative to dense layers

Attention cells are radically different to dense layers. Attention cells work with lists of information, extracting individual elements depending on their content or location.

These properties make attention cells great for selecting from the lists of nodes and edges that make up a graph.

In MacGraph write attention is used to input a signal to the nodes in the graph based on the query and the properties of the nodes. This signal should then prime the graph message passing to interact with the nodes most relevant to the question.

Prime the graph with the query using write attention

After information has propagated through the graph’s nodes, attention is used to extract the answer from the graph:

Read from the graph using attention

Combining write and read attention with propagation between nodes in the graph using the graph structure we are get the core of MacGraph.

Mac Graph architecture

Conclusion

There is a strong case that achieving superhuman results on graph-based tasks requires graph-specific neural network architectures.

We have shown with MacGraph that a neural network can learn to extract properties from nodes within a graph in response to questions and that a neural network can learn to perform graph algorithms (such as finding the shortest path) on graphs that it has never encountered before.