Using SHAP to Explain Machine Learning Models
Do you understand how your machine learning model works? Despite the ever-increasing usage of machine learning (ML) and deep learning (DL) techniques, the majority of companies say they can’t explain the decisions of their ML algorithms [1]. This is, at least in part, due to the increasing complexity of both the data and models used. It’s not easy to find a nice, stable aggregation over 100 decision trees in a random forest to say which features were most important or how the model came to the conclusion it did. This problem grows even more complex in application domains such as computer vision (CV) or natural language processing (NLP), where we no longer have the same high-level, understandable features to help us understand the model’s failures.
Explainable AI is what many consider the ‘ideal’ solution to this problem and, although the existing research is not quite far enough to provide us with a detailed, human readable description of a model’s process throughout inference; some of the explainable AI packages are a great way to get a little more insight into what’s going on. In this article, I’ll focus on SHAP [2] (SHapley Additive exPlanations) and will provide 3 examples for how the methods within it can be used for all sorts of interesting tasks. These examples serve as a simple introduction for how you can begin to explain your ML model output.
The structure of this article is as follows. I’ll start by giving a very brief overview of SHAP. Next, I’ll show how SHAP provides better feature importance estimates in a simple regression task with a random forest regressor. Then, I’ll show a simple example of how the SHAP GradientExplainer can be used to explain a deep learning model’s predictions on MNIST. Finally, I’ll end by demonstrating how we can use SHAP to analyze text data with transformers.
Note: All code used in this post is available on Github!
What is SHAP?
Although some algorithms come with native feature importance estimates, such as the coefficients of a linear or logistic regression or the Gini impurity metrics provided in decision tree-based models, approaches like neural networks and support vector machines don’t offer such easily understandable insights. In addition to this, even the provided estimates in a linear regression, for example, are often unstable under situations of multicollinearity [3] (where one or more predictor variables can be accurately estimated by the others, leading to a potential bias towards some predictor variables with no influence in overall model performance, see image below) and exist only for the entire model, not for individual data points. For this reason, stable, model-agnostic, and interpretable approaches to analyze not only the feature importances for the entire model but also for individual data points are greatly desirable. This is where SHAP comes in.
SHAP is a framework which can be used to interpret model predictions. Outlined by Lundberg and Lee in NIPS 2017 ([4], [5]), SHAP can be used to assign feature importances for every prediction. This can be immensely useful for identifying what features a model leveraged to make an individual prediction. These values are relatively stable and not susceptible to multicollinearity. See [6] for a great explanation of exactly how they achieved this.
SHAP contains several different approaches for model explainability, with specific solutions to attain feature importance estimates for tree-based models, gradient explanations for deep learning models, and model-agnostic explanations through their KernelExplainer. It’s a great package to add to your toolset.
Example #1: A Simple Regression Task
To start, we look at a simple regression task with a random forest regressor. In this scenario, we can use the SHAP TreeExplainer to get feature importance estimates. The purpose of this example is also to show how these estimates deviate from the initial estimate provided by the model.
Data: We examine the California Housing dataset, which can be obtained using the scikit-learn library.
dataset = fetch_california_housing()X = dataset.data
y = dataset.target
columns = dataset.feature_names
print(columns)df = pd.DataFrame(X, columns=columns)
The goal of this dataset is to use the following 8 columns:
- MedInc: median income
- HouseAge: median house age in block
- AveRooms: average number of rooms
- AveBedrms: average number of bedrooms
- Population: block population
- AveOccup: average house occupancy
- Latitude: house block latitude
- Longitude: house block longitude
To predict the average house value, which is represented in units of 100,000 (i.e., A house worth 100,000$ is given the target value 1).
We train a simple random forest regressor with 100 estimators and achieve an R-Squared of 0.80 and a root mean squared error of 0.51. In the figure below, we see that the predictions are definitely in line with the target values, although there is quite a bit of noise.
from sklearn.model_selection import train_test_splittrainX, testX, trainy, testy = train_test_split(df, y, test_size=0.25)from sklearn.metrics import mean_squared_errorprint("R2-Score", rf.score(testX, testy)) # 0.80
print("RMSE", mean_squared_error(rf.predict(testX), testy, squared=False)) # 0.51
Now we want to identify which features were most important. Before we try to use SHAP, let’s try the native sklearn approaches first, starting with the Random Forest Regression feature_importances_ which are found during training.
Feature ImportanceAverage Number of Bedrooms 0.029
Population 0.032
Average Number of Rooms 0.041
Median House Age 0.056
Longitude 0.089
Latitude 0.092
Average House Occupancy 0.138
Median Income 0.523
It’s quite clear that the median income is the most important attribute by far, with the average house occupancy far behind in second place. However, as sklearn warns in their documentation, this feature importance estimation is not always the most trustworthy approach. A suggestion is to use permutation importance instead. When we run that, we get the following output:
Feature ImportanceAverage Number of Bedrooms 0.034 +- 0.001
Population 0.036 +- 0.001
Average Number of Rooms 0.062 +- 0.001
Median House Age 0.124 +- 0.002
Average House Occupancy 0.308 +- 0.004
Longitude 0.381 +- 0.003
Latitude 0.497 +- 0.005
Median Income 0.904 +- 0.009
This paints a somewhat different story to our initial estimates. In this approach, we find that location (latitude and longitude) is deemed more important than the average house occupancy. The results aren’t majorly different to the extent where our initial analysis misled us, but we can see that there are some clear differences.
Permutation feature importances are also not a perfect feature importance metric, however, as they may suffer in cases with correlated features [7]. If we look at the feature correlations in our dataset then we see that this may be an issue:
Another point to note is that we don’t have per-datapoint feature importance. We just know the approximate, average importance a feature has. This is where we can use SHAP. In just 3 lines, we can run and plot feature importances using the TreeExplainer class.
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(trainX[:1000]) # limit amount of data for increased speed
shap.summary_plot(shap_values, trainX, feature_names=columns, plot_type="bar")
list(zip(columns, np.mean(np.abs(shap_values), axis=0)/np.sum(np.mean(np.abs(shap_values), axis=0))))Feature ImportancePopulation 0.017
Average Number of Bedrooms 0.019
Average Number of Rooms 0.032
Median House Age 0.060
Longitude 0.148
Average House Occupancy 0.174
Latitude 0.180
Median Income 0.369
Again we see minor change in the order of variables, but the difference between variables is somewhat less extreme than in the previous results.
We can also look at a single datapoint and identify how a single datapoint was predicted using a decision plot. In this example, the model predicted a high house price value mainly due to a large deviation in the median income variable.
i = 3
X_test = testX.to_numpy()
shap.decision_plot(explainer.expected_value, shap_values[i], X_test[i], feature_names = columns)
In this section, we showed that by running SHAP we can return stable feature importance estimates and, unlike the default estimates, can also run this to analyze single data points.
FYI: If you’re not running a tree-based model and would like this level of analysis, the KernelExplainer in SHAP is model agnostic!
Example #2: Deep Learning with MNIST
It’s easy to run SHAP on data which has high-level, interpretable features. However, with more complex data, such as images, it becomes far less useful to say “Pixel in location (10, 12) has the highest feature importance on average”. Instead, we would like a model explainer to give us higher-level insight into what it is focusing on. Although it’s difficult to quantify what this should look like, in this example we show how SHAP can provide some qualitative insight into what the models focus on.
Data: We examine MNIST, a simple and common computer vision dataset consisting of single channel 28x28 pixel images of handwritten digits. The task is simple: predict the digit in the image.
To achieve this, we train a CNN-based deep learning model using pytorch. The model structure and training setup are from [8].
# Model for MNIST Classification
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
Training the model for 15 epochs gives us over 99% accuracy, as can be seen below:
Test set: Average loss: 0.0259, Accuracy: 9923/10000 (99%)
With this type of data, it’s not quite as useful to ask which features (pixels in this case) are, on average, the most useful for classification. There are two available options to inspect our model: DeepExplainer and GradientExplainer. As I’m using PyTorch, I’ll focus solely on GradientExplainer.
The GradientExplainer shows some more insight, however. Here is what we see when running it using the first convolutional layer of the model:
to_explain = images[[3, 12, 18, 22]]
e = shap.GradientExplainer((model, model.conv1), images)
shap_values,indexes = e.shap_values(to_explain, ranked_outputs=4, nsamples=200)class_names= ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
index_names = np.vectorize(lambda x: class_names[x])(indexes.cpu().numpy())
shap_values = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
to_explain = np.swapaxes(np.swapaxes(to_explain.cpu().numpy(), 1, -1), 1, 2)shap.image_plot(shap_values, -to_explain, index_names)
This image shows the top 4 predictions for each handwritten digit on the far left. For each class, the blue represents what the model considers ‘negative’, i.e., what doesn’t fit the class it’s looking at. Take the 5 on the first row, for example. When we look at why it wasn’t classified as a 9, we clearly see that the horizontal sections of the image deviate from what we might expect to see in a 9 (look at the parts of the image that are the most negative/blue).
Although the results in this scenario are not quite as actionable as the interpretable feature importance analysis, we can still get a much stronger idea for how the model goes about its classification. On top of that, it looks pretty cool :D.
Example #3: SHAP with Transformer Networks
In our final example, we look at how we can analyze natural language with transformers and SHAP. It’s important to note that NLP data can be analyzed in a somewhat more interpretable way than images. In this example I’m only showing a special case with a transformer network, but in practical scenarios you can use other SHAP explainers or Lime [9] to obtain similar explanations.
Data: We examine the fetch_20newsgroups dataset, which can be downloaded from sklearn. In this scenario, we’re not taking the targets into account. Instead we are treating it as a sentiment analysis task as there are pretrained transformer networks for this.
An example entry looks like:
993Apr27.004240.24401@csi.jpl.nasa.gov , by eldred@rrunner.jpl.nasa.gov (Dan Eldred):In article <1rh9b0INN2r4@snoopy.cis.ufl.edu ruck@beach.cis.ufl.edu (John Ruckstuhl) writes:I know this is a long shot, but does anyone know what solvent I shoulduse to clean duct-tape adhesive from carpet?Someone taped wires to the carpet, and now it is time to move out.I don't know for sure that this will work, but you might try MEK (methylethyl keytone?). It worked getting the stickum left o
A bit messy, but it’ll be interesting to see what the model thinks about data like this.
As we already have a pretrained sentiment analysis model, we can immediately run it on the text to look at which parts of the input the model tags as positive or negative. In just a few lines we can run this analysis.
from sklearn.datasets import fetch_20newsgroups
data = fetch_20newsgroups(subset="all")
X, y = data.data, data.target
X = [x[256:512+256].replace(">", " ") for x in X] # Shortening so no issues with transformersimport transformers
import shap# load a transformers pipeline model
model = transformers.pipeline('sentiment-analysis', return_all_scores=True)
# explain the model on two sample inputs
explainer = shap.Explainer(model)
shap_values = explainer(X[200:250])i=23
shap.plots.text(shap_values[i, :, "POSITIVE"])
Here are a few examples of what sort of output we can get:
The results here are not quite as insightful as the other approaches, but it allows us to look closely at separate key sections of the input text which influence the overall prediction. At the very least, it gives us further insight than a standard prediction score.
Conclusion
Explainable artificial intelligence is an underdeveloped area in ML/AI, but just because many models are black boxes does not mean we shouldn’t at least try to peek in! In this article, I showed a few examples of how we can use SHAP to gain further insight into your ML model’s predictions.
We find that in simple classification and regression tasks with high level features, we can attain great insight from a SHAP feature importance analysis, especially when using tree-based methods. Although we may not be able to attain such high quality insight in deep learning tasks, we can use SHAP gradient and deep explainers to better understand why we made a correct or incorrect prediction. These techniques can be easily applied in most ML applications and are immensely useful for model explainability and real-time prediction analysis.
Thanks for reading!
Author:
Nathan Bosch is the Head of Education at the KTH AI Society, MSc student in Machine Learning at the KTH Royal Institute of Technology, and R&D Intern at Ericsson. You can reach him on LinkedIn or by email at nathan@kthais.com
References
[2] https://github.com/slundberg/shap
[3] https://en.wikipedia.org/wiki/Multicollinearity
[4] https://proceedings.neurips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html
[5] https://arxiv.org/abs/1705.07874
[6] https://christophm.github.io/interpretable-ml-book/shap.html
[7] https://christophm.github.io/interpretable-ml-book/feature-importance.html
[8] https://github.com/pytorch/examples/blob/master/mnist/main.py