📝 eXplainability, Graphs, and ML

Adrianna Janik
Labs Notebook
Published in
12 min readAug 28, 2023

--

Adrianna Janik and Eda Bayram

With the growing adoption of AI solutions, explainability needs to grow even faster. In June 2023, the EU Commission approved a draft of the legal framework called the AI Act¹ that mandates AI systems to provide explanations to users¹. It means that model’s explainability is no longer a nice-to-have feature but increasingly a requirement. Looking at the current trends, for example, in the Hype Cycle for AI by Gartner, we can see an exciting opportunity that arises. Generative AI requires responsible AI practices and technologies, but how do we get there? The answer can lie within the space of Knowledge Graphs.

Sketch of Gartner’s Hype Cycle for Artificial Intelligence, 2023

The post has six parts:

  1. Introduction to graph machine learning with knowledge graphs and broad explainability in ML.
  2. What is an explanation in graph machine learning?
  3. Explanations for Knowledge Graph Embedding Models on the Link Prediction Task
  4. Why knowledge graphs are crucial in the context of other Machine Learning models.
  5. What Accenture Labs is doing?
  6. Do it Yourself — code examples to recreate some of the materials listed in the post.

In this article, we will focus on the broadly understood explainability of graph machine learning models. We will first look at how graph machine learning differs from “regular” machine learning, and then we will look into what is an explanation in the context of machine learning in general and how it translates to the world of graph machine learning and specifically to link prediction task on knowledge graphs. We will conclude with a section on how knowledge graphs can assist other machine learning branches like Generative AI.

1. What is machine learning on graphs?

Machine learning (ML) is an umbrella term. It includes many different aspects of learning utilising a variety of data. Specifically, Graph ML looks at the data in a relational form, often called a graph. We can define a graph as a set of relations between entities. The scope only starts here. We can look at the specific type of graphs called knowledge graphs. The definition will slightly differ as it includes some additional properties. The first one is that all entities and relations have types. We can see an example of gene PRKN and its association with Parkinson’s Disease, as depicted below.

Visual representation of a triple. Gene PRKN (subject) connected with a relation (predicate) hasRelationTo to an object Disease: Parkinson’s Disease.
Triple — a basic unit of knowledge graph, next to entities and edges. We can represent the PRKN as a node which we know is of type Gene and link it via hasRelationTo edge with node Parkinson’s Disease with a type of Disease. Icons source 1 and 2.

We can define different tasks on such data representation. For example, we can try to predict if certain links are missing and if these are plausible if the graph is incomplete or evolving through time. Another task could be classifying its nodes based on properties. Yet another one could be classifying the whole graphs.

Simplified Hetionet⁷ graph schema — graph representing the connections between different entity types with edge types.

Graph machine learning comes in many forms, one of them is Graph Neural Networks (GNNs) for which we encourage readers to see this gentle introduction to GNNs⁸ an example of a popular explainability method GNNExplainer for link prediction using GNNs is presented in section 5. The other form of graph machine learning that we will discuss more are Knowledge Graph Embedding models (KGEs). For in-depth guide on the topic we encourage to refer to these resources: tutorial on KGEs during ECAI 2020, tutorial on KGEs for NLP during COLING 2022 and also to tutorial on Link Predictions from Raw Data to Insights during KGC 2023.

With a plurality of tasks and kinds of graph machine learning, defining explainable graphML is not easy. Let’s see why in the next section.

🗃 2. What is an explanation for ML on Graphs?

Explanation is a very vague concept. There have been several attempts to formalise it² in machine learning. We define the explanation of the model’s prediction as the answer to the question: why did the model make such prediction? Of course, we are not interested in any explanation but the ones that provide evidence supported by data and display utility for the user.

Visualisation of PCA projection of Hetionet⁷ latent space. The model was trained with TransE. To recreate, see Example 1 in section 5. Do it yourself.

We can distinguish two main approaches in explainability literature: explaining the black box or designing the transparent box. By transparent box, we mean that the model is on its own explainable, and the parameters can be easily interpretable and understood, for example, linear regression or decision trees. Explaining the so-called black box means that we have to build external methods for models that are not explainable by themself. These can be applied post-training, also called post-hoc or may require modification of the training loop.

What is an explanation for GraphML? It depends on many factors. First of all, what task the machine learning tries to solve? For GraphML, we can distinguish quite a few, starting from link prediction, knowledge completion, node classification, graph classification, and even answering complex queries. Each of these tasks can define its approaches to explain predictions. We will focus specifically on the Link Prediction task with Knowledge Graph Embedding models.

🔗 3. Explanations for Knowledge Graph Embedding Models on the Link Prediction Task

Let us define an explanation for a prediction of a machine learning model as a collection of training samples designated to be the most influential for that prediction. There are different approaches on how to find such explanations. In the context of training knowledge graph embedding models for link prediction tasks, the input is triple <PRKN, is related to, Parkinson’s Disease>. The output is a score representing the plausibility of the triple according to the model (e.g. 99%). Therefore, an explanation will be a set of ranked training triples responsible for such a score.

👱 Human Readability

However, when we have a system that makes predictions for e.g. a medical doctor who uses a Clinical Decision Support System, often the internals of the system is hidden from the user as it would not make sense to anyone with the adequate background knowledge. A crucial part of a system is a well-designed human-computer interaction to effectively convey the information from the machine learning model to the user in a given context.

The problem grows with the addition of explanations since the concept is difficult to define. From the user’s perspective, the explanation should make sense and from the model’s perspective it has to be faithful to the prediction. Let us consider the same medical doctor as before. In this case, the end-user is typically not familiar with the input data format or each data entity and relation type stored in the knowledge graph. Thus, the user may not reason about the relevancy of the provided input samples by the explanation sub-system to the output, although the purpose of explanations was to convince the user that the ML model’s output is reasonable.

The explanation, therefore, needs to be adapted to the level required by its audience to be effective. Providing a single triple as an explanation, therefore, is not a sufficient explanation for the client, we also need to include contextual information and the reasoning behind it for it to be effective.

Explaining Link Predictions with Gradient Rollback³ and Adversarial Explanations⁴

Let us look at the following works’ Gradient Rollback³ and Adversarial Explanations for KGEs⁴, both present a different approach for producing explanations for link predictions for KGEs.

White-box vs Black-box Explanations. An explanation in both articles is referred to as the training triples that are most responsible for a particular prediction. [Betz et al., 2022]⁴ further defines an adversarial attack as minimal modification applied on the training triples leading to the maximal degradation in the link prediction performance. The modification in this study is handled as the deletion and corruption of the most explanatory triple for a prediction. This converges their framework to the identification of the most influential input triples as in the case of Gradient Rollback’s framework. However, [Betz et al., 2022]⁴ make an important distinction between their framework to Gradient Rollback’s³: they referred to their framework as black-box meaning that it does not require KGE model architecture, while the methods such as Gradient Rollback³ are referred to as white-box explanations in the meaning that they require access to the embeddings and loss function of the KGE model together with the training data. In other words, the Adversarial Explanations method is model-independent and only requires the input training data. Model independency sounds necessary for a framework claiming adversarial attacks — being an outsider to break the model’s training — while Gradient Rollback does not have such a claim, providing explanations as an insider to the training procedure.

Adversarial Explanations. Then, you might wonder how they only use the input training data for producing the explanations without access to the training procedure of the KGE model. The answer is that they theorize why the model outputs a prediction based on a set of logical rules. Therefore, the first step is rule learning and identification of the influential input triples for a prediction using the learned set of rules. In order to learn the set of logical rules from the input KG triples, they adopt a method called AnyBURL [Meilicke et al., 2019]⁵. AnyBURL basically learns the logical rules governing the statistical regularities in the input data, such as speaks(X, English)←nationality(X, English). Each rule is provided with confidence as the rate of correct predictions done by that rule. Then, the training triple with the highest confidence determining a prediction using the learned rules is designated as the explanation. Next, an attack takes place by removing that input triple and adding a corrupted version of it.

Gradient-Rollback³, on the other hand, gives the insider’s feedback about the prediction by probing into the training procedure: it records the gradients used in the embedding learning (probing the back-propagation), contributed by every input triple during training. In order to estimate the influence of each training input on a particular output, Gradient Rollback³ computes a modified representation of the output (embedding of its subject, object and predicate) which would be learned if the training occurred in the absence of that input. The modified embeddings are computed via the operation called Gradient Rollback³: accumulation of the gradients recorded on the subject, object and predicate and subtraction of them from their original embeddings. The prediction score of the output is then recomputed using the modified embeddings. This would ultimately negate the contribution of an input triple to the prediction of the output. The input triple causing the highest drop in the probability of the predicted output is deemed to be the top explanation.

Human-readability. Although Gradient Rollback³ gives an insider’s justification for a decision made by the trained KGE model, its explanation would not be necessarily interpretable by the end-user. This means the explanatory input triple provided by Gradient Rollback³ may not sound relevant to the prediction output according to the end-user. Since the KG embedding learning models are inherently not symbolic — they are sub-symbolic, it is a natural consequence that their justification does not always make sense to humans. On the other hand, adversarial explanations, generated by a symbolic approach providing an outsider’s justification for a prediction, are more aligned with human reasoning.

🧩 4. Why knowledge graphs are useful for ML?

Knowledge graphs as a data representation format is useful for different decision support systems not necessarily related to machine learning. The relational form makes it easy to ask complex queries on the large graph databases, that could be presented to the user who requests contextual information on a given subject. By adding machine learning layer on top of knowledge graph we are already utilising the format of data that is explainable on its own. We can also observe a trend in other ML areas where non-explainable black box models are paired with a knowledge graph database to provide contextual information for the prediction that could be sometimes seen as an explanation⁶ although we prefer term justification. This is specifically a very promising area in the Generative AI where the model generates the content and then uses external databases to find references for the generated outputs.

💼 5. What Accenture Labs is doing?

In Accenture Labs, Dublin we utilise Knowledge Graph Embeddings Models to tackle different problems, from predicting gene-disease associations, through customers segmentation, till predicting risk of relapse of a lung cancer patients. Specifically, in this last work within the CLARIFY H2020 Horizon Project, we proposed explaining KGE model with example-based approach to predict patients relapse. We provide an open source active library AmpliGraph, where you can find many examples of KGEs predictions and if you want to learn more, refer to our tutorials on the subject: tutorial on KGEs during ECAI 2020, tutorial on KGEs for NLP during COLING 2022 and also to tutorial on Link Predictions from Raw Data to Insights during KGC 2023. Not enough? See also: Driving Better Decisions with Knowledge Graphs, AI R&D Explainable AI, Responsible AI from principles to practice.

⚒️ 6. Do it Yourself

Example 1

Train a simple TransE model on the Hetionet⁷ dataset and visualize the embeddings with tensorboard:

Requirements:

  • download the dataset from here.
  • install necessary packages in a environment of your choice:
pip install tensorflow ampligraph pandas
from ampligraph.evaluation import train_test_split_no_unseen 
from ampligraph.latent_features import ScoringBasedEmbeddingModel
from ampligraph.latent_features.loss_functions import SelfAdversarialLoss
from ampligraph.utils import create_tensorboard_visualizations
import tensorflow as tf
import pandas as pd

path = 'hetionet-v1.0-edges.sif.gz'
df = pd.read_csv(path, sep='\t')

X_train, X_test = train_test_split_no_unseen(df.values, test_size=1000)
X_train, X_valid = train_test_split_no_unseen(X_train, test_size=500)

X = {'train':X_train, 'test': X_test, 'valid': X_valid}

optim = tf.optimizers.Adam(learning_rate=0.01)
loss = SelfAdversarialLoss({'margin': 0.1, 'alpha': 5, 'reduction': 'sum'})
model = ScoringBasedEmbeddingModel(eta=5,
k=200,
scoring_type='TransE',
seed=0)
model.compile(optimizer=optim, loss=loss)
history = model.fit(X['train'],
batch_size=10000,
epochs=5)

create_tensorboard_visualizations(model,
entities_subset='all',
loc = './embeddings_vis')
# On terminal run: tensorboard --logdir='./full_embeddings_vis' --port=8891
# Open the browser and go to the following URL: http://127.0.0.1:8891/#projector

Example 2

Train a simple GCN and explain prediction with GNNExplainer:

Requirements:

  • install necessary packages in a environment of your choice:
pip install torch torch_geometric sklearn matplotlib
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig
from torch_geometric.nn import GCNConv

if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)

def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x

def decode(self, z, edge_label_index):
src, dst = edge_label_index
return (z[src] * z[dst]).sum(dim=-1)

def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index).view(-1)

def train_model(dataset, train_data, val_data, test_data):

model = GCN(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

def train():
model.train()
optimizer.zero_grad()

out = model(train_data.x, train_data.edge_index,
train_data.edge_label_index)
loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label)
loss.backward()
optimizer.step()
return float(loss)

@torch.no_grad()
def test(data):
model.eval()
out = model(data.x, data.edge_index, data.edge_label_index).sigmoid()
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

for epoch in range(1, 201):
loss = train()
if epoch % 20 == 0:
val_auc = test(val_data)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ')

model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
return model, model_config

dataset = 'CiteSeer'
path = osp.join('../', 'data', 'Planetoid')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True),
])
dataset = Planetoid(path, dataset, transform=transform)
train_data, val_data, test_data = dataset[0]

model, model_config = train_model(dataset, train_data, val_data, test_data)

edge_label_index = val_data.edge_label_index[:, 0]

explainer = Explainer(
model=model,
explanation_type='model',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=train_data.x,
edge_index=train_data.edge_index,
edge_label_index=edge_label_index,
)
explanation.visualize_graph('explain.png')

Explanation generated by GNNExplainer:

Explanation generated by GNNExplainer.

References:

  • [1] EUR-Lex-52021PC0206. 2021. Proposal for a REGULATION OF THE EUROPEAN PARLIAMENT AND OF THE COUNCIL LAYING DOWN HARMONISED RULES ON ARTIFICIAL INTELLIGENCE (ARTIFICIAL INTELLIGENCE ACT) AND AMENDING CERTAIN UNION LEGISLATIVE ACTS, COM/2021/206
  • [2] Barredo Arrieta, Alejandro, et al. ‘Explainable Artificial Intelligence (XAI): Concepts, Taxonomies, Opportunities and Challenges toward Responsible AI’. Information Fusion, vol. 58, June 2020, pp. 82–115. DOI.org (Crossref), https://doi.org/10.1016/j.inffus.2019.12.012.
  • [3] Lawrence, Carolin, et al. ‘Explaining Neural Matrix Factorization with Gradient Rollback’. Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, no. 6, May 2021, pp. 4987–95. DOI.org (Crossref), https://doi.org/10.1609/aaai.v35i6.16632.
  • [4] Betz, Patrick, et al. ‘Adversarial Explanations for Knowledge Graph Embeddings’. Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence, International Joint Conferences on Artificial Intelligence Organization, 2022, pp. 2820–26. DOI.org (Crossref), https://doi.org/10.24963/ijcai.2022/391.
  • [5] Meilicke, Christian, et al. ‘Anytime Bottom-Up Rule Learning for Knowledge Graph Completion’. Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence, International Joint Conferences on Artificial Intelligence Organization, 2019, pp. 3137–43. DOI.org (Crossref), https://doi.org/10.24963/ijcai.2019/435.
  • [6] Tiddi, Ilaria, and Stefan Schlobach. ‘Knowledge Graphs as Tools for Explainable Machine Learning: A Survey’. Artificial Intelligence, vol. 302, Jan. 2022, p. 103627. DOI.org (Crossref), https://doi.org/10.1016/j.artint.2021.103627.
  • [7] Himmelstein, Daniel Scott, et al. ‘Systematic Integration of Biomedical Knowledge Prioritizes Drugs for Repurposing’. eLife, vol. 6, Sept. 2017, p. e26726. DOI.org (Crossref), https://doi.org/10.7554/eLife.26726.
  • [8] Sanchez-Lengeling, Benjamin, et al. ‘A Gentle Introduction to Graph Neural Networks’. Distill, vol. 6, no. 8, Aug. 2021, p. 10.23915/distill.00033. DOI.org (Crossref), https://doi.org/10.23915/distill.00033.

--

--