Adaptive Message Passing: Learning to Mitigate Oversmoothing, Oversquashing, and Underreaching

Federico Errica
10 min readDec 29, 2023

--

This blog post summarizes the findings of our new contribution:

Title: Adaptive Message Passing: A General Framework to Mitigate Oversmoothing, Oversquashing, and Underreaching
Authors: F. Errica, H. Christiansen, V. Zaverkin, T. Maruyama, M. Niepert, F. Alesiani
Paper:
http://arxiv.org/abs/2312.16560

Introduction

The field of deep learning for graphs studies how to automatically extract patterns from graph-structured data, such as proteins, social networks, and physical systems [1]. Currently, message passing-based approaches represent the dominant paradigm of computation due to their efficiency and effectiveness on downstream tasks [2].

Figure 1: For the input graph on the left, a 2-layer message passing architecture repeatedly exchanges messages between all connected nodes (middle). As a result, the contextual information of node 3 can be visualized by drawing the computational tree of depth 2.

Message passing (MP) is conceptually very simple. Given a graph like the one above, MP iteratively performs two steps:

  • Each node computes a new message based on its incoming messages;
  • Each node sends its new message through its outgoing edges.

Deep Graph Nets (DGNs) are neural and/or probabilistic machine learning models that stack many MP layers on top of each other to compute node embeddings, and these node embeddings are used to make predictions of individual nodes or entire graphs [1]. With DGNs, the graph is both the input and the computational medium: node embeddings strongly depend on the structural information contained in the graph. This can be seen in Figure 1 (right), where we visualize the amount of information captured, directly or indirectly, by node 3 after 2 rounds of MP.

Motivation

So far, DGNs have been extremely successful in a wide range of applications, such as the discovery of antibiotics [3], aiding mathematicians in the discovery of conjectures and theorems [4], and scaling up material discovery [5]. However, that does not mean they are flawless. It has been a few years since researchers started to address some of the intrinsic problems of MP, among which:

  • Oversmoothing [6,7]: node representations tend to converge to the same value as we add more MP layers. This makes it harder to implement deeper networks that work well;
  • Oversquashing [8,9]: there is an exponential amount of information to be compressed into a single node representation as we add more MP layers, and as such information can be lost or difficult to recover
  • Underreaching [8]: whenever it is important to capture a dependency between nodes that are K hops away from each other, one needs at least K MP layers to capture it. Researchers typically rely on grid search to find a good depth of the DGN (equal to # of MP layers), but the choice of which depth values to try is completely arbitrary.

Oversmoothing, oversquashing, and underreaching are often linked to the inability of DGNs to capture long-range dependencies in a graph. This is a considerable issue for many practical applications: in computational physics, electrostatic and gravitational interactions decay very slowly with distance [10]; in computational chemistry and material sciences, the accurate modeling of non-local effects, such as non-bonded interactions in molecular systems, is necessary to correctly estimate properties like the free energy [11]; in biology, disrupting long-range interactions in mRNA can inhibit slicing [12]; in immunology, the distant interactions between a major histocompatibility complex and regions of the T-cell receptor molecule correlate with the binding process of these two compounds [13].

Needless to say, there have already been many attempts in this direction. Skip/residual connections have been employed to reduce oversmoothing [14] as well as concatenating the node embeddings across layers when tackling graph prediction tasks [15, 16]. Concerning oversquashing, there is consensus that modifying the MP scheme is necessary, for instance by learning when a node should completely stop propagating a message [17] or if it should only listen, isolate, receive, or broadcast its message [18]. Other approaches try to sample edges [19], design a completely asynchronous MP scheme [20], avoid backtracking of messages [21], prune incoming messages via attention [22], or design anti-symmetric MP architectures [23]. Finally, “rewiring” approaches alter the graph connectivity to facilitate the propagation of information between distant nodes [24, 25] (recently, a critical perspective on the effectiveness of rewiring approaches has also been given [26]). As regards underreaching, one can think of adaptive approaches that learn the depth, for instance by applying the the cascade-correlation to a DGN [27].

Adaptive Message Passing

In this post, we introduce a method called Adaptive Message Passing (AMP). Motivated by the observation that synchronous MP contributes to oversmoothing and oversquashing, the idea is to let the DGN decide how many messages each node should send (up to infinity!) and when to send them. In other words:

  • We learn the depth of the network during training (addressing underreaching);
  • We apply a differentiable, soft filter on messages sent by nodes, which in principle can completely shut down the propagation of a message (addressing oversmoothing and oversquashing).

To do so, we extend the variational framework for unbounded networks of [28] to the processing of graphs, while also incorporating the message filtering operation. As a result, AMP’s behavior can range from the asynchronous MP [20] to classical MP [2]!

Learning the Depth

Figure 2: a Discrete Folded Normal distribution with finite quantile and with support over integers.

To learn the depth, we rely on families of truncated distributions over integers [28] whose parameters can be learned during training. If the chosen quantile of the distribution (e.g., 0.99) increases or decreases after a backpropagation step, then we add or remove one layer from the architecture. To make this work, the DGN has to produce an output prediction at each layer, and this prediction is weighted by the probability mass function of the truncated distribution. This increases the parametrization of the model but allows us to learn a possibly infinite depth if the task at hand requires it.

Filtering/ Soft Pruning of Messages

Instead, filtering outgoing messages leads to interesting behaviors (as also discussed in recent works under the lens of expressive power [18]), since the amount of messages that flow to a node can be arbitrarily reduced by the model. For instance, in the figure below, if one considers a filtering scheme for each node shown in (b), which has been discretized for simplicity, then AMP will propagate only a subset of the total messages at message passing iterations 1 and 2 (d). This process can be implemented in a simple and fully differentiable way. Similarly, one could decide whether a node should accept incoming messages or parts of them, but this is left to future works.

Figure 3: Given an input graph of seven nodes (a) and a simplified message filtering scheme where a node can be completely masked or not at each layer (b), we observe how an L=2-layer standard message passing (c) differs from AMP (d) in terms of the number of messages sent. In this discretized example, AMP is pruning part of the edges of the input graph differently at each layer, which has consequences on the embeddings of the nodes. In practice, the message filtering is implemented as a soft mask on the node embeddings to maintain differentiability.

In turn, filtering messages leads to different computational trees (Figure 4, where filters are discretized in the interest of the exposition) compared to standard MP. It is therefore evident that oversquashing can be mitigated, since only the messages relevant to the task are propagated, reducing the amount of information to be squashed into a single node embedding.

Figure 4: Comparison of the 2-hop computational tree necessary to compute the representation of node 3 in the graph of Figure 3 for standard message passing (left) and AMP (right), where we discretized message filtering to simplify the concept. AMP can effectively prune/filter information in sub-trees to propagate only the relevant information for the task.

Considerations

In our paper, we also provide a critical discussion on some of the metrics used to evaluate oversquashing so far, such as the sensitivity of the last embedding of node v to the input features of node u:

This metric, which has been proposed as a proxy for oversquashing [24], may not tell the whole picture! Methods that focus on pruning/filtering messages like AMP may significantly reduce this quantity while still addressing oversquashing. Similarly, some of the datasets that have been proposed to evaluate whether a method solves oversquashing rely on the fact that all information should be preserved [8]; this is only one of the possible scenarios, where methods like ADGN [23] are theoretically effective. We conclude that oversquashing is a multi-faceted problem, and it might be useful to consider all these aspects in isolation rather than using the same term to describe them.

Experiments

To evaluate the benefits of combining AMP with a classical DGN, we test on synthetic [23] and chemical datasets [29] where capturing long-term dependencies is important to solve the task. Below, you can find the results, but if you want more details please check the paper. Despite requiring more parameters than the base models, AMP always improves the performances of the base model it incorporates (see the symbol).

Synthetic datasets

Table 1: Mean log10(MSE) and standard deviation averaged over 20 final runs on Diameter, SSSP, and Eccentricity. The best mean performance is highlighted in bold, whereas the † indicates whether applying AMP yields an improvement in the mean score compared to the base architecture.

Chemical datasets

Table 2: We report the mean average precision (AP, higher is better) and mean absolute error (MAE, lower is better) on the chemical datasets, where the standard deviation is computed over four final runs. The best mean performance is highlighted in bold, the second-best performance is underlined, and a † indicates whether applying AMP improves the mean score compared to the base architecture.

Controlling Oversmoothing and Oversquashing

Figure 5: We present the natural logarithm of the Dirichlet energy (left) and of the sensitivity (right) across layers for the GCN model and its AMP version. For each dataset, the curve is drawn using one of the models trained during the final runs, whose best configuration was selected by grid search. We observe that both metrics vary depending on the task, which is an indication of AMP’s ability to control oversmoothing and oversquashing.

Analysis of Underreaching

Figure 6: We report the distribution learned by the best configurations of each base model on the synthetic datasets (left) and the chemical ones (right).

Conclusions

Capturing long-range dependencies is a longstanding problem in the graph machine-learning community. This post introduced Adaptive Message Passing, a probabilistic framework that can endow most message passing architectures with the ability to learn how many messages to exchange between nodes and which messages to filter out.

Our approach can plug in most existing message passing layers, consistently improving the performance on five tasks that evaluate the ability to capture long-range dependencies.

Importantly, we can achieve competitive results on these datasets without imposing strong inductive biases, letting the models decide when a node should exchange its message or part of it.

We believe AMP will foster exciting research opportunities in the graph machine learning field and find successful applications in the fields of physics, chemistry, and material sciences.

Thanks for reading!

References

[1] Bacciu, Davide, Errica, Federico, Podda, Marco, and Micheli, Alessio. “A gentle introduction to deep learning for graphs.” Neural Networks 129 (2020): 203–221.

[2] Gilmer, Justin, et al. “Neural message passing for quantum chemistry.” International Conference on Machine Learning (ICML), 2017.

[3] Wong, Felix, et al. “Discovery of a structural class of antibiotics with explainable deep learning.” Nature (2023): 1–9.

[4] Davies, Alex, et al. “Advancing mathematics by guiding human intuition with AI.” Nature 600.7887 (2021): 70–74.

[5] Merchant, Amil, et al. “Scaling deep learning for materials discovery.” Nature (2023): 1–6.

[6] Q. Li, Z. Han, and X.-M. Wu. Deeper insights into graph convolutional networks for semi-supervised learning. AAAI Conference on Artificial Intelligence (AAAI), 2018.

[7] T. K. Rusch, M. M. Bronstein, and S. Mishra. A survey on oversmoothing in graph neural networks. arXiv preprint arXiv:2303.10993, 2023.

[8] U. Alon and E. Yahav. On the bottleneck of graph neural networks and its practical implications. International Conference on Learning Representations (ICLR), 2021.

[9] F. Di Giovanni, L. Giusti, F. Barbero, G. Luise, P. Lio, and M. M. Bronstein. On over-squashing in message passing neural networks: The impact of width, depth, and topology. International Conference on Machine Learning (ICML), 2023.

[10] A. Campa, T. Dauxois, D. Fanelli, and S. Ruffo. Physics of long-range interacting systems. OUP Oxford, 2014.

[11] S. Piana, K. Lindorff-Larsen, R. M. Dirks, J. K. Salmon, R. O. Dror, and D. E. Shaw. Evaluating the effects of cutoffs and treatment of long-range electrostatics in protein folding simulations. PLoS One, 7(6):e39918, 2012.

[12] U. Ruegsegger, J. H. Leber, and P. Walter. Block of hac1 mRNA translation by long-range base pairing is released by cytoplasmic splicing upon induction of the unfolded protein response. Cell, 107(1):103–114, 2001.

[13] M. Ferber, V. Zoete, and O. Michielin. T-cell receptors binding orientation over peptide/MHC class i is driven by long-range interactions. PloS One, 7(12):e51943, 2012.

[14] G. Li, M. Muller, A. Thabet, and B. Ghanem. DeepGCNs: Can GCNs go as deep as CNNs? IEEE/CVF International Conference on Computer Vision (ICCV), 2019.

[15] K. Xu, C. Li, Y. Tian, T. Sonobe, K.-I. Kawarabayashi, and S. Jegelka. Representation learning on graphs with jumping knowledge networks. International Conference on Machine Learning (ICML), 2018.

[16] D. Bacciu, F. Errica, and A. Micheli. Probabilistic learning on graphs via contextual architectures. Journal of Machine Learning Research, 21(134):1–39, 2020a.

[17] I. Spinelli, S. Scardapane, and A. Uncini. Adaptive propagation graph convolutional network. IEEE Transactions on Neural Networks and Learning Systems, 32(10):4755–4760,
2020.

[18] B. Finkelshtein, X. Huang, M. Bronstein, and ˙I. ˙I. Ceylan. Cooperative graph neural networks. arXiv preprint arXiv:2310.01267, 2023.

[19] A. Hasanzadeh, E. Hajiramezanali, S. Boluki, M. Zhou, N. Duffield, K. Narayanan, and X. Qian. Bayesian graph neural networks with adaptive connection sampling. International Conference on Machine Learning (ICML), 2020.

[20] L. Faber and R. Wattenhofer. GwAC: GNNs with asynchronous communication. Learning on Graphs Conference (LoG), 2023.

[21] S. Park, N. Ryu, G. Kim, D. Woo, S.-Y. Yun, and S. Ahn. Non-backtracking graph neural networks. arXiv preprint arXiv:2310.07430, 2023.

[22] P. Velickovic, G. Cucurull, A. Casanova, A. Romero, P. Lio, and Y. Bengio. Graph attention networks. International Conference on Learning Representations (ICLR), 2018.

[23] A. Gravina, D. Bacciu, and C. Gallicchio. Anti-symmetric DGN: a stable architecture for deep graph networks. International Conference on Learning Representations (ICLR), 2023.

[24] J. Topping, F. D. Giovanni, B. P. Chamberlain, X. Dong, and M. M. Bronstein. Understanding over-squashing and bottlenecks on graphs via curvature. International Conference on Learning Representations (ICLR), 2022

[25] B. Gutteridge, X. Dong, M. M. Bronstein, and F. Di Giovanni. Drew: Dynamically rewired message passing with delay. International Conference on Machine Learning (ICML), 2023.

[26] D. Tortorella and A. Micheli. Leave graphs alone: Addressing over-squashing without rewiring. Learning on Graphs Conference (LoG), 2022.

[27] A. Micheli. Neural network for graphs: A contextual constructive approach. IEEE Transactions on Neural Networks, 20(3):498–511, 2009.

[28] A. Nazaret and D. Blei. Variational inference for infinitely deep neural networks. International Conference on Machine Learning (ICML), 2022.

[29] J. Tonshoff, M. Ritzert, E. Rosenbluth, and M. Grohe. Where did the gap go? reassessing the long-range graph benchmark. Learning on Graphs Conference (LoG), 2023.

--

--

Federico Errica

I am a Senior Research Scientist at NEC Labs Europe. I have a PhD in Computer Science from the University of Pisa. My field of expertise is Graph Learning