SHAP Part 3: Tree SHAP

Rakesh Sukumar
Analytics Vidhya
Published in
13 min readMar 30, 2020

Tree SHAP is an algorithm to compute exact SHAP values for Decision Trees based models. SHAP (SHapley Additive exPlanation) is a game theoretic approach to explain the output of any machine learning model. The goal of SHAP is to explain the prediction for any instance xᵢ as a sum of contributions from it’s individual feature values. Refer to the part 1 of this series here for brief theoretical introduction to SHAP.

As explained in the first article, SHAP values are obtained from the equation:

Notation: |M| is the total number of features. S represents any subset of features that doesn’t include the i-th feature and |S| is the size of that subset. fₛ() represents the prediction function for the model for the subset S.

Let’s workout the SHAP values for a simple regression tree to get a better understanding of the algorithm. Consider a hypothetical dataset of 10 samples with three numeric independent variables (namely: x, y, z) and a target variable t. We get the below tree structure on fitting a regression tree to this dataset. See the code file here: Tree_SHAP_Hypothetical_Example.ipynb.

n1, n2, n3, …, n7 represent the nodes of the tree. s values in each node represents the number of samples from the training set that fall into each node.

Let’s compute the SHAP values for an instance i given by [x=150, y=75, z=200]. The prediction for this instance is t=20. Remember SHAP is a local feature attribution technique that explains every prediction from the model as a sum of individual feature contributions.

From the theoretical explanation of SHAP in part 1 of this series, we understand that we can compute SHAP values by starting with a null model without any independent variables and then computing the average marginal contribution as each variable is added to this model in a sequence; averaged over all possible sequences. Since we have 3 independent variables here, we have to consider 3!=6 sequences. Let’s compute the marginal contributions for each sequence. Note that SHAP make the assumption that the model prediction for the model with any subset S of independent variables is the expected value of the prediction given the subset xₛ.

The prediction for the null model ϕ⁰ (also called base value)= mean prediction for the training set = (50*2 + 30*2 + 20*1 + 10*5)/10 = 23

Consider the sequence: x > y > z:

1)First, the feature x is added to the null model. Note that for the selected instance i, we can compute the exact prediction with just this information as only the variable x is used in the nodes (n1 & n3) leading upto the leaf node n6. Thus the prediction of the model with just the feature x is 20. Therefore the marginal contribution of x in this sequence, ϕˣ¹= 20–23 = -3.

2)Now, let’s add the feature y to the above model (in step 1). Since adding y does not alter the prediction for the selected instance i, the marginal contribution for y in this sequence, ϕʸ¹ = 20–20 = 0.

3)Similarly, the marginal contribution for z in this sequence, ϕᶻ¹ = 0.

Next, let’s consider the sequence y > z >x:

1)First, the feature y is added to the null model. The first node n1 uses x as the split variable, since is x is not available yet, we compute the prediction as (4/10)*(prediction from left child node n2) + (6/10)*(prediction from right child n3); 100, 60 and 40 being number of training samples falling into nodes n1, n2 and n3 respectively.

i) Prediction from node n2: n2 uses y as the split variable, since y is available (yᵢ = 75 for instance i), the prediction from node n2 = 50.

ii) Prediction from node n3: Again, n3 uses x as the split variable. Therefore, by similar logic, prediction from n3 = (1/6)*20 + (5/6)*10= 70/6.

iii) Therefore, the prediction for the model with just the feature y is (4/10)*50 + (6/10)*(70/6) = 27. Hence, the marginal contribution for y in this sequence, ϕʸ² = 27–23 = 4.

2)Next, we add the feature z to the above model. Since z is not used as a split variable in any of the internal nodes of the tree, adding this feature does not alter the prediction in any way. Thus the marginal contribution for z in this sequence, ϕᶻ² = 0. You may validate this by following this same approach as in step 1.

3)Finally, we add the feature x to the model which gives the prediction as 20. Therefore the marginal contribution of x in this sequence is ϕˣ²=20–27 = -7.

Similarly, we compute the marginal contribution of each feature values for remaining sequences:

Sequence x > z > y: ϕˣ³ = -3, ϕʸ³ = 0, ϕᶻ³ = 0

Sequence z > x > y: ϕˣ⁴ = -3, ϕʸ⁴ = 0, ϕᶻ⁴ = 0

Sequence z > y> x: ϕˣ⁵ = -7, ϕʸ⁵ = 4, ϕᶻ⁵ = 0

Sequence y > x> z: ϕˣ⁶ = -7, ϕʸ⁶ = 4, ϕᶻ⁶ = 0

Hence, SHAP values for the instance i are given by:

ϕˣ = (ϕˣ¹ + ϕˣ² + ϕˣ³ + ϕˣ⁴ + ϕˣ⁵ + ϕˣ⁶)/6 = (-3–7–3–3–7–7)/6 = -5

ϕʸ = (ϕʸ¹ + ϕʸ² + ϕʸ³ + ϕʸ⁴ + ϕʸ⁵ + ϕʸ⁶)/6 = (0+4+0+0+4+4)/6 = 2

ϕᶻ = (ϕᶻ¹ + ϕᶻ² + ϕᶻ³ + ϕᶻ⁴ + ϕᶻ⁵ + ϕᶻ⁶)/6 = (0+0+0+0+0+0)/6 = 0

And explanation for the prediction for instance i (20) = ϕ⁰ + ϕˣ + ϕʸ + ϕᶻ = 23 + (-5) + 2 + 0 = 20; which can be explained as below:

The base value of the prediction in the absence of any information on independent variables is 23; knowing x=150 decreased the prediction by 5 and knowing y = 75 increased the prediction by 2 giving a final prediction of 20. Knowing z = 300 had no impact on the model prediction.

SHAP provides a good pictorial representation of this explanation as below. Blue color indicates that x value (=150) decreased the prediction and red color indicates that y value(=75) increased the prediction. The code file for this example is available here: Tree_SHAP_Hypothetical_Example.ipynb.

Actual Tree SHAP Algorithm

The computational complexity of the above algorithm is of the order O(LT2ᴹ), where T is the number of trees in the tree ensemble model, L is maximum number of leaves in any tree and M is the number of features. In the Tree SHAP paper², the authors propose a modified version of this algorithm that keeps track of the number of subsets S that flow into each node of the tree. The modified algorithm has a computational complexity of O(LTD²) where D is the max depth of the tree.

SHAP Interaction Values

SHAP allows us to compute interaction effect by considering pairwise feature attributions. This leads to a matrix of attribution values representing the impact of all pairs of features on a given model prediction. SHAP interaction effect is based on Shapley interaction index from game theory and is given by

where,

The above equation indicates that the SHAP interaction value of i-th feature w.r.t to j-th feature can be interpreted as the difference between SHAP values of i-th feature with & without the j-th feature. This allows us to use the algorithm for computing SHAP values to compute SHAP interaction values.

SHAP interaction effect between the i-th and j-th is split equally (i.e. ϕᵢⱼ=ϕⱼᵢ) and the total interaction effect is ϕᵢⱼ + ϕⱼᵢ. The main effect for the prediction can then be obtained as the difference between SHAP value and sum of SHAP interaction values for a feature:

Tree SHAP On A Real Dataset

Now let’s explore the Tree SHAP algorithm further using the UCI credit card default dataset. A binary variable “default payment next month” with values {0: No, 1:Yes} which indicates whether a customer has defaulted on his/her credit card payment is the target variable (dependent variable) and 23 variables related to the customer such as age, education etc and his/her previous billing and payment history are available as explanatory variables (independent variables). Read more about the dataset from the UCI website. We will use google colab to run our codes. Find the code file uploaded here: Tree_SHAP_UCI_Credit_Card_Default.ipynb.

Let’s start by installing the shap library and loading all required libraries.

Next, download the dataset from UCI website & read the data into a pandas dataframe. Note that we have dropped the ID column in the dataset.

Let’s visualize the explanatory variables in the dataset.

x-axis label 0-No Default; 1-Default
legend: 0-No Default; 1-Default

Let’s look at the distribution of the target variable.

We will train a lightgbm model on this dataset. We see that PAY_* columns have values ranging from -2 to 8. Since lightgbm considers all negative values as missing values, we add 3 to these values to make the range to 1 to 11.

Let’s copy the target to a new variable y and split the data into training & test sets stratified on the target variable.

We will now build a baseline lightgbm model.

Let’s now tune the hyperparameters of this model using hyperopt. Refer to my blog post here if you need more information on using hyperopt.

Tuned model gave a ROC AUC of 78.4%. Hyperparameter values that gave the best results are:

Let’s refit the model for the best hyperparameters found.

We will create new copy (X_test_disp) of our test dataset with the integer-coded categorical variables replaced with the corresponding category values so that SHAP plots will be more intuitive. We get the details of the category levels from UCI website. Remember that we added 3 to all PAY_* variables. Also, we use “Unk_*” for levels not defined in the UCI site.

Computing SHAP Values

Let’s compute the shap values.

Let me briefly explain the argument to shap.TreeExplainer() function.

  • model: A tree based model. Following models are supported by Tree SHAP at present: XGBoost, LightGBM, CatBoost, Pyspark & most tree-based models in scikit-learn.
  • data: Dataset to compute marginal contributions of feature variables in Tree SHAP algorithm. This is an optional argument with a default value of None. Unless provided, the training dataset will be used as explained in the algorithm section above.
  • feature_perturbation: Can take two values. ‘tree_path_dependent’ is the default value when the data argument is None. ‘interventional’ is the default value if a data argument is provided. Interventional approach uses the shared dataset to compute conditional expectations in the presence of correlated input features. “tree_path_dependent” approach on the other hand uses the training dataset and the number of samples that fall into each node as explained in the algorithm section.
  • model_output: With model_output=’raw’ (default), SHAP values explain the raw predictions from the leaf nodes of the trees. Since most gradient boosting classification models predict logit (log-odds) in their leaf nodes, SHAP values explain the logit prediction for GBM models by default. See the explanation here. Other possible values are “probability”, “log_loss”, or any model method name. With “log_loss” SHAP values explain the log of model’s loss function. See ?shap.TreeExplainer for more details.

For classification problems, explainer.shap_values() return a list of size n_classes. Since this is a binary classification model n_classes=2. Each object of this list is an array of size [n_samples, n_features] and corresponds to the SHAP values for the respective class. In this example, shap_values[1] is the SHAP values for the positive class (default payment next month = Yes) & shap_values[0] is the SHAP values for the negative class. For regression models, we get a single set of shap values of size [n_samples, n_features].

Explaining a Single Prediction

Let’s explain the prediction for the first item in the testset.

The base logit value for the positive class over the training dataset is -1.538. The logit prediction for this sample is -0.27. PAY_2 = 2 months (i.e. repayment status for the month of Aug/2005 = 2 months delay) contributed maximum to increase the logit; followed by LIMIT_BAL = 5e+4.

Explaining Predictions for a More Than One Sample

If we take the above plot for each sample, rotate them 90 degrees and stack them side-by-side, we can explain the predictions for multiple samples in a single plot:

Samples are ordered by similarity by default.However, we can change this ordering to be by output value, or by original sample order or by any numeric independent variable in the dataset. The color codes mean the same as in previous chart, with red colored features increasing the logit & the blue colored feature decreasing the logit for each data point.

SHAP Summary Plots
shap.summary_plot() can plot the mean shap values for each class if provided with a list of shap values (the output of explainer.shap_values() for a classification problem) as below:

Note that shap_values for the two classes are additive inverses for a binary classification problem. The above plot will be much more intuitive for a multi-class classification problem. We can also generate the above plot for just the class of our interest as below.

If provided with a single set of SHAP values (shap values for a single class for a classification problem or shap values for a regression problem), shap.summary_plot() creates a density scatter plot of SHAP values for each feature to identify how much impact each feature has on the model output. Features are sorted by the sum of the SHAP value magnitudes across all samples.

Note that we get grey colored points for categorical data as the integer encoded values (for a categorical variable) cannot be always used to arrange it from low to high. However, for this dataset the PAY_* variables roughly corresponds to the number of months of payment delay and can therefore be used to sort the values. Let’s plot the same graph as above treating PAY_* variables as numeric variables (all we have to do is to replace X_test_disp with X_test above).

This plot is much more intuitive. We observe that the log-odds for default increases as payment delay increases for all PAY_* variables. Also note that the log-odds for default decreases as payment amount (PAY_AMT*)increases. A surprising observation from the above plot is that the log-odds of default decrease as the amount of given credit (LIMIT_BAL) increases. This may be because the credit card company must be offering higher credit to customer with less probability of default. So customers with high LIMIT_BAL must have high values for PAY_AMT* too. Let’s investigate this further with SHAP dependency plots.

SHAP Dependence Plots

We see that customer with high LIMIT_BAL also have high values for PAY_AMT1. The cluster of red-ish colored dots towards the bottom of the band for high LIMIT_BAL samples shows the interaction effect of LIMIT_BAL with PAY_AMT1.

We can also create a SHAP dependence plot with a categorical variables as below.

SHAP can also select the interaction variable that seems to have the strongest interaction with the main variable. Lets create a shap dependence plot for top 3 features with highest impact on the model (mean shap values).

SHAP Interaction Values and Interaction Summary Plot

SHAP interaction values are currently not supported for LightGBM & CatBoost models because of the special treatment these libraries give for categorical variables. See the comments on github issues here and here. See NHANES survival model with XGBoost and SHAP interaction values notebook in the SHAP GitHub page for an example usecase.

The code files for this article is uploaded here: Tree_SHAP_UCI_Credit_Card_Default.ipynb.

Link to other articles in this series:

SHAP Part 1: An Introduction to SHAP

SHAP Part 2: Kernel SHAP

References

  1. SHAP: A Unified Approach to Interpreting Model Predictions. arXiv:1705.07874
  2. Consistent Individualized Feature Attribution for Tree Ensembles. arXiv:1802.03888 [cs.LG]
  3. Interpretable Machine Learning — A Guide for Making Black Box Models Explainable.
  4. https://github.com/slundberg/shap

--

--