C&S: Use the graph structure in your data with prediction post-processing
Graph learning techniques are hard to add to existing ML pipelines. Using Correct&Smooth to refine your existing predictions by “smoothing” them over the graph structure
Introduction
While Graph Neural Networks (GNNs) have recently dominated academic benchmarks [3], they have not seen comparable adoption in industry ML. Graph data structures and GNNs are difficult to scale to large data [1, 6] and are sufficiently different from standard supervised learning toolkits that teams would need to undergo a time-consuming and expensive process to adopt them. As a result, even if data does have a clear graph structure, it is often ignored or included in the form of row-level hand-tuned features like node degree or graphlets [4].
We want to upgrade from row-level hand-engineered node features to general graph learning techniques while making minimal change to our modeling workflow. Using a recent graph ML algorithm called Correct & Smooth (C&S) [1], we can rival or outperform GNNs using a post-processing module that sits on top of your existing supervised classifier. We will showcase the method on the LastFM Asia dataset, where the task is to predict the home country of users in a social network. We’ll work through in order:
- Describe the LastFM Asia dataset and transductive node classification task
- Implement C&S from scratch using PyTorch and PyG [5]
- Use C&S to improve performance of an off-the-shelf classifier model, and compare against baselines under different data ablations. In our experiments, we saw as much as a 20% increase in accuracy when postprocessing with C&S over the underlying MLP.
You can follow along with the code and results in Colab, or install simple correct&smooth as a python package directly:
pip install https://github.com/andrewk1/correctandsmooth/archive/refs/tags/v0.0.2.zip
Let’s get started!
Introducing our dataset and Transductive Node Classification
LastFM Asia
LastFM Asia was introduced in [2] and consists of 7,624 users and 55,612 friendship relations. For each user, we also have a 128 dimensional feature vector representing artists liked by the user, and an index describing their home country out of 18 classes. The objective is to predict which home country a user is from given the graph and their artist preferences.
We use the dataset helper class from PyG to load the dataset. The edges are stored as a tensor of node index pairs. No extra processing is needed.
Here’s how the network looks (Fig. 1), with node locations determined using a Voronoi-based technique from GraphViz.
Transductive Node Classification
In a standard supervised learning setup, we classify each user using the information about their favorite artists. In transductive node classification, we solve the same classification task but additionally know the entire graph structure at training time, including the locations of validation and test set nodes.
This looks like the following visually (Fig. 2), where we made a random train/val/test splits for the nodes. We see the entire network always, but mask the classes of green (validation) and red (test) nodes during training.
Implementing C&S in PyTorch and PyG
C&S exploits the pattern that adjacent nodes tend to have the same labels and similar features, a concept called homophily. Once we have a predicted probability distribution over classes from a supervised learner, C&S first “corrects” errors on test nodes by using error from predictions on training nodes, and then “smooths” the predictions on the test data by incorporating nearby ground truth training labels.
We implement C&S from scratch following the paper, explaining the equations from the paper with code.
Step 1: Training a base predictor
The first step is to acquire a base classifier model that can output a probability distribution over the classes per sample. We train a shallow MLP in PyTorch:
The MLP achieves 0.696 accuracy on the validation and 0.726 accuracy on the test set. While nodes deep in each class cluster are consistent, the model makes errors on users with friends in different countries (cluster borders).
Step 2: Correct
We first form a matrix with each row equal to the residual error between the (one-hot encoded) labels and the predicted class distributions for the training nodes only. Here, “Lt”, “Lv” and “U” are the train, val, test sets respectively.
In code, this looks as follows:
Next we “smooth” the error across the graph. Due to homophily, we expect errors to be positively correlated for neighboring nodes, so for validation/test nodes, errors on neighboring training nodes can be predictive of the real error.
The correction is computed with the following recursion. S is the normalized adjacency matrix, so SE will set a new error for each node as a weighted average of their neighbor node errors where nodes with lower degree (less neighbors) will have a higher weight (for an in-depth explanation of S, check out this Math StackExchange thread). Here, α (alpha) is a hyperparameter.
In code, this is a while loop where we check convergence using the L2 norm of the change in E.
We visualize the residual error matrices for the correct class at the first and final steps of the correct phase (brighter color is higher error). Empirically, we see that the Correct phase will reduce the magnitude of error on some nodes, although there is not much visible difference.
After smoothing the error, C&S scales the size of the new errors to be in the same scale as the original training errors. Adding the residuals back to the original predictions give us a new prediction vector Zr.
In code, we take average L1 norm of the training errors per row, and then normalize and re-scale the new predictions.
Step 3: Smooth
In the Correct step, we smoothed errors over adjacent nodes. In the Smooth step, we will also smooth the predictions across adjacent nodes following the same intuition. The smoothing operation is identical to the error correction, this time iterating over our best guess matrix G, initialized to our scale prediction vector.
The code is also near identical to the correct step:
Once we run the smoothing step, we have our final predictions! The GIF below visualizes our best guess at each iteration of the smoothing process. We can see that the nodes that change their predictions most live in-between the class clusters, where artist preference alone is not a clear predictor of nationality.
Measuring performance
We compare test set performance on LastFM Asia between the MLP and C&S using accuracy, including a hyperparameter sweep over the choice of alpha1
and alpha2
on the validation set (both of these were optimized over in the paper as well).
With each step, we see nearly 10% increase in accuracy! Clearly, there are huge gains to be had in including the graph structure in this particular predictive task. We also see the importance of the two alpha
variables to the performance of the smoothing steps — run a hyperparameter sweep if you choose to implement this method.
Understanding C&S
A 20% increase in accuracy deserves some scrutiny — why does C&S perform so well here? We think homophily can explain the results, where the training nodes near test nodes can inform the test node labels. Let’s do a simple experiment: Run the smoothing step over the training node labels directly, and use the results to predict validation and test labels:
Training label smoothing on 80/10/10% split
An interesting result — We see that just smoothing the training labels over neighbor test nodes will nearly match the performance of C&S+MLP approach! This validates that the gains of C&S rely on sharing training node information with nearby test nodes.
For our choice of train/val/test = 80/10/10%, we might simply have access to lots of adjacent training nodes that allows an approach like smoothing training labels to work. Let’s try reducing to 50/25/25% and re-running the experiment. We might expect C&S to outperform here because it can take advantage of the underlying MLP to “fill in” labels.
Training label smoothing on 50/25/25% split
Even in the low training data regime, pure smoothing holds up very well, outperforming C&S and other learning approaches entirely! While we reduced the number of nodes in the training set, we didn’t touch the number of graph connections. Visualizing the data splits, this seems to be a reasonable result due to the density of the graph connections: most nodes still end up with a connection to a training node!
What if we sub-sample the edges in the graph, and rerun the same model pipeline with 70% of the edges removed? The intuition here is we cannot rely on the pure density of the graph, where test nodes will have many training nodes to infer a label off of. Our models will need to more effectively combine the individual feature nodes and any existing neighbors to form a prediction.
Here are the results for this ablation:
Training label smoothing on 80/10/10% split with 30% of edges
The results align with our intuitions! C&S is a clear winner here, outperforming both the MLP and simple smoothing baselines at 80% accuracy versus 70% and 60% respectively. The smoothing baseline is even worse than the MLP baseline because of the lack of edge density. However, the combination of the two signals (user features and graph structure) via C&S can still result in 10% gains in accuracy over the MLP.
Conclusion
C&S is a strong method to boost your existing classifier performance on graph data with two fast and easy postprocessing steps. In very dense graphs, you should try just adopting the smoothing step over your training labels for comparable accuracy gains. When your graphs are not as dense, or your test set does not share many edges with your training data, C&S can provide a great boost to supervised learners to exploit some knowledge of edge structure. The fundamental smoothing code is very similar for all methods, so it’s worth giving them all a try in your next project!
You can find me on Twitter for more content :)
Simple correct&smooth package and Colab
A packaged version of the model can be found here:
All of the code, results, and visualizations from this blog can be found in this Colab:
Citations
[1] Huang, Qian, et al. “Combining label propagation and simple models out-performs graph neural networks.” arXiv preprint arXiv:2010.13993 (2020).
[2] Rozemberczki, Benedek, and Rik Sarkar. “Characteristic functions on graphs: Birds of a feather, from statistical descriptors to parametric models.” Proceedings of the 29th ACM International Conference on Information & Knowledge Management. 2020.
[3] https://ogb.stanford.edu/docs/leader_nodeprop/
[4] https://en.wikipedia.org/wiki/Graphlets
[5] Fey, Matthias, and Jan Eric Lenssen. “Fast graph representation learning with PyTorch Geometric.” arXiv preprint arXiv:1903.02428 (2019).
[6] Bojchevski, Aleksandar, et al. “Scaling graph neural networks with approximate pagerank.” Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020.