Graph Machine Learning Explainability with PyG
By Blaž Stojanovič
Graph Neural Networks (GNNs) have become increasingly popular for processing graph-structured data, such as social networks, molecular graphs, and knowledge graphs. However, the complex nature of graph-based data and the non-linear relationships between nodes in a graph can make it difficult to understand why a GNN makes a particular prediction. With the rise in popularity of Graph Neural Networks, there also came an increased interest in explaining their predictions.
The importance of explanations in practical Machine Learning applications cannot be overstated. They help build trust and transparency in models, as users can better understand how predictions are made and the factors that influence them. They improve decision-making, giving decision-makers additional understanding to make better informed decisions based on model predictions. Explanations also make it easier for practitioners to debug and improve the performance of the models they develop. In certain domains, such as finance and healthcare, explanations may even be required due to compliance and regulations.
Explanations in Graph Machine Learning are very much an ongoing research effort, and explainability on graphs is not as mature as interpretability in other subfields of ML, like computer vision or NLP. Additionally the explanations themselves differ, due to complex relational data GNNs operate on:
- Contextual Understanding: Explanations require a contextual understanding of the relationships and entities between nodes in the graph, which can be complex and difficult to understand
- Dynamic relationships: The relationships between nodes in a graph can change over time, making it challenging to provide explanations for predictions made at different points in time
- Heterogeneous Data: GML often involves processing heterogenous data types with complex features, making it difficult to provide a unified explanation methodology
- Explanation granularity: Explanations have to explain both the structural origin of the prediction, as well as feature importance. This means explaining which nodes, edges, or subgraphs are important, as well as which node or edge features contribute strongly to the prediction outcome.
Difficulties and complexities of Graph Machine Learning aside, there has been a lot of unifying work in the field recently, which aims to both provide a unified framework for evaluating the explanations [1,2], and provide a taxonomy of the existing zoo of explanation methods available [3].
In a recent community sprint the PyG community has implemented a core explainability framework along with various evaluation methods, benchmark datasets, and visualisations, which make it very easy to get started with Graph Machine Learning explanations in PyG. Moreover the framework is useful both if you just want to use common graph explainers like GNNExplainer [4] or PGExplainer [5] out of the box, or if you want to implement, test, and evaluate your own explanation methodologies.
In this blog post we will go step by step through the explainability module, shedding light on how each component of the framework works and what purpose it servers. Afterwards, we will go over various explanation evaluation methods and synthetic benchmarks, which work hand in hand to make sure you produce the best explanations for the task at hand. We will continue by taking a look at visualisation methods which are available out of the box. Finally, we will go over steps necessary to implement your own explanation method in PyG, as well as highlighting work on advanced use cases such as Heterogenous graphs and Link prediction explanations.
The Framework
When designing the Explainability framework our goal was to design an easy to use explainability module, which:
- can be extended to meet requirements of many GNN applications
- can be adapted to various types of graphs and explanation settings
- can provide explanation output to be comprehensively evaluated and visualised
There are really four concepts at the core of the framework:
- Explainer class: a wrapper of the PyG explainability module for instance level explanations
- Explanation class: the class to encapsulate the output of an Explainer
- ExplainerAlgorithm class: the explainability algorithm used by Explainer to generate Explanation(s) given training instance(s)
- metric package: evaluation metrics that use the Explanation output and potentially the GNN model / ground truth to evaluate the ExplainerAlgorithm
To see how they all come together, let us take a look at the figure below:
The user provides the explanation settings, as well as the model and data which need to be explained. The Explainer class, which is a PyG instance that wraps an explainer algorithm — a particular explanation method, generates the explanations for given model and data. The explanations are encapsulated in the Explanation class and can be further post-processed, visualised, and evaluated. Let us now go more into depth regarding the various explanation settings available
Example Explaier
Here is an example Explainer
setup, that uses the GNNExplainer
for model explanations on the Cora dataset (see the gnn_explainer.py example).
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
The node level mask is set for all attributes and the edge mask is set for edges as objects. To produce an explanation for a particular prediction of the model we simply call the explainer:
node_index = 10 # which node index to explain
explanation = explainer(data.x, data.edge_index, index=node_index)
Let us now take a look at all of the nuts and bolts, that make explanations in PyG this easy!
The Explanation Class
We represent explanations using the Explanation
class, which is a Data
or HeteroData
object containing masks for the nodes, edges, features and any attributes of the data. In this paradigm the masks act as explanation attributions for respective nodes/edges/features. The larger the mask value the more important is the corresponding component to the explanation (0 being completely unimportant). The Explanation
class contains methods for obtaining the induced explanation subgraph, which is comprised of all non-zero explanation attributions, and the compliment to the explanation subgraph. Additionally, it includes thresholding and visualisation methods for the explanation.
The Explainer Class and Explanation Settings
The Explainer
class is designed to handle all of the explainability settings, these settings are set either as direct parameters to the Explainer
or as configs in the case of ModelConfig
or ThresholdConfig
. There are many settings provided with this new interface. Let’s go over the available ones one-by-one.
# Explainer Parameters
model: torch.nn.Module,
algorithm: ExplainerAlgorithm,
explanation_type: Union[ExplanationType, str],
model_config: Union[ModelConfig, Dict[str, Any]],
node_mask_type: Optional[Union[MaskType, str]] = None,
edge_mask_type: Optional[Union[MaskType, str]] = None,
threshold_config: Optional[ThresholdConfig] = None,
The model
can be any PyG model we are using to produce the explanations. Additional model settings are specified in the ModelConfig
which specifies the mode
, the task_level
, and the return_type
of the model. The mode
describes the task type, e.g. mode=multiclass-classification
, the task_level
signifies the task level (node-, edge-, or -graph level tasks), and the return_type
describes the expected return type of the model (raw
, probs
, or log_probs
).
There are two types of explanations, as specified with explanation_type
(for a more in depth discussion see [1])
explanation_type="phenomenon"
aims to explain why a certain decision was made for a particular input. We are interested in the phenomenon that leads from inputs to outputs in our data. In this case the labels are used as targets for the explanation.explanation_type="model"
aims to provide a post-hoc explanation for themodel
provided. In this setting we are trying to open the black-box and explain the logic behind it. In this case the model predictions are used as targets for the explanation.
How exactly the Explanation
‘s are computed is specified with the algorithm
parameter, several off-the-shelf are available in the module:
GNNExplainer
: The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper.PGExplainer
: The PGExplainer model from the “Parameterized Explainer for Graph Neural Network” paper.AttentionExplainer
: An explainer that uses the attention coefficients produced by an attention-based GNN (e.g.,GATConv
,GATv2Conv
, orTransformerConv
) as edge explanationsCaptumExplainer
: A Captum-based explainerGraphMaskExplainer
: The GraphMask-Explainer from the Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking paper (part oftorch_geometric.contrib
at the moment)PGMExplainer
: The PGMExplainer model from the PGMExplainer: Probabilistic Graphical Model Explanations for Graph Neural Networks paper (part oftorch_geometric.contrib
at the moment)
We also support many different types of masks, these are set with node_mask_type
and edge_mask_type
, and can be:
None
will not mask any nodes/edges"object"
will mask each node/edge"common_attributes"
will mask each node feature/edge attribute"attributes"
will mask each node feature/edge attribute separately across all nodes/edges
Finally, you can also set the thresholding behaviour through the ThresholdConfig
. If you don’t want to threshold the explanation masks you can set this to None
, alternatively you can apply a hard
threshold at any value, or you can retain only the top-k values with topk
or set the top-k values to 1 with topk_hard
.
Explanation Evaluation
Generating an explanation is by no means the end of the explainability workflow. The quality of an explanation can be judged by a variety of different methods. PyG supports some out-of-the-box explanation evaluation metrics, you’ll find them in the the metric package.
Perhaps the most popular evaluation metric is Fidelity+/- (see [1] for details). Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+).
The fidelity scores capture how good an explainable model reproduces the natural phenomenon or the GNN model logic. Once we have produced an explanation we can obtain both fidelities as:
from torch_geometric.explain.metric import fidelity
fid_pm = fidelity(explainer, explanation)
We provide the characterization score as a means to combine both fidelities into a single metric [1]. Moreover, if we have fidelity pair for explanations at many different thresholds (or entropies) we can calculate the area under the fidelity curve with Fidelity curve auc. Additionally, we provide the unfaithfulness
metric which evaluates how faithful an Explanation
is to an underlying GNN predictor [6].
Metrics like the fidelity score and unfaithfulness are useful for evaluating explanations when there is no “ground truth” explanation available, i.e. we don’t have a predetermined set of nodes/features/edges that fully explain a particular model prediction or phenomenon. In particular when developing new explanation algorithms we might be interested in performance on certain standard benchmark datasets [1,2]. The groundtruth_metrics
method compares explanation masks and returns a choice of standard metrics ( ("accuracy"
, "recall"
, "precision"
, "f1_score"
, "auroc"
):
from torch_geometric.explain.metric import groundtruth_metrics
accuracy, auroc = groundtruth_metrics(pred_mask,
target_mask,
metrics=["accuracy", "auroc"])
Of course, to evaluate explainers in this way first need benchmark datasets where ground truth explanations are available.
Benchmark Datasets
In order to facilitate development and rigorous evaluation of new graph explainer algorithms PyG now provides several explainer datasets, like BA2MotifDataset
, BAMultiShapesDataset
, and InfectionDataset
, as well as an easy way to create synthetic benchmark datasets. Support is provided via the ExplainerDataset
class, which creates synthetic graphs coming from a GraphGenerator
and randomly attaches num_motifs
many motifs to it which come from a MotifGenerator
. Ground-truth node-level and edge-level explainability masks are given based on whether nodes and edges are part of a certain motif or not.
Currently supported GraphGenerator
‘s are:
BAGraph
: Random Barabasi-Albert (BA) graphsERGraph
: Random Erdos-Renyi (ER) graphsGridGraph
: Two-dimensional grid graph
But you can easily implement your own, by subclassing theGraphGenerator
class. Additionally, for motifs we support
HouseMotif
: House structured motif from [4]CycleMotif
: The cycle motif from [4]CustomMotif
: Easy way to add any motif based on a custom structure either from aData
object or anetworkx.Graph
object (e.g. a wheel shape)
The datasets we can generate with above settings are a super-class of the benchmark datasets used in GNNExplainer [4], PGExplainer [5], SubgraphX [8], PGMExplainer [9], GraphFramEx [1], etc.
We can generate new datasets on the fly with desired seeds and sizes. For example, to generate a dataset based on Barabasi-Albert graphs with 80 house motifs serving as ground truth explanation labels we would use:
from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
dataset = ExplainerDataset(
graph_generator=BAGraph(num_nodes=300, num_edges=5),
motif_generator='house',
num_motifs=80,
)
The BAMultiShapesDataset
is the synthetic dataset for evaluating graph classification explainability algorithms [10]. Given three atomic motifs, namely House (H), Wheel (W), and Grid (G), BAMultiShapesDataset
contains 1,000 Barabasi-Albert graphs with their labels depending on the attachment of the atomic motifs as follows:
The dataset is pre-computed in order to coincide with the official implementation.
Another precomputed dataset is the BA2MotifDataset
[5]. It contains 1,000 Barabasi-Albert graphs. Half of the graphs are attached with a HouseMotif
, and the rest are attached with a five-node CycleMotif
. The graphs are assigned to one of the two classes according to they type of attached motifs. For creation of similar datasets, you can use ExplainerDataset
with graph and motif generators.
Additionally, we provide InfectionDataset
[2] generator, where the nodes predict their distance from infected nodes (yellow) and use the unique path to infected nodes as explanation. Nodes with non-unique paths to infected nodes are excluded. Non reachable nodes and nodes with at least distance max_path_length
are collapsed into one class.
To generate an Infection dataset, we specify a graph generator, infection path length, and the number of infected nodes
# Generate Barabási-Albert base graph
graph_generator = BAGraph(num_nodes=300, num_edges=500)
# Create the InfectionDataset to the generated base graph
dataset = InfectionDataset(
graph_generator=graph_generator,
num_infected_nodes=50,
max_path_length=3
)
We aim to add even more explanation datasets and graph generators in the future, so stay tuned!
Explainability Visualisation
As mentioned before the Explanation
class provides basic visualisation functionality with two methods visualize_feature_importance()
and visualize_graph()
.
For visualising features we can specify the number of top features to plot with top_k
or pass of feature labels with feat_labels
.
explanation.visualize_feature_importance(feature_importance.png, top_k=10)
The output is stored to the specified path, here is an example output from the Cora dataset explainer above:
We can also very easily visualise the graph induced by the explanation. The output of visualize_graph()
is a visualization of the explanation subgraph after filtering out edges according to their importance values (if needed, by configured threshold). We have a choice of two backends (graphviz
or networkx
):
explanation.visualize_graph('subgraph.png', backend="graphviz")
We get a local plot of the nodes and edges that contribute to the explanation, the edge opacity corresponds to the edge importance.
Implementing your own ExplainerAlgorithm
All of the explanation computation magic happens within the ExplainerAlgorithm
which is passed to the Explainer
class. A variety of popular explanation algorithms (GNNExplainer, PGExplainer, etc.) have already been implemented and can be simply used. However, if you find yourself in need of an unimplemented ExplainerAlgorithm
fear not, simply subclass the ExplainerAlgorithm
interface and implement the two necessary abstract methods.
The forward
method computes the explanations, it has the following signature
def forward(
self,
# the model used for explanations
model: torch.nn.Module,
# the input node features
x: Union[torch.Tensor, Dict[NodeType, torch.Tensor]],
# the input edge indices
edge_index: Union[torch.Tensor, Dict[NodeType, torch.Tensor]],
# the target of the model (what we are explaining)
target: torch.tensor,
# The index of the model output to explain.
# Can be a single index or a tensor of indices.
index: Union[int, Tensor], optional,
# Additional keyword arguments passed to the model
**kwargs: optional,
) -> Union[Explanation, HeteroExplanation]
To assist in constructing forward()
methods for different explanation algorithms, the base class ExplainerAlgorithm
provides several useful helper functions, like _post_process_mask
to post process any mask to not include any attributions of elements not involved during message passing, _get_hard_masks
returns hard node and edge masks that only include the nodes and edges visited during message passing, _num_hops
to get the number of hops the model
is aggregating information from, and others.
The second method that needs to be implemented is the supports()
method
supports(self) -> bool
The supports()
function checks if the explainer supports the user-defined settings provided in self.explainer_config
and self.model_config
, it checks if the explanation algorithm is defined for the particular explanation settings being used.
Extensions to Heterogeneous Graphs
The Explanation
, as described above, can be simply extended to heterogenous graphs and HeteroData
. In this case the explanation is also a mask, but applied to all node and edge features (with different types). For this purpose we have implemented the HeteroExplanation
class, which has an almost identical interface as Explanation
. Furthermore, to facilitate future work in this direction we added heterogenous graph support to the CaptumExplainer
which can serve as a template for future implementations. Additionally, most of the explainability framework is already future-proof in this direction with many parameters being set to optional dictionaries for the heterogenous case.
Explaining Link prediction
For those wanting to provide explanations for link predictions, we have added GNNExplainer link explanation support. The idea is to treat edge explanation as just a new method of target indexing, by indexing into edge tensor instead of the node feature tensor. Link prediction explanations consider a union of k-hop-neighbourhoods of both endpoints.
This implementation integrates nicely with the existing code to support most explanation configurations. An example setup for explaining link prediction would look like the following
model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
# Explain model output for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]explainer = Explainer(
model=model,
explanation_type='model',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=train_data.x,
edge_index=train_data.edge_index,
edge_label_index=edge_label_index,
)
print(f'Generated model explanations in {explanation.available_explanations}')
To see the full example, take a look gnn_explainer_link_pred.py. In order to make it easier to get started with implementing explanation methods for any task level, we also provide example parameterised tests over all task levels (graph, node, edge), interested readers can take a look at test/explain.
This was a whirlwind tour of explainability in PyG. At the moment, a lot of exciting things are being worked on in PyG, both on the graph explainability side, as well as other graph machine learning areas. If you would like to join the community of open-source developers, please check out our Slack and github pages!
Until next time, the PyG team…
References
[1] Amara, K., Ying, R., Zhang, Z., Han, Z., Shan, Y., Brandes, U., Schemm, S. and Zhang, C., 2022. GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks. arXiv preprint arXiv:2206.09677.
[2] Faber, L., K. Moghaddam, A. and Wattenhofer, R., 2021, August. When comparing to ground truth is wrong: On evaluating gnn explanation methods. In Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (pp. 332–341).
[3] Yuan, H., Yu, H., Gui, S. and Ji, S., 2022. Explainability in graph neural networks: A taxonomic survey. IEEE Transactions on Pattern Analysis and Machine Intelligence.
[4] Ying, Z., Bourgeois, D., You, J., Zitnik, M. and Leskovec, J., 2019. Gnnexplainer: Generating explanations for graph neural networks. Advances in neural information processing systems, 32.
[5] Luo, D., Cheng, W., Xu, D., Yu, W., Zong, B., Chen, H. and Zhang, X., 2020. Parameterized explainer for graph neural network. Advances in neural information processing systems, 33, pp.19620–19631.
[6] Agarwal, C., Queen, O., Lakkaraju, H. and Zitnik, M., 2022. Evaluating explainability for graph neural networks. arXiv preprint arXiv:2208.09339.
[7] Baldassarre, F. and Azizpour, H., 2019. Explainability techniques for graph convolutional networks. arXiv preprint arXiv:1905.13686.
[8] Yuan, H., Yu, H., Wang, J., Li, K. and Ji, S., 2021, July. On explainability of graph neural networks via subgraph explorations. In International Conference on Machine Learning(pp. 12241–12252). PMLR.
[9] Vu, M. and Thai, M.T., 2020. Pgm-explainer: Probabilistic graphical model explanations for graph neural networks. Advances in neural information processing systems, 33, pp.12225–12235.
[10] Azzolin, S., Longa, A., Barbiero, P., Liò, P. and Passerini, A., 2022. Global explainability of gnns via logic combination of learned concepts. arXiv preprint arXiv:2210.07147.