A Principled Approach to Aggregations
By Guohao Li
PyG released version 2.1.0 with contributions from over 60 contributors. One of the primary features is the new aggregation operator package that allows you to customize your GNNs with different aggregations. The author, Guohao, is part of the PyG team that developed this new feature. See the complete roadmap here, along with the discussions from all the contributors. In this post, we will introduce the motivation for building it and how to use it in practice.
Graph Neural Networks (GNNs) have become a master key to unlocking many applications that require dealing with non-grid data structures. One of the most critical components of GNNs is the aggregation function. It may account for the symmetry and invariance [1, 2], and the expressive power [3, 4] of GNNs.
Different aggregation functions are good at capturing the different types of properties of graphs [3,4]. For instance, mean aggregation is good at learning the neighborhood distribution, max aggregation can identify the most representative features, and sum aggregation is capable of learning the structural properties of graphs such as degree information.
Many other works [5, 6, 7] also empirically found that the choice of aggregation functions is crucial to the performance of GNNs on different datasets.
Below you can see how different aggregators perform on different data sets.
Recent works also show that using multiple aggregations [4] and learnable aggregations [6] can potentially gain substantial improvements.
Inspired by this work, we made the concept of aggregation a first-class principle in PyG, which allows you to DIY your GNNs like playing legos. As of now, PyG provides support for various aggregations — from simple ones (e.g., mean
, max
, sum
), to advanced ones (e.g., median
, var
, std
), learnable ones (e.g., SoftmaxAggregation
, PowerMeanAggregation
), and exotic ones (e.g., LSTMAggregation
, SortAggregation
, EquilibriumAggregation
). Please feel free to open an issue if you want more!
The more detailed documentation and design principles of the aggregation package can be found in the PyG documentation. Now let’s install PyG 2.1.0 and try them out on a real dataset! The following experiment is available in this Ipython Notebook.
Customizing Aggregations within Message Passing
To facilitate experimentation with these different aggregation schemes and unify concepts of aggregation within GNNs across both MessagePassing
and global readouts, we provide modular and re-usable aggregations in the newly defined torch_geometric.nn.aggr.*
package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:
In this tutorial, we explore the new aggregation package with SAGEConv
[5] and ClusterLoader
[8] and showcase on the PubMed
graph from the Planetoid
node classification benchmark suite [9].
Loading the dataset
Let’s first load the Planetoid
dataset and create subgraphs with ClusterData
for training.
Define train, test, and run functions
Here we define a simple run
function for training the GNN model.
Define a GNN class and Import Aggregations
Now, let’s define a GNN helper class and import all those new aggregation operators!
Original interface with string type as the aggregation argument
Previously, PyG only supports customizing MessagePassing
with simple aggregations by passing a string
(e.g., 'mean'
, 'max'
, 'sum'
). Let’s define a GNN with mean
aggregation and run it for 5 epochs.
Use a single aggregation module as the aggregation argument
With the new interface, the MessagePassing
class can take an Aggregation
module as an argument. Here we can define the mean
aggregation by MeanAggregation
. We can see the model achieves the same performance as previously.
Use a list of aggregation strings as the aggregation argument
For defining multiple aggregations, we can use a list of strings as the input argument. The aggregations will be resolved from pure strings via a lookup table, following the design principles of the class-resolver library, e.g., by simply passing in "mean"
to the MessagePassing
module. This will automatically resolve it to the MeanAggregation
class. Let’s see how a PNA-like GNN [4] works. It converges much faster!
Use a list of aggregation modules as the aggregation argument
You can also use a list of Aggregation
modules to specify your convolutions.
Use a list of mixed modules and strings as the aggregation argument
And the mix of them is supported as well for your convenience.
Define multiple aggregations with MultiAggregation
module
When a list is taken, MessagePassing
would stack these aggregators in via the MultiAggregation
module automatically. But you can also directly pass a MultiAggregation
instead of a list. Now let’s see how can we define multiple aggregations with MultiAggregation
. Here we use different initial temperatures for SoftmaxAggregation
[6]. Every different temperature will result in aggregation with different softness.
There’s More!
There are many other aggregation operators supported for you to “lego” your GNNs. PowerMeanAggregation
[6] allows you to define and potentially learn generalized means beyond simple arithmetic mean such as harmonic mean and geometric mean. LSTMAggregation
[5] can perform permutation-variant aggregation. More other interesting aggregation operators such as Set2Set
[10], DegreeScalerAggregation
[4], SortAggregation
[11], GraphMultisetTransformer
[12], AttentionalAggregation
[13, 14] and EquilibriumAggregation
[15] are ready for you to explore.
Conclusion
In this tutorial, you have been presented with the torch_geometric.nn.aggr
package which provides a flexible interface to experiment with different aggregation functions with your message passing convolutions and unifies aggregation within GNNs across MessagePassing
and global readouts. This new abstraction also makes designing new types of aggregation functions easier. Now, you can create your own aggregation function with the base Aggregation
class. Please refer to the docs for more details.
Have fun!
[1] Battaglia, Peter W., Jessica B. Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti et al. “Relational inductive biases, deep learning, and graph networks.” arXiv preprint arXiv:1806.01261 (2018).
[2] Bronstein, Michael M., Joan Bruna, Taco Cohen, and Petar Veličković. “Geometric deep learning: Grids, groups, graphs, geodesics, and gauges.” arXiv preprint arXiv:2104.13478 (2021).
[3] Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. “How powerful are graph neural networks.” ICLR. Keyulu Xu Weihua Hu Jure Leskovec and Stefanie Jegelka (2019).
[4] Corso, Gabriele, Luca Cavalleri, Dominique Beaini, Pietro Liò, and Petar Veličković. “Principal neighbourhood aggregation for graph nets.” Advances in Neural Information Processing Systems 33 (2020): 13260–13271.
[5] Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in neural information processing systems 30 (2017).
[6] Li, Guohao, Chenxin Xiong, Ali Thabet, and Bernard Ghanem. “Deepergcn: All you need to train deeper gcns.” arXiv preprint arXiv:2006.07739 (2020).
[7] You, Jiaxuan, Zhitao Ying, and Jure Leskovec. “Design space for graph neural networks.” Advances in Neural Information Processing Systems 33 (2020): 17009–17021.
[8] Chiang, Wei-Lin, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, and Cho-Jui Hsieh. “Cluster-gcn: An efficient algorithm for training deep and large graph convolutional networks.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 257–266. 2019.
[9] Yang, Zhilin, William Cohen, and Ruslan Salakhudinov. “Revisiting semi-supervised learning with graph embeddings.” In International conference on machine learning, pp. 40–48. PMLR, 2016.
[10] Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. “Order matters: Sequence to sequence for sets.” arXiv preprint arXiv:1511.06391 (2015).
[11] Zhang, Muhan, Zhicheng Cui, Marion Neumann, and Yixin Chen. “An end-to-end deep learning architecture for graph classification.” In Proceedings of the AAAI conference on artificial intelligence, vol. 32, no. 1. 2018.
[12] Baek, Jinheon, Minki Kang, and Sung Ju Hwang. “Accurate learning of graph representations with graph multiset pooling.” arXiv preprint arXiv:2102.11533 (2021).
[13] Li, Yujia, Daniel Tarlow, Marc Brockschmidt, and Richard Zemel. “Gated graph sequence neural networks.” arXiv preprint arXiv:1511.05493 (2015).
[14] Veličković, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. “Graph attention networks.” arXiv preprint arXiv:1710.10903 (2017).
[15] Bartunov, Sergey, Fabian B. Fuchs, and Timothy P. Lillicrap. “Equilibrium aggregation: encoding sets via optimization.” In Uncertainty in Artificial Intelligence, pp. 139–149. PMLR, 2022.