BERT Explainability

Shesh Narayan Gupta
AIGuys
Published in
11 min readDec 22, 2021

--

Photo by Clarisse Croset on Unsplash

In this post we are going to explore a few methods towards the explainability of BERT, and why it may be worth the time.

What is BERT ?

BERT is an open-source machine learning framework for natural language processing (NLP). BERT is designed to help computers understand the meaning of ambiguous language in the text by using surrounding text to establish context.

BERT, which stands for Bidirectional Encoder Representations from Transformers, is based on Transformers, a deep learning model in which every output element is connected to every input element, and the weightings between them are dynamically calculated based upon their connection (in NLP, this process is called attention).

Motivation for BERT explainability: Why?

Let me explain the problem in a layman term first before delving into details and throwing a lot of technical and Machine learning jargon at you. So, to put it simply, if we have text columns in our dataset along with numerical columns and if we want to understand how that text column’s content is contributing to our predictions (what words, bigrams, trigrams are playing an important role) then how can we do it? The most common approach we take for a text column in the pre-processing stage is either drop them or convert them into dummy variables or do some sort of mathematical transformations (because our models love numbers!). This may be a useful approach but most of the time by doing this we lose the valuable information which we can get from the language and word contexts. For example when we talk about a movie “Special effects”, “Brad Pitt”, “Angelina Jolie” and so on could be very important words and bigrams. So instead of doing lame transformations and dropping the text columns completely, we could do something more useful and productive. Assume if our model could also understand the context of the text columns with respect to what we are trying to predict, it could do wonders and give more accurate results. So we can use models like BERT to understand the context of our text columns and investigating what aspect of language BERT is learning helps in verifying the robustness of our model. But still, there is a small issue — how do we know what words are deemed important by our model in our text columns even after using complex models like BERT?

Many AutoML models do come with inbuilt feature selection and feature importance. But the problem still persists— these AutoML libraries/software's/packages most of the time fail to explain the importance of the features (what words may have contributed to making that feature important) if those are from the text columns(if any) of the input dataset. If these features are deemed important according to the feature selection pipeline they will just be shown to the end-user in the list of important features by the AutoML software like “<column_name>_<some_number>_<feature number>”.

In recent days where AutoML (Like TPOT, H2O.ai, HyperOpt etc.) are in a boom, model interpretability has become the need of the hour.

From going on this point onwards, I’ll explain rest of the article by taking IMDB dataset as an input dataset which has two text columns “Reviews” and “Sentiments”. This can be downloaded from https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews

We will consider “Reviews ” as our input feature which is nothing but a text column.

With that being said, let’s dive right into the main idea of what this article is focusing on and as a first step let’s look into the picture shown below to understand what we discussed in the above few paragraphs.

Positive and negative feature impacts of IMDB dataset

A little bit of background on what is that above image explaining and where it has been originated from — This is a SHAPLY explanation of feature importance on the IMDB dataset. IMDB dataset was taken as an input and after cleaning and pre-processing of the columns, this was fed into BERT (bert-base-uncased pre-trained model) to understand the language context. A note to the new readers (for whom BERT is an unknown territory) is that BERT will provide 768 output numbers (called output feature vectors) for every text input you provide to it. It simply means for each of the rows in your dataset each text column will have its own 768 outputs (consider it sort of a magic conversion that does have all the language context intact in it). In our case, since our dataset has only 2 columns (reviews and sentiments) each of them will have its own 768 output numbers from BERT and yes every row will have 768x2 output numbers. To simply put BERT will convert 2 text columns from our input dataset to 768x2 numerical columns. Once these features are created it is then fed into H2O.ai autoML to calculate the feature importance. Remember for H2O these are just 768x2 input numerical features. However, for any model to perform prediction, it needs input and output so in our case we made “Sentiments ” as output and “Reviews” as input (which is now 768 numerical values after passing through BERT). Sentiments were not passed to BERT because it is to be treated as a target we are going to predict or do something later on. So, the image above you see is now the output of H2O’s top model’s feature importance (with only one text column converted into 768 numerical columns) explained by SHAP. I hope it is clear up until this point now.

Assuming the end-user is a non-technical person, it’s natural for him/her to not understand the meaning of these derived features (remember what our problem was?). Hence this BERT explainability will help end-users intuitively to understand the meaning of these derived features.

BERT explainability is not just limited to the above-mentioned problem and there could be many more reasons to explore this, I feel that the problem mentioned here is one of the worthy reasons to explore BERT explainability.

Where to start from?

If we start to look deep into how these features are derived from the text(s) in BERT, we can figure out that they are computed from models which have been trained on billions and billions of documents. Essentially, these features are high-dimensional vectors that don’t have any meaning in human language. To help the user make sense of these high-dimensional vectors, we may use the text which created these vectors to explain these vectors. Simply put, we can find words from our input text which are highly influential while computing each of these output vectors. Since BERT is also based on the “Attention” mechanism, the easiest way to do this is to calculate the attention weights (attribution) for the different input sentences.

What approaches we could take?

a. Document term approach

We may use a bag of words approach to finding the words which are highly correlated with a given feature vector from BERT. We can first concatenate all the feature vectors (which are to be explained) with the text columns (If the dataset is too large, we take a sample of the dataset, or else we use the entire dataset) Then we create a document term matrix from the sampled dataset. A document term matrix is a 2d matrix with all the words in our text corpus as columns. Each row represents a row in the sampled dataset, it has counts of that word appearing in that row. Below is an example of the document term matrix.

Document Term Matrix

We can compute the document term matrix, remove the stop words, apply min frequency, and compute unigrams and bi-grams.

import pandas as pd
import re
import numpy as np
from tqdm import notebook, tqdm
from sklearn.feature_extraction.text import CountVectorizer
from xgboost import XGBRegressor
from tqdm import tqdm
import torch
def create_document_term_matrix(text_sub):
vec = CountVectorizer(min_df=3, stop_words=stop_words, ngram_range=(1,2), max_features=10000)
X = vec.fit_transform(text_sub)
df = pd.DataFrame(X.toarray(), columns=vec.get_feature_names())
return df

Giving first 4 rows of IMDB dataset as input below is the output document term matrix from the code

Document Term Matrix for first 4 rows “Review” column of IMDB dataset

NOTE: The output will ignore all the stop-words if you have provided any

Using this matrix as input features X and the feature to be explained as the target vector Y, we can pass them to XGBRegressor (or any other relevant model as a matter of fact, which can give importance of each feature). This way we get the importance of each word in the vocabulary for that particular target vector Y (which is one of the output feature vectors from BERT).

def compute_important_words(fv):     
gpu_available = torch.cuda.is_available()
if gpu_available:
model = XGBRegressor(tree_method="gpu_hist")
else:
model = XGBRegressor()
model.fit(document_term_matrix, fv)
xgb_importance = model.feature_importances_
imp_word_indices = xgb_importance.argsort()[-20:][::-1]
res = []
for ind in imp_word_indices: res.append((document_term_matrix.columns[ind],
str(round(model.feature_importances_[ind], 4))))
return res

Final code output

Output from the XGBRegerssor in the actual code

Word importance output from XGBRegressor

Sorting the vocab by importance, we can take top k words to explain that feature vector from BERT. We then process training the XGBRegressor with different target vectors Y.

word_groups = compute_important_words(df_feature_vector)

NOTE : The document term matrix is created only once and the XBGRegressor is trained again and again for each feature vector from BERT. Also note that how the words which are repeated more i.e. have high count on document term matrix, are given higher importance by XGBRegressor. While it may not be true always that the higher number of times the word is appearing highest the importance for it will be, this approach still can do some justice for explaining the important words.

b. Word Importance approach

An attribution method scores the input data based on the predictions the model makes, i.e. it attributes the predictions to its input signals or features, using scores for each feature. Integrated Gradients is one such method. In rough terms, it is equal to (feature * gradient). The gradient is the signal that tells the neural network how much to increase or decrease a certain weight/coefficient in the network during backpropagation. It relies heavily on the input features to do so. Therefore, the gradient associated with each input feature with respect to the output can help us get a clue about how important a feature is. Integrated Gradients allows us to attribute the selected output feature from the BERT model to its inputs. Word importance's can be generated for each of the output features like so:

Word Importance

This allows us to understand which words led to the feature (which is one of the feature vectors from BERT) they’re looking at.

NOTE: Green means that the token is correlated positively with the output feature, red means negatively. Each token gets a real number score, which can be positive or negative.

The above-mentioned approach can be implemented using Captum’s Integrated gradients with ease. The first step will be to fine-tune BERT model on the desired dataset

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = '<PATH-TO-SAVED-MODEL>'

# load model
model = BertModel.from_pretrained(model_path)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

and create a helper function to perform a forward pass of the model and make predictions.

def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
output = model(inputs, token_type_ids=token_type_ids,
position_ids=position_ids, attention_mask=attention_mask, )
return output.start_logits, output.end_logits

Once the helper function is defined we need to create a custom forward function that will allow us to access the start and end positions of our prediction using position input argument.

def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
pred = predict(inputs,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
pred = pred[position]
return pred.max(1).values

Then we can make predictions using input, token type, position id and a default attention mask and compute attributions with respect to the BertEmbeddings layer

def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
output = model(inputs, token_type_ids=token_type_ids,
position_ids=position_ids, attention_mask=attention_mask, )
return output.start_logits, output.end_logits
def compute_attributions():
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)

attributions_start, delta_start = lig.attribute(inputs=input_ids,
baselines=ref_input_ids,
additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
return_convergence_delta=True)
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
return_convergence_delta=True)
attributions_start_sum = summarize_attributions(attributions_start)
attributions_end_sum = summarize_attributions(attributions_end)

(we can define a set of helper functions for constructing references/baselines for word tokens, token types, and position ids).

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

text_ids = tokenizer.encode(text, add_special_tokens=False)

# construct input token ids
input_ids = [cls_token_id] + [sep_token_id] + text_ids + [sep_token_id]

# construct reference token ids
ref_input_ids = [cls_token_id] + [sep_token_id] + \
[ref_token_id] * len(text_ids) + [sep_token_id]

return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

Now that we have the attributions calculated, we can summarize attributions for each word token in the sequence and get the important texts to display to our end-users. Walla!!

def summarize_attributions(attributions):
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions

Final code output

Below is the final output from the IG approach to calculate the important words which could have been deemed significant by BERT while calculating one of the target output vectors (#551 in our example), which H2O found important overall for our dataset.

Important bigrams for a positive and negative impact from IG approach

As we can see this approach is far better than the Document term matrix approach as the output important words are making more sense than the previous one.

NOTE: Because this approach uses back-propagation to calculate the important words, this approach is computationally intensive. It took almost 30 minutes to explain the important words for only 500 rows, that too over a powerful GPU. Imagine how much time it will take to explain for over tens of thousands of rows in your dataset. If we can figure out the way to may it run faster, this could be THE approach.

Conclusion

We have seen in this article how we can use Document Term and Integrated Gradients approach to achieve BERT’s explainability and solve the problem where AutoML software fails to explain completely feature importance if the features originated from texts in the input dataset. Although these two approaches appear to work well, this can be further improved in order to improve the quality of the important words so that it makes more sense to the end-user when they see these are the words that actually caused a positive or negative impact on their predictions.

References

Notebook Links to play through the code

--

--

Shesh Narayan Gupta
AIGuys

"Data are just summaries of thousands of insights – I uncover a few of those insights to help make the data meaningful..."