A Principled Approach to Aggregations

PyTorch Geometric
7 min readNov 2, 2022

--

By Guohao Li

PyG released version 2.1.0 with contributions from over 60 contributors. One of the primary features is the new aggregation operator package that allows you to customize your GNNs with different aggregations. The author, Guohao, is part of the PyG team that developed this new feature. See the complete roadmap here, along with the discussions from all the contributors. In this post, we will introduce the motivation for building it and how to use it in practice.

Graph Neural Networks (GNNs) have become a master key to unlocking many applications that require dealing with non-grid data structures. One of the most critical components of GNNs is the aggregation function. It may account for the symmetry and invariance [1, 2], and the expressive power [3, 4] of GNNs.

Different aggregation functions are good at capturing the different types of properties of graphs [3,4]. For instance, mean aggregation is good at learning the neighborhood distribution, max aggregation can identify the most representative features, and sum aggregation is capable of learning the structural properties of graphs such as degree information.

Many other works [5, 6, 7] also empirically found that the choice of aggregation functions is crucial to the performance of GNNs on different datasets.

Below you can see how different aggregators perform on different data sets.

Recent works also show that using multiple aggregations [4] and learnable aggregations [6] can potentially gain substantial improvements.

Inspired by this work, we made the concept of aggregation a first-class principle in PyG, which allows you to DIY your GNNs like playing legos. As of now, PyG provides support for various aggregations — from simple ones (e.g., mean, max, sum), to advanced ones (e.g., median, var, std), learnable ones (e.g., SoftmaxAggregation, PowerMeanAggregation), and exotic ones (e.g., LSTMAggregation, SortAggregation, EquilibriumAggregation). Please feel free to open an issue if you want more!

Available Aggregation Functions on Py 2.1.0.

The more detailed documentation and design principles of the aggregation package can be found in the PyG documentation. Now let’s install PyG 2.1.0 and try them out on a real dataset! The following experiment is available in this Ipython Notebook.

Customizing Aggregations within Message Passing

To facilitate experimentation with these different aggregation schemes and unify concepts of aggregation within GNNs across both MessagePassing and global readouts, we provide modular and re-usable aggregations in the newly defined torch_geometric.nn.aggr.* package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:

aggr_api.py package

In this tutorial, we explore the new aggregation package with SAGEConv [5] and ClusterLoader [8] and showcase on the PubMed graph from the Planetoid node classification benchmark suite [9].

Loading the dataset

Let’s first load the Planetoid dataset and create subgraphs with ClusterData for training.

aggr_load_data.py package

Define train, test, and run functions

Here we define a simple run function for training the GNN model.

aggr_run.py package

Define a GNN class and Import Aggregations

Now, let’s define a GNN helper class and import all those new aggregation operators!

aggr_gnn.py package

Original interface with string type as the aggregation argument

Previously, PyG only supports customizing MessagePassing with simple aggregations by passing a string (e.g., 'mean' , 'max' , 'sum' ). Let’s define a GNN with mean aggregation and run it for 5 epochs.

aggr_str.py package

Use a single aggregation module as the aggregation argument

With the new interface, the MessagePassing class can take an Aggregation module as an argument. Here we can define the mean aggregation by MeanAggregation. We can see the model achieves the same performance as previously.

aggr_single_module.py package

Use a list of aggregation strings as the aggregation argument

For defining multiple aggregations, we can use a list of strings as the input argument. The aggregations will be resolved from pure strings via a lookup table, following the design principles of the class-resolver library, e.g., by simply passing in "mean" to the MessagePassing module. This will automatically resolve it to the MeanAggregation class. Let’s see how a PNA-like GNN [4] works. It converges much faster!

aggr_multi_str.py package

Use a list of aggregation modules as the aggregation argument

You can also use a list of Aggregation modules to specify your convolutions.

aggr_multi_module.py package

Use a list of mixed modules and strings as the aggregation argument

And the mix of them is supported as well for your convenience.

aggr_mix_str_module.py package

Define multiple aggregations with MultiAggregation module

When a list is taken, MessagePassing would stack these aggregators in via the MultiAggregation module automatically. But you can also directly pass a MultiAggregation instead of a list. Now let’s see how can we define multiple aggregations with MultiAggregation. Here we use different initial temperatures for SoftmaxAggregation[6]. Every different temperature will result in aggregation with different softness.

aggr_multi_aggr.py package

There’s More!

There are many other aggregation operators supported for you to “lego” your GNNs. PowerMeanAggregation [6] allows you to define and potentially learn generalized means beyond simple arithmetic mean such as harmonic mean and geometric mean. LSTMAggregation [5] can perform permutation-variant aggregation. More other interesting aggregation operators such as Set2Set[10], DegreeScalerAggregation [4], SortAggregation[11], GraphMultisetTransformer[12], AttentionalAggregation [13, 14] and EquilibriumAggregation[15] are ready for you to explore.

Conclusion

In this tutorial, you have been presented with the torch_geometric.nn.aggr package which provides a flexible interface to experiment with different aggregation functions with your message passing convolutions and unifies aggregation within GNNs across MessagePassing and global readouts. This new abstraction also makes designing new types of aggregation functions easier. Now, you can create your own aggregation function with the base Aggregation class. Please refer to the docs for more details.

aggr_base.py package

Have fun!

[1] Battaglia, Peter W., Jessica B. Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti et al. “Relational inductive biases, deep learning, and graph networks.” arXiv preprint arXiv:1806.01261 (2018).

[2] Bronstein, Michael M., Joan Bruna, Taco Cohen, and Petar Veličković. “Geometric deep learning: Grids, groups, graphs, geodesics, and gauges.” arXiv preprint arXiv:2104.13478 (2021).

[3] Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. “How powerful are graph neural networks.” ICLR. Keyulu Xu Weihua Hu Jure Leskovec and Stefanie Jegelka (2019).

[4] Corso, Gabriele, Luca Cavalleri, Dominique Beaini, Pietro Liò, and Petar Veličković. “Principal neighbourhood aggregation for graph nets.” Advances in Neural Information Processing Systems 33 (2020): 13260–13271.

[5] Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in neural information processing systems 30 (2017).

[6] Li, Guohao, Chenxin Xiong, Ali Thabet, and Bernard Ghanem. “Deepergcn: All you need to train deeper gcns.” arXiv preprint arXiv:2006.07739 (2020).

[7] You, Jiaxuan, Zhitao Ying, and Jure Leskovec. “Design space for graph neural networks.” Advances in Neural Information Processing Systems 33 (2020): 17009–17021.

[8] Chiang, Wei-Lin, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, and Cho-Jui Hsieh. “Cluster-gcn: An efficient algorithm for training deep and large graph convolutional networks.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 257–266. 2019.

[9] Yang, Zhilin, William Cohen, and Ruslan Salakhudinov. “Revisiting semi-supervised learning with graph embeddings.” In International conference on machine learning, pp. 40–48. PMLR, 2016.

[10] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. “Order matters: Sequence to sequence for sets.” arXiv preprint arXiv:1511.06391 (2015).

[11] Zhang, Muhan, Zhicheng Cui, Marion Neumann, and Yixin Chen. “An end-to-end deep learning architecture for graph classification.” In Proceedings of the AAAI conference on artificial intelligence, vol. 32, no. 1. 2018.

[12] Baek, Jinheon, Minki Kang, and Sung Ju Hwang. “Accurate learning of graph representations with graph multiset pooling.” arXiv preprint arXiv:2102.11533 (2021).

[13] Li, Yujia, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. “Gated graph sequence neural networks.” arXiv preprint arXiv:1511.05493 (2015).

[14] Veličković, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. “Graph attention networks.” arXiv preprint arXiv:1710.10903 (2017).

[15] Bartunov, Sergey, Fabian B. Fuchs, and Timothy P. Lillicrap. “Equilibrium aggregation: encoding sets via optimization.” In Uncertainty in Artificial Intelligence, pp. 139–149. PMLR, 2022.

--

--

PyTorch Geometric

Open-source framework for working with Graph Neural Networks