Explainable AI: Interpreting BERT Model

Sai Sharanya Nalla
5 min readAug 17, 2022

--

Motivation and background

Why is it important to build interpretable AI models?

The future of AI is in enabling humans and machines to work together to solve complex problems. Organizations are attempting to improve process efficiency and transparency by combining AI/ML technology with human review.

In recent years with the advancement of AI, AI-specific regulations have emerged, for example, Good Machine Learning Practices (GMLP) in pharma and Model Risk Management (MRM) in finance industries, other broad spectrum regulations addressing data privacy, EU’s GDPR and California’s CCPA. Similarly, internal compliance teams may also want to interpret a model’s behavior when validating decisions based on model predictions. For instance, underwriters want to learn why a specific loan application was tagged suspicious by an ML model.

Overview

What is interpretability?

In the ML context, interpretability refers to trying to backtrack what factors have contributed to an ML model for making a certain prediction. As shown in the below graph, simpler models are easier to interpret but may often produce lower accuracy compared to complex models like Deep Learning, Transformer based models that can understand non-linear relations in the data and often have high accuracy.

Source: Explainable Artificial Intelligence (XAI) paper

Loosely defined there are two types of explanations —

Global explanation: is explaining on an overall model level to understand what features have contributed the most for the output ? Eg, in a finance setting where the use case is to build ML model to identify customers who are most likely to default, some of the most influential features for making that decision are customer’s credit score, total no. of credit cards, revolving balance etc

Local explanation: can enable you to zoom in on a particular data point and observe the behavior of the model in that neighborhood. For example, for sentiment classification of a movie review use case, certain words in the review may have higher impact towards the outcomes vs the other. “I have never watched something as bad

What is a transformer model?

A transformer model is a neural network that tracks relationships in sequential input, such as the words in a sentence, to learn context and subsequent meaning. Transformer models use an evolving set of mathematical approaches, called attention or self-attention, to find minute relationships between even distance data elements in a series. Refer to Google’s publication for more information.

Integrated Gradients

Integrated Gradients (IG), is an Explainable AI technique introduced in the paper Axiomatic Attribution for Deep Networks. In this paper, an attempt is made to assign an attribution value to each input feature. This tells how much the input contributed to the final prediction.

IG is a local method that is a popular interpretability technique due to its broad applicability to any differentiable model (e.g., text, image, structured data), ease of implementation, computational efficiency relative to alternative approaches, and theoretical justifications. Integrated gradients represent the integral of gradients with respect to inputs along the path from a given baseline to input. The integral can be approximated using a Riemann Sum or Gauss Legendre quadrature rule. Formally, it can be described as follows:

Integrated Gradients along the i — th dimension of input X. Alpha is the scaling coefficient. The equations are copied from the original paper.

The cornerstones of this approach are two fundamental axioms, namely sensitivity and implementation invariance. More information can be found in the original paper.

Use Case

Now let’s see in action how Integrated Gradients method can be applied using Captum package. We will be fine tuning a question answering BERT (Bidirectional Encoder Representations from Transformers) model, on SQUAD dataset using transformers library from HuggingFace, review notebook for detailed walkthrough.

Steps:

  • Load the tokenizer and pre-trained BERT model, in this case bert-base-uncased
  • Next is computing attributions w.r.t BertEmbeddings layer. To do so, define baseline/references and numericalize both the baselines and inputs.
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
token_type_ids=None, ref_token_type_ids=None, \
position_ids=None, ref_position_ids=None):
input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)return input_embeddings, ref_input_embeddings
  • Now lets define question-answer pair as an input to our BERT model

Question = “What is important to us?”

text = “It is important to us to include, empower and support humans of all kinds.”

  • Generate corresponding baselines/references for question-answer pair
  • Next step is to make predictions, one option is to use LayerIntegratedGradients and compute the attributions with respect to BertEmbedding. LayerIntegratedGradients represents the integral of gradients with respect to the layer inputs / outputs along the straight-line path from the layer activations at the given baseline to the layer activation at the input.
start_scores, end_scores = predict(input_ids, \
token_type_ids=token_type_ids, \
position_ids=position_ids, \
attention_mask=attention_mask)
print(‘Question: ‘, question)
print(‘Predicted Answer: ‘, ‘ ‘.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

Output:

Question: What is important to us?

Predicted Answer: to include , em ##power and support humans of all kinds

  • Visualize attributes for each word token in the input sequence using a helper function
# storing couple samples in an array for visualization purposesstart_position_vis =
viz.VisualizationDataRecord(
attributions_start_sum,
torch.max(torch.softmax(start_scores[0], dim=0)),
torch.argmax(start_scores),
torch.argmax(start_scores),
str(ground_truth_start_ind),
attributions_start_sum.sum(),
all_tokens,
delta_start)
print(‘\033[1m’, ‘Visualizations For Start Position’, ‘\033[0m’)
viz.visualize_text([start_position_vis])
print(‘\033[1m’, ‘Visualizations For End Position’, ‘\033[0m’)
viz.visualize_text([end_position_vis])

From the results above we can tell that for predicting the start position, our model is focusing more on the question side. More specifically on the tokens ‘what’ and ‘important’. It has also slight focus on the token sequence ‘to us’ in the text side.

In contrast to that, for predicting end position, our model focuses more on the text side and has relatively high attribution on the last end position token ‘kinds’.

Conclusion

This blog describes how explainable AI techniques like Integrated Gradients can be used to make a deep learning NLP model interpretable by highlighting positive and negative word influences on the outcome of the model.

References

https://arxiv.org/abs/1703.01365

https://captum.ai/

https://arxiv.org/pdf/1711.06104.pdf

--

--