Building Models that Learn to Discover Structure and Relations

Thomas Kipf, Ethan Fetaya, Jackson Wang, Max Welling, Rich Zemel

Some argue that a key component of human intelligence is our ability to reason about objects and their relations (e.g. [1,2]). This enables us, for example, to build rich compositional models of physics (how objects or particles interact) and intuitive theories of causation (what causes what) [3].

For artificial systems, these tasks remain a challenge. Most sophisticated pattern recognition models, e.g. based on Convolutional Neural Networks (CNNs) or Recurrent Neural Networks (RNNs), lack a certain relational inductive bias [4]; impeding their ability to generalize well on problems with inherent compositional structure.

In our recent ICML (2018) paper: Neural Relational Inference for Interacting Systems, we explore a class of models named Graph Neural Networks (GNNs) that reflect the inherent structure of the problem domain in their model architecture¹. This enables variants of GNNs, for example, to learn to predict physical dynamics of an interacting system (e.g. billiard balls on a table) [5] or to reason about relations between objects in a given image [6].

Physical simulation of particles coupled by invisible springs. The connectivity is given by a hidden interaction graph. Unconnected particles do not interact.

In our work, we investigate whether this class of models is also capable of recognizing the underlying structure and types of relations in the data we observe in a completely unsupervised way (i.e. without ever showing it the ground truth relations or interactions).

To illustrate this task, let us consider a more concrete setting: we are looking at a physical simulation of balls rolling on a 2D surface (see Figure above), and some of the balls are connected by (invisible) springs that create an attractive force. If we knew the position and velocity of every ball, including their connectivity structure, we would be able to predict where they are going to move next.

Can you guess which balls are connected by springs? The solution appears after a few seconds.

Without knowing this latent interaction graph, predicting the system’s dynamics can be quite difficult. Similarly, the task of inferring which ball is connected to which by a spring is challenging in the first place. Have a look at the video to the left (or above, in case you’re viewing this on your phone) and see if you are able to correctly guess the interaction graph.You will notice that once the interaction structure is shown, it is suddenly much easier and highly intuitive to understand the dynamics.

In our work, we give GNNs the task to simultaneously infer this latent interaction structure and to predict the dynamics of the interacting system. After showing the model a wide set of simulations, it can recognize these hidden relations and give accurate predictions in 99.9% of the cases², interestingly without ever using a ground truth interaction graph example in training.

Our Neural Relational Inference (NRI) model can be seen as an auto-encoder where the task of the encoder is to create a hypothesis about how the system interacts and the decoder learns a dynamical model of the interacting system constrained by the encoder’s “interaction hypothesis”. We frame this as a probabilistic model where the latent code corresponds to a distribution over relation types between objects (see Figure below).

Neural Relational Inference model architecture. Both encoder and decoder are based on Graph Neural Networks (GNNs) and the latent code represents a distribution over relation types between objects.

The encoder is based on a GNN which passes messages between all pairs of objects, i.e. it operates on the fully-connected graph. After multiple steps of message passing, each node (i.e. object) is informed about its relations to its neighbors, and also how its neighbors relate to other nodes, and so on. This enables the model to predict which two objects are likely to interact with each other given a certain type of interaction (e.g. attractive or repulsive).

From this predictive distribution³ over interaction types, we sample a particular type of interaction for each pair of objects. The crucial part to make this model work now lies in how we construct the decoder: we model the decoder using another GNN, but this time the GNN is limited to pass messages (of a certain type) only along the edges predicted by the encoder. This structural constraint forces the model into finding the correct underlying structure of interactions, as it would otherwise not be able to predict the future dynamics of the system.

We test this simple structural constraint on a number of different environments, ranging from physical simulations with spring forces, charged particles and phase-coupled oscillators (Kuramoto model) to real-world settings in the form of human motion capture data and basketball sports analytics data. We find that our NRI model can predict the dynamics of these interacting systems at high precision (see examples in Figure below); while discovering latent interaction graphs that accurately represent the underlying structure of the problem (for physical simulations) or provide insightful explanations (for real-world data).

Example trajectories predicted by the NRI model in comparison to the ground truth dynamics of the system. Shaded lines denote ground truth trajectories which are provided to the NRI model as input (to condition the generative process).

Our model only provides a first step towards designing systems that can build rich, intuitive and compositional theories of their environment. Yet, many questions still remain unanswered. For instance, how can learned structural representations help us generalize to new unseen tasks or how can we effectively constrain models so that their learned latent structure corresponds to causal relationships? This is an important research direction that has only received very little attention in the deep learning community and we hope that our work will inspire and guide future efforts in this area.


  1. ^ At a high level, a GNN passes vector-valued messages, parameterized by small neural networks, along the edges of a graph. The overall per-layer update rule can be seen as a node-to-edge (v→e) transformation followed by an edge-to-node (e→v) aggregation function. For more details, have a look at our paper.
  2. ^ For a system with 5 particles randomly connected with springs.
  3. ^ NRI learns a probability distribution over the structure and types of relations between objects. This allows us to systematically include prior beliefs, e.g. about the sparsity of the latent relational graph.

Code and citation

Our PyTorch implementation of the NRI model (including code for reproducing some of the experiments) is available on GitHub:

If you found this post useful, please consider citing our paper and share your insights in the comments section.

@inproceedings{kipf2018neural,
title={Neural Relational Inference for Interacting Systems},
author={Kipf, Thomas and Fetaya, Ethan and Wang, Kuan-Chieh and Welling, Max and Zemel, Richard},
booktitle={International Conference on Machine Learning (ICML)},
year={2018}
}

Acknowledgements

Thanks to Ethan Fetaya and Petar Veličković for providing feedback on an earlier draft of this blog post.


References

[1] Spelke, Elizabeth S., Katherine D. Kinzler. “Core knowledge.” Developmental Science 10 (2007).
[2] Lake, Brenden M., Tomer D. Ullman, Joshua B. Tenenbaum, and Samuel J. Gershman. “Building Machines that Learn and Think Like People.” Behavioral and Brain Sciences 40 (2017).
[3] Gerstenberg, Tobias, Joshua B. Tenenbaum. “Intuitive theories.” Oxford handbook of causal reasoning (2017).
[4] Battaglia, Peter W. et al. “Relational inductive biases, deep learning, and graph networks.” arXiv preprint arXiv:1806.01261 (2018).
[5] Battaglia, Peter W., Razvan Pascanu, Matthew Lai, Danilo Rezende, Koray Kavukcuoglu. “Interaction Networks for Learning about Objects, Relations and Physics.” In Advances in Neural Information Processing Systems (NIPS) 2016.
[6] Santoro, Adam, David Raposo, David G. Barrett, Mateusz Malinowski, Razvan Pascanu, Peter Battaglia, and Tim Lillicrap. “A simple neural network module for relational reasoning.” In Advances in Neural Information Processing Systems (NIPS) 2017.