Graph Machine Learning Explainability with PyG

PyTorch Geometric
14 min readFeb 2, 2023

--

By Blaž Stojanovič

Graph Neural Networks (GNNs) have become increasingly popular for processing graph-structured data, such as social networks, molecular graphs, and knowledge graphs. However, the complex nature of graph-based data and the non-linear relationships between nodes in a graph can make it difficult to understand why a GNN makes a particular prediction. With the rise in popularity of Graph Neural Networks, there also came an increased interest in explaining their predictions.

The importance of explanations in practical Machine Learning applications cannot be overstated. They help build trust and transparency in models, as users can better understand how predictions are made and the factors that influence them. They improve decision-making, giving decision-makers additional understanding to make better informed decisions based on model predictions. Explanations also make it easier for practitioners to debug and improve the performance of the models they develop. In certain domains, such as finance and healthcare, explanations may even be required due to compliance and regulations.

Explanations in Graph Machine Learning are very much an ongoing research effort, and explainability on graphs is not as mature as interpretability in other subfields of ML, like computer vision or NLP. Additionally the explanations themselves differ, due to complex relational data GNNs operate on:

  • Contextual Understanding: Explanations require a contextual understanding of the relationships and entities between nodes in the graph, which can be complex and difficult to understand
  • Dynamic relationships: The relationships between nodes in a graph can change over time, making it challenging to provide explanations for predictions made at different points in time
  • Heterogeneous Data: GML often involves processing heterogenous data types with complex features, making it difficult to provide a unified explanation methodology
  • Explanation granularity: Explanations have to explain both the structural origin of the prediction, as well as feature importance. This means explaining which nodes, edges, or subgraphs are important, as well as which node or edge features contribute strongly to the prediction outcome.
Figure from [4], which highlights the complexities of explanations in graph machine learning. The left hand side shows the GNN computation graph for making the prediction at node v. Some edges in the computational graph are important neural massage-passing pathways (green), while others are not (orange). However, the GNN needs to aggregate important as well as unimportant features to make the prediction, and the goal of explanation methods is to identify a small set of important features and pathways which are crucial for prediction.

Difficulties and complexities of Graph Machine Learning aside, there has been a lot of unifying work in the field recently, which aims to both provide a unified framework for evaluating the explanations [1,2], and provide a taxonomy of the existing zoo of explanation methods available [3].

In a recent community sprint the PyG community has implemented a core explainability framework along with various evaluation methods, benchmark datasets, and visualisations, which make it very easy to get started with Graph Machine Learning explanations in PyG. Moreover the framework is useful both if you just want to use common graph explainers like GNNExplainer [4] or PGExplainer [5] out of the box, or if you want to implement, test, and evaluate your own explanation methodologies.

In this blog post we will go step by step through the explainability module, shedding light on how each component of the framework works and what purpose it servers. Afterwards, we will go over various explanation evaluation methods and synthetic benchmarks, which work hand in hand to make sure you produce the best explanations for the task at hand. We will continue by taking a look at visualisation methods which are available out of the box. Finally, we will go over steps necessary to implement your own explanation method in PyG, as well as highlighting work on advanced use cases such as Heterogenous graphs and Link prediction explanations.

The Framework

When designing the Explainability framework our goal was to design an easy to use explainability module, which:

  • can be extended to meet requirements of many GNN applications
  • can be adapted to various types of graphs and explanation settings
  • can provide explanation output to be comprehensively evaluated and visualised

There are really four concepts at the core of the framework:

  • Explainer class: a wrapper of the PyG explainability module for instance level explanations
  • Explanation class: the class to encapsulate the output of an Explainer
  • ExplainerAlgorithm class: the explainability algorithm used by Explainer to generate Explanation(s) given training instance(s)
  • metric package: evaluation metrics that use the Explanation output and potentially the GNN model / ground truth to evaluate the ExplainerAlgorithm

To see how they all come together, let us take a look at the figure below:

High level overview of the PyG Explainability framework

The user provides the explanation settings, as well as the model and data which need to be explained. The Explainer class, which is a PyG instance that wraps an explainer algorithm — a particular explanation method, generates the explanations for given model and data. The explanations are encapsulated in the Explanation class and can be further post-processed, visualised, and evaluated. Let us now go more into depth regarding the various explanation settings available

Example Explaier

Here is an example Explainer setup, that uses the GNNExplainer for model explanations on the Cora dataset (see the gnn_explainer.py example).

explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)

The node level mask is set for all attributes and the edge mask is set for edges as objects. To produce an explanation for a particular prediction of the model we simply call the explainer:

node_index = 10 # which node index to explain
explanation = explainer(data.x, data.edge_index, index=node_index)

Let us now take a look at all of the nuts and bolts, that make explanations in PyG this easy!

The Explanation Class

We represent explanations using the Explanation class, which is a Data or HeteroData object containing masks for the nodes, edges, features and any attributes of the data. In this paradigm the masks act as explanation attributions for respective nodes/edges/features. The larger the mask value the more important is the corresponding component to the explanation (0 being completely unimportant). The Explanation class contains methods for obtaining the induced explanation subgraph, which is comprised of all non-zero explanation attributions, and the compliment to the explanation subgraph. Additionally, it includes thresholding and visualisation methods for the explanation.

The Explainer Class and Explanation Settings

The Explainer class is designed to handle all of the explainability settings, these settings are set either as direct parameters to the Explainer or as configs in the case of ModelConfig or ThresholdConfig. There are many settings provided with this new interface. Let’s go over the available ones one-by-one.

# Explainer Parameters
model: torch.nn.Module,
algorithm: ExplainerAlgorithm,
explanation_type: Union[ExplanationType, str],
model_config: Union[ModelConfig, Dict[str, Any]],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
threshold_config: Optional[ThresholdConfig] = None,

The model can be any PyG model we are using to produce the explanations. Additional model settings are specified in the ModelConfig which specifies the mode, the task_level, and the return_type of the model. The mode describes the task type, e.g. mode=multiclass-classification, the task_level signifies the task level (node-, edge-, or -graph level tasks), and the return_type describes the expected return type of the model (raw, probs, or log_probs).

There are two types of explanations, as specified with explanation_type (for a more in depth discussion see [1])

  • explanation_type="phenomenon" aims to explain why a certain decision was made for a particular input. We are interested in the phenomenon that leads from inputs to outputs in our data. In this case the labels are used as targets for the explanation.
  • explanation_type="model" aims to provide a post-hoc explanation for the model provided. In this setting we are trying to open the black-box and explain the logic behind it. In this case the model predictions are used as targets for the explanation.

How exactly the Explanation‘s are computed is specified with the algorithm parameter, several off-the-shelf are available in the module:

We also support many different types of masks, these are set with node_mask_type and edge_mask_type, and can be:

  • None will not mask any nodes/edges
  • "object" will mask each node/edge
  • "common_attributes" will mask each node feature/edge attribute
  • "attributes" will mask each node feature/edge attribute separately across all nodes/edges
Different node mask types available in the Explainer class

Finally, you can also set the thresholding behaviour through the ThresholdConfig. If you don’t want to threshold the explanation masks you can set this to None, alternatively you can apply a hard threshold at any value, or you can retain only the top-k values with topk or set the top-k values to 1 with topk_hard.

Explanation Evaluation

Generating an explanation is by no means the end of the explainability workflow. The quality of an explanation can be judged by a variety of different methods. PyG supports some out-of-the-box explanation evaluation metrics, you’ll find them in the the metric package.

Perhaps the most popular evaluation metric is Fidelity+/- (see [1] for details). Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+).

Fidelity+/- definitions for both Phenomenon and Model modes (source [1])

The fidelity scores capture how good an explainable model reproduces the natural phenomenon or the GNN model logic. Once we have produced an explanation we can obtain both fidelities as:

from torch_geometric.explain.metric import fidelity
fid_pm = fidelity(explainer, explanation)

We provide the characterization score as a means to combine both fidelities into a single metric [1]. Moreover, if we have fidelity pair for explanations at many different thresholds (or entropies) we can calculate the area under the fidelity curve with Fidelity curve auc. Additionally, we provide the unfaithfulness metric which evaluates how faithful an Explanation is to an underlying GNN predictor [6].

Metrics like the fidelity score and unfaithfulness are useful for evaluating explanations when there is no “ground truth” explanation available, i.e. we don’t have a predetermined set of nodes/features/edges that fully explain a particular model prediction or phenomenon. In particular when developing new explanation algorithms we might be interested in performance on certain standard benchmark datasets [1,2]. The groundtruth_metrics method compares explanation masks and returns a choice of standard metrics ( ("accuracy", "recall", "precision", "f1_score", "auroc"):

from torch_geometric.explain.metric import groundtruth_metrics
accuracy, auroc = groundtruth_metrics(pred_mask,
target_mask,
metrics=["accuracy", "auroc"])

Of course, to evaluate explainers in this way first need benchmark datasets where ground truth explanations are available.

Benchmark Datasets

In order to facilitate development and rigorous evaluation of new graph explainer algorithms PyG now provides several explainer datasets, like BA2MotifDataset, BAMultiShapesDataset, and InfectionDataset, as well as an easy way to create synthetic benchmark datasets. Support is provided via the ExplainerDataset class, which creates synthetic graphs coming from a GraphGenerator and randomly attaches num_motifs many motifs to it which come from a MotifGenerator. Ground-truth node-level and edge-level explainability masks are given based on whether nodes and edges are part of a certain motif or not.

Currently supported GraphGenerator‘s are:

  • BAGraph: Random Barabasi-Albert (BA) graphs
  • ERGraph: Random Erdos-Renyi (ER) graphs
  • GridGraph: Two-dimensional grid graph

But you can easily implement your own, by subclassing theGraphGenerator class. Additionally, for motifs we support

  • HouseMotif: House structured motif from [4]
  • CycleMotif: The cycle motif from [4]
  • CustomMotif: Easy way to add any motif based on a custom structure either from a Data object or a networkx.Graph object (e.g. a wheel shape)

The datasets we can generate with above settings are a super-class of the benchmark datasets used in GNNExplainer [4], PGExplainer [5], SubgraphX [8], PGMExplainer [9], GraphFramEx [1], etc.

Random graph generators and motif generators

We can generate new datasets on the fly with desired seeds and sizes. For example, to generate a dataset based on Barabasi-Albert graphs with 80 house motifs serving as ground truth explanation labels we would use:

from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
dataset = ExplainerDataset(
graph_generator=BAGraph(num_nodes=300, num_edges=5),
motif_generator='house',
num_motifs=80,
)

The BAMultiShapesDataset is the synthetic dataset for evaluating graph classification explainability algorithms [10]. Given three atomic motifs, namely House (H), Wheel (W), and Grid (G), BAMultiShapesDataset contains 1,000 Barabasi-Albert graphs with their labels depending on the attachment of the atomic motifs as follows:

Classes in the BAMultiShapesDataset depend on the presence of atomic motifs

The dataset is pre-computed in order to coincide with the official implementation.

Another precomputed dataset is the BA2MotifDataset[5]. It contains 1,000 Barabasi-Albert graphs. Half of the graphs are attached with a HouseMotif, and the rest are attached with a five-node CycleMotif. The graphs are assigned to one of the two classes according to they type of attached motifs. For creation of similar datasets, you can use ExplainerDatasetwith graph and motif generators.

Additionally, we provide InfectionDataset [2] generator, where the nodes predict their distance from infected nodes (yellow) and use the unique path to infected nodes as explanation. Nodes with non-unique paths to infected nodes are excluded. Non reachable nodes and nodes with at least distance max_path_length are collapsed into one class.

The infection dataset from [2]

To generate an Infection dataset, we specify a graph generator, infection path length, and the number of infected nodes

# Generate Barabási-Albert base graph
graph_generator = BAGraph(num_nodes=300, num_edges=500)
# Create the InfectionDataset to the generated base graph
dataset = InfectionDataset(
graph_generator=graph_generator,
num_infected_nodes=50,
max_path_length=3
)

We aim to add even more explanation datasets and graph generators in the future, so stay tuned!

Explainability Visualisation

As mentioned before the Explanation class provides basic visualisation functionality with two methods visualize_feature_importance() and visualize_graph().

For visualising features we can specify the number of top features to plot with top_k or pass of feature labels with feat_labels.

explanation.visualize_feature_importance(feature_importance.png, top_k=10)

The output is stored to the specified path, here is an example output from the Cora dataset explainer above:

Feature importance on Cora, for details see the gnn_explainer.py example

We can also very easily visualise the graph induced by the explanation. The output of visualize_graph() is a visualization of the explanation subgraph after filtering out edges according to their importance values (if needed, by configured threshold). We have a choice of two backends (graphviz or networkx):

explanation.visualize_graph('subgraph.png', backend="graphviz")

We get a local plot of the nodes and edges that contribute to the explanation, the edge opacity corresponds to the edge importance.

Subgraph induced by the explanation from the the gnn_explainer.py example

Implementing your own ExplainerAlgorithm

All of the explanation computation magic happens within the ExplainerAlgorithm which is passed to the Explainerclass. A variety of popular explanation algorithms (GNNExplainer, PGExplainer, etc.) have already been implemented and can be simply used. However, if you find yourself in need of an unimplemented ExplainerAlgorithm fear not, simply subclass the ExplainerAlgorithm interface and implement the two necessary abstract methods.

The forward method computes the explanations, it has the following signature

def forward(
self,
# the model used for explanations
model: torch.nn.Module,
# the input node features
x: Union[torch.Tensor, Dict[NodeType, torch.Tensor]],
# the input edge indices
edge_index: Union[torch.Tensor, Dict[NodeType, torch.Tensor]],
# the target of the model (what we are explaining)
target: torch.tensor,
# The index of the model output to explain.
# Can be a single index or a tensor of indices.
index: Union[int, Tensor], optional,
# Additional keyword arguments passed to the model
**kwargs: optional,
) -> Union[Explanation, HeteroExplanation]

To assist in constructing forward() methods for different explanation algorithms, the base class ExplainerAlgorithmprovides several useful helper functions, like _post_process_mask to post process any mask to not include any attributions of elements not involved during message passing, _get_hard_masksreturns hard node and edge masks that only include the nodes and edges visited during message passing, _num_hops to get the number of hops the model is aggregating information from, and others.

The second method that needs to be implemented is the supports() method

supports(self) -> bool

The supports() function checks if the explainer supports the user-defined settings provided in self.explainer_configand self.model_config, it checks if the explanation algorithm is defined for the particular explanation settings being used.

Extensions to Heterogeneous Graphs

The Explanation, as described above, can be simply extended to heterogenous graphs and HeteroData. In this case the explanation is also a mask, but applied to all node and edge features (with different types). For this purpose we have implemented the HeteroExplanation class, which has an almost identical interface as Explanation. Furthermore, to facilitate future work in this direction we added heterogenous graph support to the CaptumExplainer which can serve as a template for future implementations. Additionally, most of the explainability framework is already future-proof in this direction with many parameters being set to optional dictionaries for the heterogenous case.

Explaining Link prediction

For those wanting to provide explanations for link predictions, we have added GNNExplainer link explanation support. The idea is to treat edge explanation as just a new method of target indexing, by indexing into edge tensor instead of the node feature tensor. Link prediction explanations consider a union of k-hop-neighbourhoods of both endpoints.

This implementation integrates nicely with the existing code to support most explanation configurations. An example setup for explaining link prediction would look like the following

model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
# Explain model output for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]
explainer = Explainer(
model=model,
explanation_type='model',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=train_data.x,
edge_index=train_data.edge_index,
edge_label_index=edge_label_index,
)
print(f'Generated model explanations in {explanation.available_explanations}')

To see the full example, take a look gnn_explainer_link_pred.py. In order to make it easier to get started with implementing explanation methods for any task level, we also provide example parameterised tests over all task levels (graph, node, edge), interested readers can take a look at test/explain.

This was a whirlwind tour of explainability in PyG. At the moment, a lot of exciting things are being worked on in PyG, both on the graph explainability side, as well as other graph machine learning areas. If you would like to join the community of open-source developers, please check out our Slack and github pages!

Until next time, the PyG team…

References

[1] Amara, K., Ying, R., Zhang, Z., Han, Z., Shan, Y., Brandes, U., Schemm, S. and Zhang, C., 2022. GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks. arXiv preprint arXiv:2206.09677.

[2] Faber, L., K. Moghaddam, A. and Wattenhofer, R., 2021, August. When comparing to ground truth is wrong: On evaluating gnn explanation methods. In Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (pp. 332–341).

[3] Yuan, H., Yu, H., Gui, S. and Ji, S., 2022. Explainability in graph neural networks: A taxonomic survey. IEEE Transactions on Pattern Analysis and Machine Intelligence.

[4] Ying, Z., Bourgeois, D., You, J., Zitnik, M. and Leskovec, J., 2019. Gnnexplainer: Generating explanations for graph neural networks. Advances in neural information processing systems, 32.

[5] Luo, D., Cheng, W., Xu, D., Yu, W., Zong, B., Chen, H. and Zhang, X., 2020. Parameterized explainer for graph neural network. Advances in neural information processing systems, 33, pp.19620–19631.

[6] Agarwal, C., Queen, O., Lakkaraju, H. and Zitnik, M., 2022. Evaluating explainability for graph neural networks. arXiv preprint arXiv:2208.09339.

[7] Baldassarre, F. and Azizpour, H., 2019. Explainability techniques for graph convolutional networks. arXiv preprint arXiv:1905.13686.

[8] Yuan, H., Yu, H., Wang, J., Li, K. and Ji, S., 2021, July. On explainability of graph neural networks via subgraph explorations. In International Conference on Machine Learning(pp. 12241–12252). PMLR.

[9] Vu, M. and Thai, M.T., 2020. Pgm-explainer: Probabilistic graphical model explanations for graph neural networks. Advances in neural information processing systems, 33, pp.12225–12235.

[10] Azzolin, S., Longa, A., Barbiero, P., Liò, P. and Passerini, A., 2022. Global explainability of gnns via logic combination of learned concepts. arXiv preprint arXiv:2210.07147.

--

--

PyTorch Geometric

Open-source framework for working with Graph Neural Networks