Explaining Trees with FastTreeShap and What if tool (Part-2)

Zabir Al Nazi Nabil
5 min readApr 2, 2023

--

Machine Learning Workflow (adjusted)

Read part-1 here!

To explain machine learning models, there are many methods. The frequently used ones are GradCam, SHAP, Lime, Integrated Gradients, etc. In this article, we will only focus on SHAP.
SHAP values (SHapley Additive exPlanations) is a method based on cooperative game theory and is used to increase the transparency and interpretability of machine learning models. SHAP values represent a feature’s responsibility for a change in the model’s output. We can understand how a single feature affects the output of the model using the SHAP values of that feature.

Tree SHAP is a fast and exact method to estimate SHAP values for tree models and ensembles of trees, under several different possible assumptions about feature dependence. Usually, the Tree SHAP implementation from the shap library is pretty efficient. However, I will show you another interesting implementation of TreeSHAP from LinkedIn. https://github.com/linkedin/FastTreeSHAP is based on the paper: “Fast TreeSHAP: Accelerating SHAP Value Computation for Trees”.

from sklearn.datasets import load_breast_cancer
cancer_ds = load_breast_cancer()
cancer_x = cancer_ds.data
cancer_y = cancer_ds.target
print("Feature set shape: ", cancer_x.shape)
print("Target shape: ", cancer_y.shape)
print("Features: ", cancer_ds.feature_names)
print("Targets: ", cancer_ds.target_names)
Feature set shape:  (569, 30)
Target shape: (569,)
Features: ['mean radius' 'mean texture' 'mean perimeter' 'mean area'
'mean smoothness' 'mean compactness' 'mean concavity'
'mean concave points' 'mean symmetry' 'mean fractal dimension'
'radius error' 'texture error' 'perimeter error' 'area error'
'smoothness error' 'compactness error' 'concavity error'
'concave points error' 'symmetry error' 'fractal dimension error'
'worst radius' 'worst texture' 'worst perimeter' 'worst area'
'worst smoothness' 'worst compactness' 'worst concavity'
'worst concave points' 'worst symmetry' 'worst fractal dimension']
Targets: ['malignant' 'benign']

I use the breast_cancer dataset from sklearn. The dataset has 30 features and 569 rows. The features are: ‘mean radius’, ‘mean texture’, ‘mean perimeter’, ‘mean area’, ‘mean smoothness’, ‘mean compactness’, ‘mean concavity’, ‘mean concave points’, ‘mean symmetry’, ‘mean fractal dimension’, ‘radius error’, ‘texture error’, ‘perimeter error’, ‘area error’, ‘smoothness error’, ‘compactness error’, ‘concavity error’, ‘concave points error’, ‘symmetry error’, ‘fractal dimension error’, ‘worst radius’, ‘worst texture’, ‘worst perimeter’, ‘worst area’, ‘worst smoothness’, ‘worst compactness’, ‘worst concavity’, ‘worst concave points’, ‘worst symmetry’, ‘worst fractal dimension’. We have to solve a binary classification problem of detecting if a cancer sample is ‘malignant’ or ‘benign’.

Let’s do a simple 75:25 train-test split.

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(cancer_x, cancer_y, test_size = 0.25, random_state = 42)
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
(426, 30)
(143, 30)
(426,)
(143,)

Now, let’s train the lightgbm model on the training set. We achieved an F1 score of 95%, not entirely useless!

! pip install lightgbm

import lightgbm as lgb
import numpy as np
train_data = lgb.Dataset(data=x_train, label=y_train)
model = lgb.train(train_set=train_data, params = {})
y_preds = np.array(model.predict(x_test))from sklearn import metricsfpr, tpr, thresholds = metrics.roc_curve(y_test, y_preds, pos_label=1)
metrics.auc(fpr, tpr)

0.9945900957136913
from sklearn.metrics import f1_score
f1_score(y_test, np.array([0 if a < 0.5 else 1 for a in y_preds ]))
0.9545454545454545

Now, let’s try to explain the predictions of the model. We use the TreeExplainer from the fasttreeshap library. There are multiple versions of the algorithm, we will choose v2 since it’s relatively faster.

I have plotted the SHAP values for the features in the first row. It’s evident that the worst area, mean concave points and worst concave points are impacting the prediction score the most. The mean perimeter and worst concavity are influencing the model score in the opposite direction. But it’s a little tricky to observe different random samples and understand how the features are impacting the results. What if we had an interactive tool for that?

SHAP Explanations for Cancer Classification

What If Tool will be very useful in this scenario. We can plug a custom predict function that will take the feature inputs from the UI and pass them to the model to get the model score and SHAP values. WIT can be used for the following purposes (and more).

  • Interactively changing a feature value and seeing how the prediction scores and the SHAP values change. It’s really useful if we want to understand if the model has learned the feature space in a way that can be explained with common sense. For instance, let’s say we are trying to predict the price of an apartment. If all the features are constant, we expect the price to go up if we increase the area.
Feature perturbation with SHAP attributes in WIT
  • Instead of experimenting with many samples, we can also look at the partial dependence plot of the feature directly if we are only interested in the change in the model score. The partial dependence plot (short PDP or PD plot) shows the marginal effect one or two features have on the predicted outcome of a machine learning model (J. H. Friedman). A partial dependence plot can show whether the relationship between the target and a feature is linear, monotonic, or more complex.
Partial dependence plot in WIT
  • We can also simply look at the feature distribution.
Feature distribution plot in WIT
  • A counterfactual explanation indicates the smallest change in feature values that can translate to a different outcome. What-if counterfactuals answer the query of what the model would forecast in the event that you altered the action input. They make it possible to comprehend and troubleshoot a machine-learning model by observing how it responds to input (feature) changes. Conventional interpretability approaches can rank features based on their predictive value or approximate a machine learning model. Contrarily, counterfactual analysis “interrogates” a model to discover what adjustments to a certain data point might cause the model’s conclusion to alter. In simple words, let’s say we have two very similar cancer samples, and the feature values look very similar still one is benign and another is malignant! WIT will find you counterfactual samples for any data point.
Counterfactual analysis with WIT
  • You can adjust the threshold for a classification problem to optimize precision or recall in WIT too.
Tuning binary classification threshold in WIT

PAIR Code recently released the “Learning Interpretability Tool (LIT)”. The Learning Interpretability Tool (🔥LIT, formerly known as the Language Interpretability Tool) is a visual, interactive ML model-understanding tool that supports text, image, and tabular data. It can be run as a standalone server, or inside of notebook environments such as Colab, Jupyter, and Google Cloud Vertex AI notebooks.

The code used in this article can be found here: https://github.com/zabir-nabil/What-If-Explainability

--

--