How Does AI Understand Graphs?

time is technically a graph

Mark Cleverley
The Startup
6 min readOct 11, 2020

--

It is difficult to overstate exactly how versatile graphs are. Almost everything that exists can be represented as a graph. I’m not just talking the obvious cases like molecules or social networks, either:

3D modeling is essentially graph structure, with vertices of polygons comprising complex objects.

Fluid dynamics? Absolutely.

Architecture? You bet.

I’m quite certain that even sheet music can be represented by a graph, once someone figures out how to turn a profit from it.

this is about as dense as deriving linear algebra

These ubiquitous and powerful structures are quite difficult to capture, however, when it comes to machine learning. I pondered my frustrations on binary encoding these structures a while back, and concluded that our brains understand graphs by constructing a neuron-synapse graph in our own grey matter.

This is a poetic but fairly useless tautology, so I decided to investigate contemporary solutions before worrying about binary representations.
As it turns out, some extremely clever folks have already figured out how to feed graphs into neural networks to great effect.

I’ll be drawing heavily from Oren Wright’s talk at Carnegie Mellon last year:

Searching for structure

If I had to state the simplest reason graphs work well:
Context matters.

There’s plenty of information about you as a person.
But consider for a moment: how much of that information exists in relation to something else?
How much of you is connected to something or someone else in the world?

Most of you, I’d wager. The world is a series of interconnected systems, and the most crucial information is usually found by examining those connections.

For over a decade, deep AI networks have found great success processing binary, scalar and image data to predict or classify various outputs.
But graph data runs into issues inherent to its structure: irregularity.

Time series data can be thought of as any series of scalars, really. A temporal component can be explicitly read with LSTM or HTM networks, but a vector of numbers is clearly linear.

With images, we can simply flatten the pixels into a single vector. This works well because a 2D image has regular structure: tessellating squares form a larger rectangle, so we can reduce to one dimension while maintaining the relationships between pixels relatively easily.

Flattening a graph, however, will do away with all of the valuable relational information between nodes.

What’s so good about convolution, anyway?

Convolutional neural networks are at the forefront of image recognition AI, for some very good reasons:

  1. Fixed parameters: relatively low memory footprint
  2. Local kernel: allows construction of low to high hierarchies of information layers
  3. Spatially invariant: features/objects can be learned/predicted regardless of local ‘position’ in image

To accomplish this, they use the convolutional layer, which takes an image (pixel vector) and feeds it through a filter.

A convolutional filter is a small square matrix that you pass over pixels in an image to “boil down” the information into a smaller, condensed image.
Instead of looking at “what precise values are at these exact pixels?”, it looks at “what general values are around this neighborhood of pixels?”:

Using the convolutional filter I (kernel size 3x3) on a pixel in the image K , we can multiply the overlapping matrices element-wise and sum the products: 4 + 1 + 3 = 8.

This is trickier, but still possible, to do with graphs. We want to aggregate information in local neighborhoods of the graph, like we did with pixels in the image.

Time is really just a graph

Let’s go back to a time series vector. Time series data operates with steps, or shifts, to denote intervals between data measurement.
Let’s say that in the time series vector (1 -> 2 -> 3 -> 4 ->), 1 comes one shift before 2, which precedes 3, and 4 loops back to 1. We can draw a 4x4 shift matrix to represent the relationships between time series inputs.

If that shift matrix looks familiar, that’s because it’s technically also an adjacency matrix — a matrix that represents connected nodes on a graph. [column 1, row 2] = 1 indicates that there’s an edge, or connection, between the first and second nodes.

Multiplying the time series signal (vector representation) by its shift matrix gives us the time-shifted signal, which preserves a good deal of information for neural network operation.
Since graphs have their own signal and adjacency matrix, we can perform the same task to make a graph-shifted signal.

This is because any linear shift invariant filter can be represented as a polynomial of shifts. If you’d like to dig further into the math, Henry AI Labs has a great breakdown video that details the use of self-loops and degree matrices.

A graph-shifted signal allows us to run convolutional operations similar to an image kernel filter on graph data, through the magic of vector propagation:

Each node in this example graph contains three data points, together forming a data vector. Let’s say we want to figure out the neighborhood information of the pink node above.

To gather “neighborhood average information” in a similar way to convolutional filters, we can blend together (or simply average) neighboring node information to construct a node’s convoluted data vector. The pink node’s propagated vector will then be 0.7, 0.3, 0.7.

A convolutional layer in a GCN might perform this process for each node before passing the ‘estimated’ graph through to the next layer. It gets much more complicated than this when we dive into multi-layered GCNs, but the same principles apply throughout.

Deep graph learning

There’s two major tasks that graph convolutional networks (GCNs) perform:

  1. Node classification: based on labeled nodes, predict info about unlabeled nodes
  2. Graph classification: based on labeled graphs, predict info about new graphs

Wright gives an example of node classification GCNs with a citation network, where each node is an academic paper (labeled by field & topic) and each edge between nodes a formal citation. The topics, information and application of new papers can then be effectively predicted by which journals they cite.

Graph classification is absolutely fascinating, since it’s the equivalent of image classification (or object recognition) that CNNs often perform.

Electroencephalogram (EEG) readings can be modeled as graphs and used to predict human emotion based on brain activity.
After training a graph network on labeled molecules, it can generate new proposed molecules that can aid in drug and material design.

source

This is quite radical to me, to be quite honest. Graphs have irregular structure that changes with information content — like water, if you increase the volume you probably won’t be able to practically maintain the same shape.

Yet there we have it: by reconsidering the idea of structure, we can pass through every node in a graph and gather relational information, eventually making predictions of nodes (“filling in the blanks”) or making predictions of entire graphs (“inferring overall purpose”).

How we can encode this to a binary vector is another case entirely, but it’ll probably involve the adjacency matrix in some regard.

--

--

Mark Cleverley
The Startup

data scientist, machine learning engineer. passionate about ecology, biotech and AI. https://www.linkedin.com/in/mark-s-cleverley/