Graph Neural Networks for System Interaction Inference

Aljubrmj
Stanford CS224W GraphML Tutorials
8 min readJan 16, 2022
Unsupervised adjacency matrix prediction using graph neural networks.

This blog post was authored by Mohammad (Jabs) Aljubran as part of the Stanford CS224W course project, and is mostly based on T. Kipf, E. Fetaya, K. Wang, M. Welling and R. Zemel, Neural Relational Inference for Interacting Systems (2018). All results can be reproduced using this public Google Colab Notebook.

Overview

Think of a social network, where you are given information about the network users and how they interact with one another. You are then asked to figure out what links/relationships, if any, exist between those people. In other words, your task is to figure out the interplay of components in a complex dynamical system. Actually, such systems exist in many fields, such as sociology, sports, physics, biology. etc.

In physics, the challenge is that we do not know 1) the physical laws governing these dynamics, and/or 2) the system condition and properties. Rather, we have access to historical measurements and observations from installed sensors and devices. Researchers are often interested in modeling the physical interactions between system components and forecasting their behavior [1].

This can very well be viewed as a graphical problem! For example, think of multiple balls bouncing inside a box where some are connected with springs. Below is a demonstration of a very simple case with only two springs. Can you guess which balls are linked by springs?

Example trajectories of balls bouncing inside a box, where some are connected by springs. Actually, there exist two springs in this simulation. Can you guess which balls the two springs connect? Find the answer below.

Answer: one spring is between Ball-1 and Ball-5 and the other is between Ball-2 and Ball-3, while Ball-4 is totally isolated. See graph drawing below.

Graphical representation of balls (nodes) bouncing in a box with springs (edges) connecting some of them. In this case, there are two springs connecting (Ball-1, Ball-5) and (Ball-2, Ball-3).

Given the observed trajectories of these balls, your task is to find out which balls are connected by springs. I will demonstrate two unsupervised approaches to predict springs (graph edges) by learning how to forecast ball (graph nodes) trajectories. The first approach, “multi-simulation”, is based on training a model on multiple simulations while the second approach, “single-simulation”, is based on training a model on a single simulation only. Note that the latter setting is common in many real-world applications where physical events cannot be repeated and simulated multiple times. Hence, I decided to introduce it here. In this post, you will learn how to achieve this task using neural relational inference (NRI) models based on graph neural networks (GNNs)!

Chart explaining how multi-simulation and single-simulation models are different during the training and evaluation phases.

Outline

Now we are ready to find out how GNNs can solve this problem! This blog post will cover the following:

  • Dataset: description of how to generate your own dataset by simulating bouncing balls inside a box using physical laws and governing equations.
  • Graph Neural Networks: a brief explanation of the concept behind GNNs.
  • Multi-Simulation Model: description and results of GNN variational auto-encoder as implemented in PyTorch based on [2].
  • Single-Simulation Model: description and results of training GNN for the prediction of a weighted adjacency matrix with PyG implementation (disclaimer: I came up with the approach and PyG implementation).

Dataset

The dataset is generated based on a physics-based simulator. Say you have N balls bouncing inside a 2D box, such that each pair of balls is randomly connected with a spring. Each ball is initiated with a location vector sampled from Gaussian 𝒩(0, 0.5) and a random velocity vector with norm 0.5. Aside from the elastic collision with the box walls, there are no external forces acting on the balls. When a pair of balls is connected with a spring, the spring force Fᵢⱼ of ball vᵢ on ball vⱼ is modeled using Hooke’s law Fᵢⱼ = −k(xᵢ − xⱼ), where k is the spring constant and xᵢ is the Cartesian vector location of ball vᵢ. The ball trajectories are then simulated by solving partial differential equations defined based on Newton’s equations of motion. This simulator is written in Python [2] where it adopts leapfrog integration to numerically solve the system of equations.

Without loss of accuracy, I limited the train, valid, and test number of simulations to 100, 25, and 25, respectively, to make sure data generation is fairly fast. Whereas these simulations can be customized, I followed the same simulation parameters used in [2]. Meanwhile, preprocessing was limited to normalizing the data to a range of [−1,1] using the training split statistics.

Graph Neural Networks

In the described dataset, the goal is to predict whether or not any two balls are connected by a spring in an unsupervised manner. In both the multi-simulation encoder and single-simulation GNN architectures, springs are first hypothesized to exist between every two balls (fully connected graph).The model then learns the likelihood of these hypotheses, hence predicting a weighted adjacency matrix indicating the probability of each edge. GNNs are particularly suitable for this task because they consist of multiple permutation invariant/equivariant functions [3]. GNNs are based on a node-to-node, two-step process: message passing and aggregation. In this post, we present a more generalized formulation, briefly, that incorporates features of both nodes and edges.

Formally, take a graph 𝒢 = (𝒱, ℰ), with vertices v ∈ 𝒱 and edges e ∈ ℰ. Let γˡᵢ and γˡ₍ᵢ,ⱼ₎ be the lᵗʰ layer property vectors of node vᵢ and edge e₍ᵢ,ⱼ₎, such that γ can be a feature vector (h) or message vector (m). Also, let ηˡᵥ and ηˡₑ be the lᵗʰ layer transformation functions such that η can be a message function (MSG), aggregation function (AGG), or nonlinear function (f). We will use these symbols as we explain GNNs visually. Consider the graph below where you want to apply a two-layer GNN, such that Ball-5 is the target node. For demonstration purposes, let us focus on the links between Ball-2 and its children in the second layer. Step 1: Start by drawing a subtree structure. Then, initialize the node feature vectors to the initial input features, hᵢ⁽⁰⁾= xᵢ and h₍ᵢ,ⱼ₎⁽⁰⁾ = x₍ᵢ,ⱼ₎i, j[1, N]

Step 1 requires the construction of subtree structures, which can also be used as computational graphs.

You may now think about the subtree structure as a computational graph. Step 2: For each node in the subtree, apply node message transformation. Step 3: Compute node-to-edge message passing by applying edge aggregation and nonlinear transformation. Step 4: For each edge in the subtree, apply edge message transformation. Step 5: Compute edge-to-node message passing by applying node aggregation and nonlinear transformation. Note that skip connections are also incorporated in this description.

Steps 2–5 of the node-to-node, two-step GNN process of message passing and aggreation.

Thus, we just completed a round of node-to-node message passing and aggregation. These steps must be done for the entire graph for as many layers as desired. The below sketch summarizes this graph computation.

Summary of how to construct a computational graph based on subtree structures for a target node.

Multi-Simulation Model

Summary of the NRI variational autoencoder using GNN architectures, as introduced in [2].

Assuming the reader is familiar with variational autoencoder deep learning algorithms [4], I will only provide a brief description. Denoting all observed trajectories for a single simulation by x, the encoder q (z|x) represents a factorized discrete distribution z parameterized by ϕ, where zᵢⱼ resembles the type of e₍ᵢ,ⱼ₎ (i.e. whether or not e₍ᵢ,ⱼ₎ exists in the dataset of interest). Note that this technique can model multiple edge types as well, but this is not desired here as we only care about whether or not springs exist. To enable backpropagation, the discrete q(z|x) is transformed using a continuous approximation before applying the reparametrization trick [5]. Denoting x as the set of features of all balls N at time t, the decoder is defined to forecast ball trajectories and model p(xᵗ⁺¹|xᵗ, xᵗ⁻¹, …, x¹, z), parameterized by θ using a GNN whose adjacency matrix is drawn from the learned encoder distribution q(z|x). The model is then trained in a similar fashion to variational autoencoders where the evidence lower bound (ELBO) is maximized. For simplicity, I will use multilayer perceptron (MLP) nonlinear transformations throughout this implementation to allow for faster training.

Using the aforementioned dataset and training the described GNN variational autoencoder, we can learn to predict which balls are connected by springs with a test accuracy of 88.5%! Below are test examples, showing the ground truth versus predicted adjacency matrix (1: spring exists, 0: spring does not exist), indicating the model's accuracy. Simultaneously, the model can forecast the ball trajectories, and below I show an example of forecasting one test simulation case. Note that, using 10 timesteps for teacher forcing, the forecasts are also decent! For improved results, you may use recurrent neural networks (RNNs) as encoders and decoders in exchange for some compute.

Adjacency matrix prediction performance of the GNN variational autoencoder. Note that the output is normalized between [0,1], where values greater than 0.5 indicate a spring exists between the respective balls.
Trajectory forecast performance of the GNN variational autoencoder. Note that we are using 10-timestep teacher forcing in this demonstration.

Single-Simulation Model

In this part, the model learns to predict the springs for a given simulation without prior training on any other simulation runs, hence this is significantly more challenging. I used PyG to implement the graph convolutional network (GCN) [6], such that the model learns the adjacency matrix as part of the forecasting process. The code below shows the GNN block I developed using the PyG framework. Note how this GNN hypothesizes that the input graph is fully connected at first, then learns a weighted adjacency matrix to reject some of these hypotheses and predict where springs exist, if any. By construction, this dataset deals with symmetric adjacency matrices, hence this is enforced algorithmically in line 32 of the code snippet below.

Since adjacency matrices can be sparse, we will evaluate this model using precision, recall, F-1 score besides accuracy. Below is a demonstration of two model tests. Training only a single simulation run in an unsupervised setting, the single-simulation model does great with perfect predictions! Note that the title of each plot indicates the metrics at the last training epoch.

Example of the single-simulation model prediction of the adjacency matrix. This case involves 5 balls, and the GNN is implemented in PyG.
Another example of the single-simulation model prediction of the adjacency matrix. This case involves 5 balls, and the GNN is implemented in PyG.
Another example of the single-simulation model prediction of the adjacency matrix. This case involves 5 balls, and the GNN is implemented in PyG.

Bonus: Would this work for many bouncing balls? Let us try to stick 10 balls in the box and see what happens … it works fairly well! Please visit my Google Colab Notebook to run these experiments yourself!

Another example of the single-simulation model prediction of the adjacency matrix. This case involves 10 balls, and the GNN is implemented in PyG.

[1] S. Sukhbaatar and R. Fergus, Learning multiagent communication with backpropagation (2016), NeurIPS 2016 Advances in Neural Information Processing Systems.

[2] T. Kipf, E. Fetaya, K. Wang, M. Welling and R. Zemel, Neural Relational Inference for Interacting Systems (2018), ICML 2018 Networks and Relational Learning 2.

[3] N. Keriven and G. Peyre, Universal Invariant and Equivariant Graph Neural Networks (2019), NeurIPs 2019 Advances in Neural Information Processing Systems.

[4] X. Chen, D. Kingma, T. Salimans, Y. Duan, P. Dhariwal, J. Schulman, I. Sutskever and P. Abbeel, Variational Lossy Autoencoder (2017), ICML 2017.

[5] C. Maddison, A. Mnih and Y. Teh, The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (2017), ICLR 2017.

[6] T. Kipf and M. Welling, Semi-Supervised Classification with Graph Convolutional Networks (2017), ICLR 2017.

--

--