Graph Analysis Made Easy with PyG Explainability
By: Anh Hoang Nguyen, Rajanie Prabha, Kevin Su as part of the Stanford CS224W course project.
Greetings, fellow graph enthusiasts! Do you want to understand why your GNN is doing what it’s doing, without having to resort to mind-reading or witchcraft? Probably staring at your model for a really long time works. In case that doesn’t give you the ‘why’ of things, you can resort to the Pytorch Geometric Explainability module. Please refer to the Colab for the full code!
Quick outline of the blog:
- First, we train a GNN model on the node property prediction task using the ogbn-arxiv dataset
- Then, we’ll use PyG’s explainability module to apply the GNNExplainer algorithm (Ying et al. 2019) on explaining this model’s predictions as well as how to visualize the explanation result
- Talk about Explanation evaluation metrics
- Take a closer look at PyG’s implementation of the GNNExplainer algorithm
Why AI Explainability Matters: Trust, Ethics, and Accountability
Explainability in Machine Learning is an ongoing effort in the research community. In traditional machine learning pipelines, neural networks are often viewed as black boxes that learn arbitrary non-linear functions that optimize certain objective functions. Many issues can arise, however, because there are no guarantees that neural networks actually perform the task that it was optimized to do. Unwanted behavior can creep into neural networks from bad data, or from artifacts of the inductive bias of the network. The field of machine learning explainability aims to build trust and transparency in models by providing important context for the reasons why a machine learning model made predictions.
The PyG community has been actively working on an explainability framework along with many benchmark datasets, evaluation methods, etc., to start exploring the world of interpretability in Graph Machine Learning.
What PyG’s Explainability Module Can Do:
PyG’s explainability module provides several tools for gaining insights into the decision-making processes of GNNs. The PyTorch Geometric explainability module provides a powerful set of tools for understanding how our models are making decisions based on graphs. By providing detailed visualizations and explanations of the decision-making process, we can gain a deeper understanding of the inner workings of our models and make more informed decisions about how to improve them.
The Explainability Toolset
The PyG Explainability module has four main parts:
- Explainer: Class for instance-level explanations of GNNs. Explainer is the centerpiece of the Explainability module. On a high level, it takes in:
- a model to explain,
- explanation configurations, and
- an explainability algorithm (represented by ExplainerAlgorithm class). The output of this class is an explanation wrapped in the Explanation class.
2. Explanations: The Explanation
class is a type of Data
or HeteroData
object that holds masks for different components of the graph data such as nodes, edges, and features, along with their attributes. With the help of the Explanation class, one can extract the induced explanation subgraph, which consists of non-zero explanation attributions, and the complement to the explanation subgraph. The class also provides methods for thresholding and visualizing the explanation.
3. Algorithms: An abstract base class for implementing explainer algorithms. We plan to use GNNExplainer to explain the reasons for model recommendations trained on the ogbn-arxiv dataset. Given a trained model and a prediction, GNNExplainer identifies a subgraph structure and a subset of node features that are most influential for the prediction.
4. Explanation Metrics: Metrics to judge the quality of explanations [Explained later in detail].
Okay, but how do these interpretations work?
In the context of GNNs, an “explanation” refers to a subset of the original graph that is represented as a mask or subgraph. This subset consists of weighted nodes, edges, and possibly node features, and the weights assigned to these entities reflect their relative significance in explaining the model’s results.
More formally, we can define our problem as follows.
Let G = (V, E) represent a graph with |𝑉 | nodes and |𝐸| edges. Each node can have d-dimensional features. Also, we can treat our GNN model as a function f: V → Y where in the context of node classification, Y is the finite set of possible labels.
Our explanation for the prediction class 𝑦ₜ of a target node 𝑣ₜ will consist of an edge mask M_E(𝐸, 𝑓, 𝑣ₜ, 𝑦ₜ) ∊ℝ|𝑉|×|𝑉| and for a complex model, the node feature mask M_NF(V, 𝑓, 𝑣ₜ, 𝑦ₜ) ∊ℝ|𝑉|×d where each element is an importance score of that edge or node feature to the prediction. The importance score (sometimes also referred to as the weight) lies in the range [0,1] for soft masking, and {0,1} for hard masking.
(See Amara et al. (2022) for more details)
Dataset ogbn-arxiv: Scholarly Network Exploration
Let’s unlock the secrets of scholarly communication networks. The ogbn-arxiv dataset, an Open Graph Benchmark (OBG) dataset, provides a wealth of graph data that can be used to gain valuable insight into the world of research papers and their references. With millions of nodes and edges, this dataset offers a unique opportunity to understand the relationships and patterns of scholarly work.
Each paper is represented as a node, and each edge between nodes represents a citation relationship, where one paper cites another. The dataset also includes a 128-dimensional feature vector for each paper, which is created by averaging the embeddings of words in the paper’s title and abstract. The word embeddings are generated using the skip-gram model applied over the MAG corpus. The dataset contains over 1.6 million papers and over 19 million citation relationships.
Prediction task: a 40-class classification problem that aims to predict the primary categories of arXiv papers where categories are subject areas of the arXiv CS papers, such as cs.AI, cs.LG, and cs.OS.
Data split: the idea is to train models on past papers and subsequently use them to predict the topic areas of newly released papers. Specifically, papers published before 2017 are used for training, those published in 2018 for validation, and those released after 2019 for testing.
Let’s put on our coding caps and explore this resource.
Model: Graph Convolution Networks
Let’s use Graph Convolution Network (GCN) to build our GNN model (Kipf et al. 2017). The PyG’s built-in GCNConv layer will come in handy for our implementation. BN is the Batch Normalization layer, followed by ReLU activation and a Dropout layer.
The below snippet shows how you can define various layers for your network:
self.convs = torch.nn.ModuleList(
[GCNConv(in_channels=input_dim,out_channels=hidden_dim)] +
[GCNConv(in_channels=hidden_dim,out_channels=hidden_dim)
for i in range(num_layers - 2)] +
[GCNConv(in_channels=hidden_dim, out_channels=output_dim)]
)
self.bns = torch.nn.ModuleList(
[torch.nn.BatchNorm1d(num_features=hidden_dim)
for i in range(num_layers - 1)]
)
self.softmax = torch.nn.LogSoftmax()
And, this is how you can proceed with the forward function:
out = None
for i in range(len(self.bns)):
x = self.convs[i](x, adj_t)
x = self.bns[i](x)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
if not self.return_embeds:
x = self.softmax(x)
out = x
Here, adj_t is the Graph connectivity in COO format with shape [2, num_edges].
And, this is how you can train your model. We are using NLL (Negative Log Likelihood Loss for this) and use accuracy as the evaluation metric.
def train(model, data, train_idx, optimizer, loss_fn):
model.train()
loss = 0
optimizer.zero_grad()
o = model(data.x, data.edge_index)
o = o[train_idx]
y = torch.flatten(data.y[train_idx])
loss = loss_fn(o, y)
loss.backward()
optimizer.step()
return loss.item()
def _eval_acc(self, y_true, y_pred):
acc_list = []
for i in range(y_true.shape[1]:
is_labeled = y_true[:,i] == y_true[:,i]
correct = y_true[is_labeled,i] == y_pred[is_labeled,i]
acc_list.append(float(np.sum(correct))/len(correct))
return {'acc': sum(acc_list)/len(acc_list)}
### Training Params:
args = {
'device': device,
'num_layers': 3,
'hidden_dim': 64,
'dropout': 0.5,
'lr': 0.01,
'epochs': 100
}
What do we get?
Train Accuracy: 70.73%, Valid Accuracy: 70.17% Test Accuracy: 69.36%
Now, you are wondering, What? Why? How? What does it mean? We can answer some of these burning questions via the PyG explainability framework!
Let’s open this black box:
IT IS GAME TIME!
Below is the code snippet that initializes the Explainer class. Specifically, it focuses on explaining the behavior of a single node in the graph by identifying the top 40 contributing features. The Explainer
object is instantiated with various configuration parameters, including the model
being explained, the algorithm
used for generating explanations (GNNExplainer in this case), and the explanation_type
, which is set to 'model' indicating that the goal is to explain the model's behavior rather than individual predictions.
The node_mask_type
and edge_mask_type
parameters indicate that the explanation will be focused on node attributes and object-level edges. Additionally, the model_config
dictionary specifies the task being performed by the model (multiclass_classification
) and the level at which the explanation is generated (node
). The threshold_config
the dictionary specifies the method used to threshold the contributions and is set to the top 40 features.
# Explanability for a single node
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.explain.metric import unfaithfulness, fidelity
# Threshold contributions by the top 40 features.
topk = 40
node_index=10
explainer_individual = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
threshold_config=dict(threshold_type = 'topk', value=topk)
)
# Get explanation for node indexed 10
explanation_individual = explainer_individual(data.x, data.edge_index, index=node_index)
In order to interpret the prediction for node indexed 10, the GNNExplainer provides the below feature importance bar plot for the top 10 features.
We can visualize feature importance and path graph by:
path_features = "feature_importance.png"
explanation_individual.visualize_feature_importance(path_features, top_k=10)
path_graph = "graph_importance.png"
explanation_individual.visualize_graph(path_graph)
How do we know if this makes sense?
Metrics are all you need!
- unfaithfulness
This metric evaluates how faithful an Explanation is to an underlying GNN predictor, as described in the paper (Agarwal and Queen 2023). GEF (graph explanation unfaithfulness) can be expressed as:
where y refers to the prediction probability vector obtained from the original graph, and y_hat refers to the prediction probability vector obtained from the masked subgraph. Finally, the Kullback-Leibler (KL) divergence score quantifies the distance between the two probability distributions.
- fidelity
Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+). The fidelity scores capture how well an explainable model reproduces the natural phenomenon or the GNN model logic. It evaluates the fidelity of an Explainer given an Explanation, as described in the paper (Amara et. al. 2022).
For phenomenon explanations, the fidelity scores are given by:
For model explanations, the fidelity scores are given by:
For metrics:
explanation_metrics = explainer_metrics(data.x, data.edge_index)
print(f'Generated explanations in {explanation_metrics.available_explanations}')
fid_pm = fidelity(explainer_metrics, explanation_metrics)
print("Fidelity:", fid_pm)
char_score = characterization_score(fid_pm[0], fid_pm[1])
print("Characterization score:", char_score)
The Fidelity values achieved are (0.9929, 0.4951)
Characterization score: 0.7426
But what exactly does the above result mean for our explanation?
Before going into what this score means, we need some background on what types of explanations are considered good.
Given a GNN model, there can be many possible explanations that can be put into 2 categories as proposed in Amara et al. (2022)!
- Sufficient Explanation: An explanation is considered sufficient if it can independently lead to the model’s initial prediction. The same prediction may have multiple sufficient explanations due to the graph’s different configurations. The fidelity scores fid– of a sufficient explanation is close to 0.
- Necessary Explanation: An explanation is considered necessary if the model’s prediction changes when it’s removed from the original graph. Necessary explanations have a fidelity score fid+ close to 1.
An explanation is a characterization of the prediction if it is both necessary and sufficient.
The characterization_score is recommended as a global evaluation metric. The reason being since it is a weighted harmonic mean of fid+ and 1-fid–, it can balance the sufficiency and necessity requirements for an explanation. Ring any bells? We can’t read your mind but we bet you are also thinking about the F1-score that combines precision and recall.
The characterization score with equal weights on fid+ and 1-fid– is low as soon as one of the two terms is low.
Our explanation has a characterization score of 0.7 which means it is pretty good!
Implementing Your Own Explainer Algorithm
To implement your own custom explainability function, you can extend the ExplainerAlgorithm abstract base class. All you need to implement are the following two abstract methods:
- forward(model, x, edge_index, target, index): The function that computes the explanations. model is the model used for explanations, x is the input node features, edge_index is the input edge features, target is the target of the model that is being explained (used in phenomena explanations), and index is the index of the model output to explain.
- supports(self) -> bool: Returns whether or not the algorithm supports the current explainer_config and model_config parameters
The ExplainerAlgorithm abstract base class also contains a list of helpful utility functions that can be used when implementing your explainer algorithm. See the source for more details.
Finally, A Closer Look at GNNExplainer
In case you are still lost, let’s take a closer look at the GNNExplainer algorithm under the hood so that we can examine how this explainability algorithm works.
In the GNNExplainer problem setup, we perform multi-label node classification. We let G be a graph on edges E and nodes V with d-dimensional node features 𝞦={𝑥₁…𝑥ₙ}. Additionally, let f be a label function on nodes f: V → {1,….C} that maps nodes to one of the C classes. The GNN model ɸ seeks to approximate f.
Recall that at layer 1, the update of a GNN model ɸ involves three key computations.
- First, the model computes a message for every pair of nodes. This message function is
where h terms are corresponding representations for noded 𝑣ᵢ and 𝑣ⱼ in layer l-1 and rᵢⱼ is the relation between nodes.
2. Second, for each node 𝑣ᵢ the GNN aggregates messages from 𝑣ᵢ’s neighborhood Nᵥᵢ, and aggregates the messages via an aggregation function
3. Finally, the GNN takes the aggregated message M(l)ᵢ and 𝑣ᵢ’s previous representation h(l-1)ᵢ, and non-linearly transforms them to obtain 𝑣ᵢ’s representation
The key insight of GNNExplainer is that the computation graph of a node 𝑣, which is defined by the aggregation procedure, fully determines all of the information the GNN uses to make a prediction. The mathematical details of the specific optimization that GNNExplainer utilizes are out of scope for this tutorial, mostly because there is a different objective function for single-instance explanations vs. joint learning of graph structural and node feature information. For greater detail, see (Ying et al., 2020).
Conclusion
In this blog post, we have trained a GCN on the ogbn-arxiv
dataset (an Open Graph Benchmark (OBG) dataset) and showed how we can use PyG's built-in explainability module in order to produce explanations for both node features and graph structure. Then, we have examined ways to evaluate the quality of the explanations, and also how the algorithm GNNExplainer works under the hood. Hopefully, this tutorial has given you a greater understanding of the capabilities of the Explanability module in PyG and now life has started to make sense again!
References:
- Pytorch Geometric Docs.
- Ying et. al., GNNExplainer: Generating Explanations for Graph Neural Networks, 2019.
- Amara et. al., GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks, 2022.
- Agarwal and Queen et. al., Evaluating explainability for graph neural networks, 2023.
- Blaž Stojanovič, Graph Machine Learning Explainability with PyG, 2013, Medium blog post.