SHAP Force Plots for Classification
How to functionize SHAP force plots for binary and multi-class classification
In this post I will walk through two functions: one for plotting SHAP force plots for binary classification problems, and the other for multi-class classification problems.
At this point you may be thinking “Alright, but there’s already a
shap.force_plot() function, so what are we even doing here?” And yes, technically you are correct. BUT pretty much all the examples of SHAP force plots I have seen are for continuous or binary targets. You actually can produce force plots for multi-class targets, it just takes a little extra digging. My goal is to help you do that digging so you get an output that is extra interpretable. After all, that’s what SHAP is all about: making complex, “blackbox” models more interpretable.
I’m not going to spend much time talking about or explaining what SHAP (SHapley Additive exPlanations) is because there are plenty of awesome resources that already do that. If you want to learn more about Shapley values and SHAP, I highly recommend starting here. That hits the highlights, you’ll see the various kinds of cool plots you can produce with the SHAP Python library, and the hammering a log story is a strange but effective visual for understanding what Shapley values are trying to tell us.
To sum up very briefly, Shapley values give us a metric for evaluating the importance of a predictor relative to other predictors. Essentially these values take into account how the loss function (model error) is affected by knowing vs. not knowing about that predictive feature. These values also indicate the direction of the relationship (positive or negative) between the predictive feature and the target variable.
The SHAP library provides easy-to-use tools for calculating and visualizing these values. To get the library up and running
pip install shap, then:
Once you’ve successfully imported SHAP, one of the visualizations you can produce is the force plot. Force plots, like the one shown below, allow you to see how features contributed to the model’s prediction for a specific observation. This is perfect for being able to explain to someone exactly how your model arrived at the prediction it did for a specific observation.
My examples for this post come from a project I worked on to predict whether or not a registered voter cast a vote during the 2020 general election. The binary target was yes they voted (Vote = 1), or no they didn’t (No Vote = 0). In the plot above, the bold 0.80 is the model’s score for this observation. Higher scores lead the model to predict 1 and lower scores lead the model to predict 0. The features that were important to making the prediction for this observation are shown in red and blue, with red representing features that pushed the model score higher, and blue representing features that pushed the score lower. Features that had more of an impact on the score are located closer to the dividing boundary between red and blue, and the size of that impact is represented by the size of the bar.
So this particular person was ultimately classified as Vote (1), because they were pushed higher by all the factors shown in red (birth region info is missing from their registration record, they are a Baby Boomer, etc.).
However, this plot is the only output you get from using
shap.force_plot(). It doesn’t tell you the predicted output of the model, nor does it tell you the ground truth label for this specific observation. Wouldn’t it be helpful to be able to see all that information output at once? That’s the purpose of our first function.
Force Plots for Binary Classification
Before we get to the function, let’s make sure everything is set up in a way that will make it easy to use. Preprocess and split your data so you can train your model. Make sure to store the feature names that correspond to the preprocessed data! You’ll need those to make the plots actually informative. If your preprocessing pipeline involves creating new columns (such as when one hot encoding categorical variables) you can obtain the new feature names from the
.get_feature_names attribute of the correct step in the pipeline. For example, my code for setting this all up was as follows:
If you’ve got more than just categorical variables, you’ll need to make sure to grab those feature names from their steps and then get all your feature names into a single list in the order they are handled in your preprocessing pipeline.
Once you train and tune your model, assign the fitted classifier and the booster each to a variable (I used XGBoost classifiers for this project — if you’re using something else, these functions should be easy to adapt for other tree-based algorithms). We want to be able to call on these easily in our next steps and as arguments in our function. And while we’re storing useful things, let’s turn our preprocessed
X_train_tf array into a DataFrame that uses our
feature_names as column names.
Next, we need to actually obtain the SHAP values for our trained model. Since I used a tree-based classifier, I do this using SHAP’s
TreeExplainer(). This explainer is then used to calculate the SHAP values as shown below.
Important note: I have seen some functions for producing visualizations that include the calculation of SHAP values as part of the function. I do not recommend doing this. Depending on the size of your dataset, the calculation can take a decent amount of time and there is no reason to do this each time you want to produce a plot for the same model.
Now we’re all set up to write and call our first function which will produce a more informative output that looks like this:
And here’s how you can write that
You can adjust the information included in the f-strings to format the output however you want. Maybe you want to be super serious and get rid of that exclamation mark that celebrates when your model predicted correctly. Maybe you want to map the labels to report the actual class names rather than a number so it’s easier for you to keep track of things.
Force Plots for Multi-Class Classification
Again preprocess your data, store the feature names after preprocessing, and this time don’t forget to label encode your target. You’ll need to access those label-encoded classes later!
The last line returns the following output for my multi-class model:
array(['Early', 'Election Day', 'No Vote], dtype=object)
because the multi-class version of my model split people who cast a vote in the election into 2 categories based on when they chose to vote. So those get coded as 0, 1, and 2, respectively.
Next, train and tune your model, store your best estimator and the booster same as for the binary model, and make sure you’ve got your preprocessed X training set converted into a DataFrame with the appropriate feature names as the column names. Then use the booster to get your TreeExplainer and calculate the SHAP values.
If you try to use
shap.force_plot() like we did in our binary function, it throws an error:
TypeError: list indices must be integers or slices, not tuple.
This is because, when you calculate SHAP values for a multi-class target, you get a list of n arrays containing SHAP values. Here n is the total number of classes that make up you target. So since my model has 3 classes of the target variable, I should find that I have 3 arrays in my list. You can confirm these things with the following code:
We can think of each array of SHAP values in a similar way as we did for the single array of SHAP values we got when we had a binary target. In that binary case, the SHAP values were pushing the model towards a classification of Vote (1) or No Vote (0). Now with our 3 classes, each array is assessing each class as its own binary target. So Early Vote (1) vs. Not Early Vote (0), Election Day Vote (1) vs. Not Election Day Vote (0), and No Vote(1) vs. Not No Vote(0).
Is your head spinning a little? Stick with me. The visualizations will help.
So here’s the output we’re aiming for from our function:
Three force plots, one for each target class, and the model predicted the class with the highest score — Election Day (granted it was barely the highest). Here’s how to write a function that will get you this type of output:
Once you’ve taken a minute to examine what each piece of the function is doing (I think I provided enough context and annotation, but feel free to drop me a comment), you might be thinking “This is great and all, but you’ve hard-coded your target classes. Could be better.”
My response is “That’s fair. Now that you’ve got a clear idea of how this function works for our specific example, here’s how you can rewrite it to apply to your specific problem whether you’ve got 3 or 333 target classes.”
You just have to supply an additional argument (the label encoded
classes_ attribute) so the function can loop through those to create your label dictionary. I also changed the default from displaying
'all' to only displaying
'both' in case you do have 333 classes in for target variable.
The SHAP library provides useful tools for assessing the feature importances of certain “blackbox” algorithms that have a reputation for being less interpretable. It also provides ways to visualize how the features impact your model’s predictions.
The SHAP force plot shows you exactly which features had the most influence on the model’s prediction for a single observation. This is interesting in and of itself, but particularly useful if you find yourself having to explain something like why your model determined you should deny a specific person’s loan application to your boss.
You can use these plots in part of a larger function to produce a more informative output for both binary and multi-class classification problems. Specifically, it’s useful to be able to see:
- the ground truth label for the observation,
- the model’s prediction for the same observation,
- a statement that does the comparing for you and tells you if your model predicted this observation correctly, and
- the force plot(s) explaining the model’s output for that observation.
I hope you found this helpful and are able to apply something you’ve learned to your own work!