Looking behind the curtain: saliency maps for graph machine learning

Wu Huijun
stellargraph
Published in
10 min readAug 22, 2019

When it comes to understanding predictions generated by machine learning on graphs, one of the biggest questions we’re left with is; why? Why was a certain prediction generated?

There is high demand for interpretability on graph neural networks, and this is especially true when we’re talking about real-world problems. If we can’t understand how or why a prediction was made then it’s difficult to trust it, let alone act on it.

StellarGraph’s recent article machine learning on graphs introduced graph neural networks and the tasks they can solve. This time, we’re discussing interpretability: in other words, techniques to interpret the predictions returned by the machine learning model.

Using node classification with graph convolutional networks (GCN) as a case study, we’ll look at how to measure the importance of specific nodes and edges of a graph in the model’s predictions. This will involve exploring the use of saliency maps to look at whether the model’s prediction will change if we remove or add a certain edge, or change node features.

Definition of interpretability

Despite extensive interest in interpreting neural network models, there is no established definition for interpretability. Miller et al. define interpretability as the degree to which a human can understand the cause of a decision [1]. Kim et al. say interpretability is the degree to which humans can consistently predict the model’s decision [2]. The ongoing research on this topic generally relates interpretability to many other objectives such as trust, transparency, causality, informativeness, and fairness [3].

My personal view on interpretability of graph machine learning models is that it is a way to tell which node features or graph edges play a more important role in the model’s prediction. It is essentially asking a counterfactual question: how would the prediction change if these node features were different, or if these edges didn’t exist?

While applying GCNs to node classification tasks, the model predicts the label for a given target node based on both the features of nodes (including the target node itself), and the structure of the graph. The interpretability problem we want to tackle is: which nodes, edges and features led to this prediction? And similarly, what would happen if certain edges did not exist, or if features of nodes were different?

Such interpretations can potentially provide insights into the underlying reasoning mechanism of the GCN models.

Saliency maps

Saliency mapping is a technique with origins in computer vision literature [4] used to change or simplify an image into something that has meaning for humans, making it easier to analyse.

Take the following image as an example: a saliency map based interpretability tool [5] has highlighted the bicycle as an explanation for the predicted label ‘bicycle’.

First, let’s look at the simplest way to get the highlighted pixels.

Given the bicycle image, a classification ConvNet takes the image as the input and gives a score indicating the probability that the image is classified as a bicycle. To be precise, the classifier yields a vector of scores, one score per class. For example, [0.01, 0.001, …, 0.8, …, 0.005] would be a vector of scores for classes [‘flower’, ‘person’, …, ‘bicycle’, …, ‘cat’]. The final prediction is the class with the highest score; in this case, ‘bicycle’.

So, the vanilla saliency map approach simply calculates the gradient of the score of interest (typically the highest score, corresponding to the winning class) with respect to the pixels of the input image.

Since the saliency map has the same shape as the input image, the saliency map can be used as a mask on top to highlight the important part of the input image, with regards to the predicted label. As per the definition of gradients, this saliency map tells us how much the prediction score for class c would change if the RGB pixel intensities in the highlighted area of the image are slightly increased.

In the context of this discussion, we’ll be applying a saliency map approach to help interpret the decisions of GCN models.

Applying saliency maps in GCNs

From here, things start to get a little more technical.

In a node classification task, given the feature matrix X and adjacency matrix A of the graph, a two-layer GCN model is defined as:

A quick review of StellarGraph’s machine learning on graphs will be useful if you need a refresher here.

If we deal with the features of nodes and entries of the adjacency matrix in the same way as pixels in images, it seems to be intuitive to use saliency map methods to interpret the decisions of GCN models. However, for graph data, the features of nodes can be either categorical or binary (unlike for images, where the features are continuous intensities of RGB pixel channels). The weights of the edges, particularly for unweighted graphs, are also binary. This means the features and edge weights can only be flipped rather than slightly perturbed, impacting the meaning and clarity of the vanilla saliency map.

Taking the importance of edges as an example, the saliency map for class ccan be described as follows:

Here, fc is the predicted score for class c of the GCN model which takes the adjacency matrix A and feature matrix X as inputs. Due to the fact that fc(X, A) is not a linear function of X and A, the vanilla gradients may fail to give accurate importance due to the discreteness of graph data.

Take a simple ReLU network, f(x) = ReLU(x — 0.5) as an example. When xincreases from 0 to 1, the function value increases by 0.5. However, computing the gradient at x = 0 gives 0, which does not capture the model’s behavior accurately.

According to the discreteness of edge weights, we can consider the below:

This naive method inevitably introduces much computational overhead, especially for graphs with a large number of nodes and edges. It might be possible to compute the importance of the existing edges in the graph if the adjacency matrix is sparse.

To address the limitations of the vanilla saliency map approach and solve the discreteness problem in graph data in a continuous manner, we propose to exploit the idea of integrated gradients [6]. Integrated gradients is the path integral of the gradient along a path from the baseline to the input. To compute it, we can calculate the gradient for m times along the path and get the average.

Integrated gradients reflect the importance of edges more accurately than vanilla saliency maps because they capture the gradient information in a relative global manner.

Let’s again apply the example f(x) = ReLU(x — 0.5). Unlike the vanilla saliency map approach which only calculates the gradient at x = 0 or x = 1, integrated gradients take the gradient values between x = 0 to x = 1 into account. It is obvious that calculating integrated gradients will give us a positive value rather than 0. Further, since the gradient operations are well-optimised in a variety of frameworks, this method can exploit the high parallelisms of GPU.

The StellarGraph open-source Python Library offers support for the calculation of integrated gradients for GCN models. Here we use node classification on a citation network dataset (CORA-ML), with GCN as an example.

The CORA-ML dataset consists of 2,708 scientific publications, classified into one of seven classes. The citation network consists of 5,429 links corresponding to citations between the papers. Each publication in the dataset is described by a 0/1-valued key word vector indicating the absence/presence of the corresponding key words from the dictionary. The dictionary consists of 1,433 unique words.

We first build the GCN model and train it:

Given a target node (e.g., target_idx = 7), we then calculate the integrated gradients to measure the importance of graph nodes and edges with regard to the model’s predicted ‘class_of_interest’ for the target node (the ‘class_of_interest’ is typically the class with the highest predicted score):

To make the importance scores more intuitive, we plot the 2-hop ego graph of the target node. In the figures below, the red edges indicate positive importance, while the blue edges indicate negative importance. The width of edges measures the absolute value of the importance. Meanwhile, the round nodes indicate positive importance and the star nodes indicate negative importance. Again, the size of the nodes measures the absolute value of node importance. Finally, the class bar to the right shows node labels, with the node colours indicating true class labels of nodes:

Figure 1

Overall, we see that the closest node of the same class generally has higher importance, while nodes with different labels tend to have negative importance.

To verify these results, we compare with the ground-truth importance. For edges, we obtain this by removing them, one by one, and measuring the corresponding change in the model’s predicted score for the class of interest (the winning class). For nodes, we select one each time and set its features to all-zero, measuring the corresponding change in the model’s prediction score for the target node. The ground-truth importance for the edge or node can then be calculated by re-evaluating the model and noting the change in the predicted score:

Figure 2

By comparing figures 1 and 2, we can see that the integrated gradients are quite consistent with the brute-force approach, as the relative importance of nodes and edges is similar.

The CORA-ML dataset is used with the information of exact words for the bag-of-words features. We then evaluate the importance of the prediction for target node #2741, which has the label “probabilistic methods”. The top-50 important features among nodes are as follows:

Conversely, the top-50 negative features are:

The top positive node features, such as bayesian, conditional, probabilistic, density, etc., make sense because the ground-truth class is “probabilistic methods”. Words like ‘computer’, ‘California’, and ‘Irvine’ etc., generally appear in the author affiliations and the dataset descriptions (i.e. the UCI dataset, widely used in ML research, is from UC Irvine).

The top negative features on the other hand appear to be less relevant to the “probabilistic methods” class. By contrast, they may point to other classes. For instance, words like ‘rules’ and ‘pruning’ seem to be more relevant to the class ‘rule-based methods’.

By setting the top five important features to 0, and the top five least important features to 1, the prediction score for class 3 reduces from 0.956 to 0.321. The winning class becomes class 1. This further verifies the correctness of the importance scores.

In closing

Interpretability plays a critical role in graph machine learning tasks. In this post, we’ve considered interpretability based on the question: which node features and edges play a more important role for the model to give certain predictions?

We show that by using an integrated gradient approach, we can measure the importance of each edge and node feature and in turn, gain insights into the model’s predictions. What’s more, knowing how the model works is also the first step towards building better models.

Our next post in this series will take us beyond saliency maps to adversarial attacks and defences, and will seriously motivate us to rethink the capabilities of GCN models. If you’re eager to delve deeper in the meantime, take a look at our paper Adversarial Examples on Graph Data: Deep Insights into Attack and Defence.

Finally, the StellarGraph open-source Python Library is a great resource to explore interpretations for learning tasks if you want to experiment with graph machine learning techniques and applying them to network-structured data.

This work is supported by CSIRO’s Data61, Australia’s leading digital research network.

References

  1. Miller, Tim. “Explanation in artificial intelligence: Insights from the social sciences.” arXiv Preprint arXiv:1706.07269. (2017).
  2. Kim, Been, Rajiv Khanna, and Oluwasanmi O. Koyejo. “Examples are not enough, learn to criticize! Criticism for interpretability.” Advances in Neural Information Processing Systems (2016).
  3. Lipton, Z. C. (2018). The mythos of model interpretability. Communications of the ACM, 61(10), 36–43.
  4. Simonyan, K., Vedaldi, A., & Zisserman, A. (2013). Deep inside convolutional networks: Visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034.
  5. Yang, J., & Yang, M. H. (2016). Top-down visual saliency via joint CRF and dictionary learning. IEEE transactions on pattern analysis and machine intelligence, 39(3), 576–588.
  6. Sundararajan, M., Taly, A., & Yan, Q. (2017, August). Axiomatic attribution for deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70 (pp. 3319–3328). JMLR. org.
  7. Molnar, C. (2018). Interpretable machine learning. A Guide for Making Black Box Models Explainable.
  8. Daniel Zugner, Amir Akbarnejad, and Stephan Gunnemann. Adversarial attacks on neural networks for graph data. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 2847–2856. ACM, 2018.

--

--