Why should I trust my Graph Neural Network?

An introduction to explainability methods for GNNs

By Alicja Chaszczewicz, Kyle Swanson, Mert Yuksekgonul as part of the Stanford CS224W course project.

Imagine we have a Graph Neural Network (GNN) model that predicts with fantastic accuracy on our test dataset whether chemical compounds are mutagenic or whether a molecule has the ability to inhibit HIV replication. How can we trust the model to work on settings beyond our test distribution? How can we understand if the model is using the underlying biological mechanisms rather than relying on spurious correlations contained in the training dataset?

Photo by Terry Vlisidis on Unsplash

Deep learning has proven extraordinarily successful at solving challenging problems in computer vision, natural language processing (NLP), and beyond [1]. However, most deep learning models contain a very large number of parameters that are involved in a complicated sequence of functions, making it difficult for humans to understand and reason about how these models make decisions. In many real-world situations, this lack of explainability can significantly hinder the acceptance of the model’s predictions. For example, if a model is able to classify images of skin lesions with high accuracy but cannot explain its predictions, doctors are unlikely to trust the model and will not take advantage of its predictive capabilities when diagnosing patients [2, 3]. A model that is able to explain its reasoning is much more likely to be trusted and accepted and can therefore have a greater impact in the world.

GNNs are great at processing graph-structured data such as molecules; however, just like most other deep learning models, their internal decision-making processes are typically opaque to humans. While a variety of interpretability methods have been developed for deep learning models in the computer vision and NLP domains [4], fewer methods exist for the nascent field of graph machine learning [5]. In this blog post, we explore explainability methods that are designed for GNNs and elucidate how GNNs make decisions about the graphs they operate on. All the results can be reproduced using our Colab.

Case Study

Datasets

In this post, we will focus on two problems.

BA-Shapes is a synthetic node classification dataset, where 5-node house motifs are added to a Barabási–Albert (BA) graph, and labels of the nodes are determined according to their position within the house motif (shown in the figure below). We follow the setting in [8], where we use a BA graph with 300 nodes and 80 house motifs. In this dataset, the ground truth mechanism is available to us: we expect the models to use the positions of the nodes with respect to house motifs to make their predictions.

House Motifs in the BA-Shapes dataset. There are 4 classes based on their positions with respect to these motifs: (1) The top node in the motif (2) The middle nodes in the motif (3) The bottom nodes in the motif (4) Nodes that are not in the motif.

MUTAG [9] is a molecular graph classification dataset, where the goal is to predict mutagenicity on Salmonella typhimurium of nitroaromatic compounds. This dataset consists of 188 compounds with binary labels. It is known that carbon rings and NO₂ groups have a strong relationship with mutagenic effects [9], so we expect our models to discover these relations.

Model

Throughout our experiments, we use the Graph Isomorphism Network (GIN) architecture proposed in [6]. All variants of GNNs follow the same two-step framework to define a single layer: Message + Aggregation. The expressivity of the architecture depends on these two choices, along with how the layers are connected. We refer to [11] for various design choices that can be made when choosing a GNN.

GNN figure from Representation learning on graphs: Methods and applications by Hamilton et al. (2017)

GINs are introduced to satisfy simple theoretical properties that lead to very expressive functional forms. For instance, standard GNN layers such as GraphSAGE [12] and GCN [13] cannot model injective functions since their aggregation steps are simply max or mean pooling [6]. GIN purposefully modifies the aggregation step, using an MLP and sum aggregation:

Update step for the node representations of the GIN architecture.

While GIN is an interesting architecture in its own right, the details of the theoretical and methodological properties are beyond the scope of this post, and hence we refer the reader to the original paper for further details. The GIN model we train on BA-Shapes achieves a test accuracy of 0.986, and the GIN model we train on MUTAG achieves a test accuracy of 0.789 and a test AUC of 0.857.

GNNExplainer

GNNExplainer [8] is a model-agnostic explainability method that can be applied to any machine learning task on graphs, such as link prediction, node classification, or graph classification. The main idea behind the method is to mask unimportant features and edges and to discover a part of the graph that significantly influences the model’s predictions.

GNNExplainer learns a mask over graph edges. Important edges are selected (in green), and the subgraph induced by them (which might be disconnected) is the discovered explanation. Additionally, a mask over node attributes is learned simultaneously to identify crucial node features for the GNN prediction. Please see more details in the original work [8].

More technically, GNNExplainer learns a graph mask and a feature mask that together specify the subgraph (which might be disconnected) that maximizes the mutual information with model’s predictions. The mutual information objective aims to quantify how model predictions change if unimportant parts of the graph are masked and only the explanation subgraph is left. Discovering the most informative subgraph, as defined by the mutual information metric, is computationally intractable since the number of subgraphs is exponential in the number of nodes. GNNExplainer instead finds a solution to a simplified problem, which optimizes a continuous relaxation of the objective.

Here we consider a version of the objective targeting a question of why the model predicted a particular label. The objective for the prediction task of Y on the graph G with node features X becomes

which is a modified cross entropy loss. Intuitively, if labels c represent the original model’s predictions, masking edges that are unimportant for this prediction (as defined by the application of mask M to the adjacency matrix A) should not modify the predictions. Please note that the mask is represented as a matrix of real numbers and the sigmoid function is applied to transform the values to the (0, 1) range. The final explanation consists of edges with mask values higher than a set threshold (which is a hyperparameter). Please see more details on this objective and its extension to feature masking in the original paper [8].

We apply the GNNExplainer method to the BA-Shapes dataset and obtain explanations highlighting important graph structures.

The subgraph induced by the edges selected by the GNNExplainer for the task of label prediction of BA-Shapes node 366. Different node colors represent node labels. GNNExplainer accurately identifies the important house motif graph structure.

To see more examples of accurate (and inaccurate) explanations please refer to our Colab. There, you can also interactively adjust the threshold on the edge mask that is used to select the final explanation set of edges.

We also run GNNExplainer on the MUTAG dataset. The method sometimes results in non-intuitive explanations with disconnected induced subgraphs (depending on the chosen threshold). Please find more examples and experiment with different thresholds in our Colab.

GNNExplainer selected edges (in black) for a mutagenic molecule.

Subgraph Explanations

Although GNNExplainer can accurately highlight important nodes and edges, these nodes and edges can be in disconnected regions of the graph. This can make it challenging to reason about how those disconnected nodes and edges combine to explain the GNN’s explanation, given that the GNN cannot reasonably be run on those nodes and edges alone without the rest of the graph connecting them.

Perhaps a more intuitive type of explanation for a GNN’s prediction is a connected subgraph. In many graph-structured datasets, subgraphs are the reason that a graph has a particular label. For example, in the BA-Shapes dataset, the location of nodes within the house motifs, which are subgraphs, explain the nodes’ labels. In the MUTAG dataset, the presence of carbon rings and NO₂ groups, which are also subgraphs, largely determines whether the molecule as a whole is mutagenic. Below, we use the MUTAG dataset to demonstrate two methods for identifying subgraph explanations: one that explicitly counts subgraphs without a GNN and one that uses a GNN to identify salient subgraphs.

Counting Subgraphs

An illustration of the subgraph counting method. For a given subgraph size k, we enumerate and count every subgraph in the graph of size k (here k = 3), and we build a feature vector of subgraph counts for the graph. Once this is done for every graph, we train a logistic regression model to predict each graph’s label using the subgraph count vectors. The trained coefficients of the logistic regression model are used to identify important subgraphs.

In order to identify subgraphs that explain a graph’s binary classification label, a simple method is to explicitly count the subgraphs in every graph and identify whether certain subgraphs tend to be associated with certain labels. For example, in the MUTAG dataset, carbon rings and NO₂ groups tend to appear more often in graphs with the 1 (mutagenic) label and less often in graphs with the 0 (non-mutagenic) label.

The first step to identifying such subgraphs using subgraph counts is to choose a desired subgraph size k. Then, for each graph, we enumerate every size-k subgraph in that graph using the ESU algorithm [10]. We then count the number of instances of each unique subgraph and build a vector of those counts for each graph. Now, we have a feature vector of subgraph counts along with a binary label for each graph. We use this data to train a logistic regression classifier with L1 penalty to predict the label. This method is summarized in the figure above.

The three most important MUTAG subgraphs according to a logistic regression model trained on subgraph counts. A positive coefficient means the subgraph is predictive of mutagenicity while a negative coefficient means the subgraph is predictive of non-mutagenicity.

For the MUTAG dataset, we chose to use size k = 5 subgraphs. In the training set, we identified 94 unique subgraphs with 5 nodes, and the counts of these subgraphs became the feature vector for each graph. We trained a logistic regression model on this data, and it achieves a test accuracy of 0.842 and a test AUC of 0.845 while using only 17 of the 94 subgraphs (a sparse set was selected due to the L1 penalty). We then examined the subgraphs that have the largest coefficient (in absolute value) in the model as these are the subgraphs that most explain the model’s prediction for the graph. As seen above, the model correctly identifies carbon rings and NO₂ groups as predictive of mutagenicity (positive coefficient) and other subgraphs as predictive of non-mutagenicity (negative coefficient). This illustrates the potential of even a simple logistic regression model trained on subgraph counts can identify salient subgraphs.

SubgraphX

Illustration of SubgraphX. Given a graph, SubgraphX uses Monte Carlo Tree Search (MCTS) to identify subgraph explanations. MCTS starts with the whole graph and iteratively prunes nodes until reaching a pre-determined smallest subgraph size k (here k = 3). It then applies a GNN to evaluate the subgraph. Larger subgraphs along the path from the original graph to the size-k subgraph keep track of the average score of all size-k subgraphs they contain. MCTS quickly explores the space of high-scoring subgraphs by balancing exploitation (pruning nodes that lead to high scoring subgraphs, e.g., top left subgraph) and exploration (pruning nodes that lead to subgraphs with few visits, e.g., bottom left subgraph).

Although the subgraph counting method can identify interpretable subgraph explanations, it has to two key limitations. First, the time it takes to explicitly enumerate all subgraphs grows exponentially with k, meaning the method only works for very small values of k (generally k ≤ 5). Second, the method requires using a simple model such as a logistic regression model to obtain subgraph coefficients, which limits the expressive power of the predictive model. Rather than sacrifice speed and accuracy for explainability, an alternate solution is to develop a method that efficiently extracts subgraph explanations directly from a trained GNN. SubgraphX [7] does precisely that (see figure above).

Given a graph with a particular label, the goal of SubgraphX is to identify small subgraphs that have high scores for that label according to the GNN. These subgraphs then explain the GNN’s prediction for the entire graph since each subgraph alone causes the GNN to predict the correct label. Since enumerating all subgraphs of size k and making GNN predictions on them would be prohibitively slow even for small values of k, SubgraphX employs Monte Carlo Tree Search (MCTS) to efficiently search the space of size-k subgraphs for high scoring subgraphs.

Starting with the whole graph, MCTS iteratively prunes nodes (and connecting edges) from the graph to obtain subgraphs. The pruning process continues until a subgraph of size k is obtained. This subgraph is evaluated by the GNN and the prediction serves as the score for this subgraph. (The original SubgraphX method uses Shapley values to score subgraphs, but we use the GNN prediction for simplicity.) All the larger subgraphs on the path from the original graph to the size-k subgraph keep track of the number of times that subgraph has been visited by the MCTS search along with the average score of all the size-k subgraphs that have been found starting from that subgraph. As MCTS proceeds, it decides which nodes to prune by balancing exploitation (pruning nodes that lead to high scoring subgraphs) and exploration (pruning nodes that lead to subgraphs with few visits). This allows MCTS to quickly identify a variety of high-scoring size-k subgraphs without explicitly enumerating all of them. The highest scoring size-k subgraphs that have been found then serve as the explanations for the GNN’s prediction.

Subgraph explanation (subgraph edges in black) for a mutagenic molecule produced by SubgraphX.

We extracted the key elements of the SubgraphX implementation (from here) for binary graph classification, and we applied the method to our GNN model trained on the MUTAG dataset. Above can be seen the subgraph identified for one particular mutagenic molecule. As with the subgraph counting method, SubgraphX identifies carbon rings as explanations for the model’s prediction. However, unlike the subgraph counting method, with SubgraphX we are able to use the full power of the GNN to make accurate predictions while still obtaining interpretable subgraph explanations.

Conclusion

In this post, we introduced GNNExplainer, subgraph counting, and SubgraphX as model explainability approaches. However, more explainability methods are being proposed as GNNs draw more and more attention. The Dive-into-Graphs repository is a wonderful resource for graph deep learning research and contains several explainability methods, including some of the ones in this post. As another amazing resource, this recent survey paper talks about most of the interpretability methods proposed so far for GNNs.

References

[1] Y. Bengio, I. Goodfellow, and A. Courville, “Deep learning,” MIT Press, 2016.

[2] S. Tonekaboni, S. Joshi, M. D McCradden, and A. Goldenberg, “What Clinicians Want: Contextualizing Explainable Machine Learning for Clinical End Use,” in Proceedings of the 4th Machine Learning for Healthcare Conference, vol. 106, pp. 359–380, 2019.

[3] A. Gomolin, E. Netchiporouk, R. Gniadecki, and I. V. Litvinov, “Artificial Intelligence Applications in Dermatology: Where Do We Stand?” in Frontiers in Medicine, vol. 7, 2020.

[4] D. V. Carvalho, E. M. Pereira, and J. S. Cardoso, “Machine Learning Interpretability: A Survey on Methods and Metrics,” in Electronics, vol. 8, no. 8, 2019.

[5] H. Yuan, H. Yu, S. Gui, and S. Ji, “Explainability in Graph Neural Networks: A Taxonomic Survey,” in arXiv preprint arXiv:2012.15445, 2020.

[6] K. Xu, W. Hu, J. Leskovec, and S. Jegelka, “How Powerful are Graph Neural Networks?” in International Conference on Learning Representations, 2019.

[7] H. Yuan, H. Yu, J. Wang, K. Li, and S. Ji, “On Explainability of Graph Neural Networks via Subgraph Explorations,” in arXiv preprint arXiv:2102.05152, 2021.

[8] R. Ying, D. Bourgeois, J. You, M. Zitnik, and J. Leskovec, “GNNExplainer: Generating Explanations for Graph Neural Networks,” in Advances in Neural Information Processing Systems, vol. 32, 2019.

[9] A. K. Debnath, R. L. Lopez de Compadre, G. Debnath, A. J. Shusterman, and C. Hansch, “Structure-Activity Relationship of Mutagenic Aromatic and Heteroaromatic Nitro Compounds. Correlation with Molecular Orbital Energies and Hydrophobicity,” in Journal of Medicinal Chemistry, vol. 34, no. 2, pp. 786–797, 1991.

[10] S. Wernicke, “Efficient Detection of Network Motifs,” in IEEE/ACM Transactions on Computational Biology and Bioinformatics, vol. 3, no. 4, pp. 347–359, Oct.-Dec. 2006.

[11] J. You, Z. Ying, J. Leskovec, “Design Space for Graph Neural Networks,” in Advances in Neural Information Processing Systems, vol. 33, 2020.

[12] W. Hamilton, Z. Ying, J. Leskovec, “Inductive Representation Learning on Large Graphs,” in Advances in Neural Information Processing Systems, vol. 30, 2017.

[13] T. N. Kipf, M. Welling, “Semi-Supervised Classification with Graph Convolutional Networks,” in International Conference on Learning Representations, 2017.

--

--