Do I know you? Flexible unsupervised and semi-supervised graph models with Deep Graph Infomax

Huon Wilson
stellargraph
Published in
7 min readMay 27, 2020
Photo by Sebastian Pena Lambarri on Unsplash

Imagine you could feed your data directly into a machine learning model and have it learn, without the need for any manual labelling. Imagine predictions in data-sparse environments could be improved without needing to label more data or even adjust the model structure. Thanks to Deep Graph Infomax — a graph machine learning algorithm that uses the graph structure to understand patterns in the data associated with each node — these are welcome realities.

Deep Graph Infomax is an unsupervised training procedure. A typical supervised task matches input data against input labels, to learn patterns in the data that are associated with the labels. Deep Graph Infomax skips the labels, and instead guides a model to learn from how input data points are connected as a graph, by understanding what should or shouldn’t be linked. It can thus be used for both unsupervised representations for these points and also pre-training a model to improve performance of semi-supervised tasks without much labelled data.

A graph is a collection of nodes and the edges between them, where the edges represent some connection or relationship between the nodes. Graph machine learning takes conventional machine learning models like multilayer perceptron and long short-term memory that apply to single data points (representing individual nodes) and augments them to use the information provided by each data point’s connections.

The StellarGraph library implements Deep Graph Infomax as a flexible training procedure for most of the models StellarGraph supports, as one of several methods that work without supervision. StellarGraph is an open source, user-friendly library for graph machine learning built on Tensorflow Keras.

Step-by-step: true versus corrupted

The core idea is learning to distinguish between the nodes of two graphs:

  • A true graph, consisting of the real nodes and the edges that connect them, along with feature vectors associated with each node
  • A corrupted graph, where the nodes and edges have been changed in some manner

A Deep Graph Infomax training procedure starts with the graph G and has four components:

  • A corruption procedure C. The corruption procedure changes the true input graph G into the mutated version H = C(G). The paper suggests randomly shuffling the node features among the nodes: H has the same edges as G, but the features associated with each node differ.
  • An encoder E. The encoder takes an input graph and computes an embedding vector v for each node. It is typically some graph machine learning model such as GCN or GraphSAGE.
  • A read-out R. The read-out collapses the separate embedding vectors for each node in a graph into a single summary vector s for the whole graph, like s = R(E(G)). This can be as simple as the sum or average, or something more complicated.
  • A discriminator D. The discriminator compares a node embedding vector against the graph summary vector, like D(v, s), to yield a “score” between 0 and 1 for each node embedding vector.
The Deep Graph Infomax algorithm, as a flow chart (adapted from Figure 1 in the paper). The input data is fed in as a graph G in the top left corner.

Starting with an input “true” graph G, the forward pass of a Deep Graph Infomax model mirrors those components very closely:

  1. Corrupt G to a new graph: H = C(G)
  2. Encode each node of both of these graphs: E(G) and E(H)
  3. Summarise the true graph into a summary vector: s = R(E(G))
  4. Score the encoded embedding vectors of both G and H, using the discriminator and G’s summary vector s: D(v, s) for v in E(G) and E(H)
  5. Collate all the scores in a loss function (equation (1), in the paper) that tries to maximise D(v, s) if v is the embedding vector of a true node and minimise it if v is the embedding vector of a corrupted node.

After training, this will have the discriminator gives scores close to 1 for the true nodes, and scores close to 0 for the corrupted nodes. As part of this, the weights of the encoder model are trained to be useful for this distinguishing step. Once training has completed, the encoder model can be used independently to compute node embedding vectors directly.

The algorithmic description tells us the individual steps, but it doesn’t give us a great intuition for what’s happening.

Intuition: identifying sensible connections

Deep Graph Infomax learns to tell which nodes should be connected. By shuffling the node features, nodes end up connected in “strange” ways. The model learns to distinguish between nodes that have sensible connections and nodes that have weird or unexpected connections.

Most graph algorithms summarise the neighbourhood of a node to compute a representation or embedding vector. The shuffling will change the shape and content of those neighbourhoods, in a way that the model/encoder E can learn to capture.

A simplified family tree could be a graph of people with edges from (biological) parent to child. The true graph G has patterns associated with the edges like “source age > target age", while the corrupted graph H does not. Deep Graph Infomax can help a model learn the difference.

For example, in a (biological) family tree with ages included as a data, a neighbourhood of a node (person) will be connected to two older nodes (parents), and zero or more younger nodes (children). If the edges are directed from parent to child, there’s rules like parents being older than children. Even without edge direction, there’s still rules like being connected to only two older nodes. There’s also softer trends, like the ages should be far enough apart, but not too far, reflecting the reasonable range of ages when people have children.

If the features (ages) are shuffled, these rules and patterns are likely to disappear. The model E’s computed embedding vectors summarise each node’s neighbourhood, and they will reflect this change.

The summary vector s serves as a reference point for the comparison, as the “average” true neighbourhood.

What’s the application?

A model trained with Deep Graph Infomax will yield representations or embedding vectors for each node. These embedding vectors can be used for tasks like clustering, community detection and nearest neighbour searches.

They can also be used as input to downstream tasks, like node classification or regression. The vectors become inputs to another model. This does require labels for each node, but the embedding vectors already capture useful information, and so fewer labels might be necessary.

Finally, the model E trained by Deep Graph Infomax can also be further trained and fine-tuned. The model already captures useful information, so again, fewer labels are necessary to get good results. This turns a supervised task into a semi-supervised one, where training with labels benefits from the existing patterns already captured by Deep Graph Infomax.

Node classification is a task that can benefit from pre-training with Deep Graph Infomax. StellarGraph’s demo of GCN, Deep Graph Infomax and fine-tuning works with just eight training examples, with one or two from each of the seven classes of nodes in the Cora dataset. Pre-training with Deep Graph Infomax and then fine-tuning with the training set gives dramatically better accuracy than just training a model directly on the training set. (Accuracies averaged across 50 runs, standard error of both is about 0.8 percentage points.)

How to get started?

In StellarGraph, using Deep Graph Infomax requires three pieces:

  • A base model and its data generator (this is the encoder E)
  • The corrupted data generator to do the feature shuffling (this is the corruption function C)
  • The Deep Graph Infomax model itself which coordinates the whole process including encoding, summarising and discriminating (this includes both the read-out R and the discriminator D).

Given input data in a graph G, each of these points translates directly into a short snippet of code. There’s a narrated demo that walks through the procedure, which we summarise here.

The base model is created in the same way for unsupervised training with Deep Graph Infomax and for supervised training in any normal way. For instance, we can construct a one-layer GCN model:

base_generator = FullBatchNodeGenerator(graph)
base_model = GCN(
layer_sizes=[128],
activations=["relu"],
generator=base_generator
)

Then, we need to create a data generator that does the shuffling of the node features. This is done with CorruptedGenerator, which automatically knows how to shuffle data for many models, including FullBatchNodeGenerator like we use for GCN:

corrupted_generator = CorruptedGenerator(base_generator)

Finally, we create the Deep Graph Infomax model itself, using the base GCN model and the corrupted generator. This model’s in_out_tensors yields the inputs and outputs required to build a Tensorflow Keras model, which can be trained with the conventional APIs like fit:

infomax = DeepGraphInfomax(base_model, corrupted_generator)
dgi_in, dgi_out = infomax.in_out_tensors()
dgi_model = tf.keras.Model(inputs=dgi_in, outputs=dgi_out)

Once the base model is trained via Deep Graph Infomax, it can be used to create a separate Keras model. This model can then be further fine-tuned with fit, or used to compute embedding vectors with predict:

embedding_in, embedding_out = base_model.in_out_tensors()embedding_model = tf.keras.Model(
inputs=embedding_in,
outputs=tf.squeeze(embedding_out, axis=0)
)

That’s it! With a couple of lines of code, we’ve connected a base GCN model into Deep Graph Infomax so that it can be trained without supervision, and then used the trained model to yield node embedding vectors. This is the same basic process used when fine-tuning the trained model for end-to-end node classification.

In summary —

Deep Graph Infomax is a procedure for training graph machine learning models without supervision. It trains a model to capture patterns in the connections between nodes and their features, by contrasting true nodes against corrupted/shuffled nodes. It does this without requiring any manual labels, but can be beneficial as an initialisation or pre-training procedure when one does have some labelled data.

StellarGraph implements it as a training procedure for many of its graph machine learning models. Run pip install stellargraph, and jump into the demos of Deep Graph Infomax for computing node embeddings and for pre-training for data-scarce environments.

The Deep Graph Infomax algorithm was initially implemented in StellarGraph by Kieran Ricardo. We’re working on generalising our Deep Graph Infomax support to easily work with other algorithms, like Relational GCN, Cluster-GCN and even generalise from nodes to whole graphs.

This work is supported by CSIRO’s Data61, Australia’s leading digital research network.

--

--