Hands-On SHAP: Practical Implementation for Image, Text, and Tabular Data

SHREERAJ
7 min read2 days ago

--

Welcome to my Sixth Article in this series on Explainable AI.

Source Google

Brief Recap of Fifth Article on Explainable AI :

In my Fifth Article on Explainable AI, we explored SHAP (SHapley Additive exPlanations) and its application across text, images, and tabular data. SHAP plays a crucial role in enhancing model interpretability by providing consistent and theoretically grounded explanations for AI predictions. It leverages concepts from game theory to attribute feature importance, offering both global and local insights into model behavior.

SHAP works by calculating the contribution of each feature to a prediction, considering all possible combinations of features. This approach ensures a fair distribution of the model’s output among the input features. SHAP values represent the importance of a feature by measuring the change in the model’s output when that feature is present versus when it’s absent.

Source Google

In this article, we’ll implement SHAP practically on these data types to gain deeper insights into model predictions.

1. Setting Up the Environment in Google Colab:

  • Open Google Colab Notebook
  • Install all Require Libraries
# Install required libraries
!pip install lime shap scikit-learn numpy pandas matplotlib tensorflow pillow

# Note: After installation, restart the runtime to ensure all libraries are properly loaded.

2. SHAP Implementation for Text Data:

Algorithm Flow:

Prepare text data and train a model: Load and preprocess text data, vectorize it (e.g., using TF-IDF), and train a model (e.g., LinearSVC).
Create a SHAP explainer: Use `shap.LinearExplainer` with the trained model and vectorized text data.
Generate SHAP values for text instances: Transform the text instance to explain into vector form and compute SHAP values using the explainer.
Visualize word importance: Use `shap.force_plot` to visualize the SHAP values and interpret the importance of words/features in the text instance.

Code Implementation:

import shap
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
import numpy as np

# Ensure initjs is called
shap.initjs()

# Prepare data
train_text, train_labels = shap.datasets.imdb()
vectorizer = TfidfVectorizer(max_features=10000)
train_vectors = vectorizer.fit_transform(train_text)

# Train model
model = LinearSVC().fit(train_vectors, train_labels)

# Create explainer
explainer = shap.LinearExplainer(model, train_vectors)

# Generate SHAP values
text_to_explain = "This movie was really great"
x = vectorizer.transform([text_to_explain])

# Convert sparse matrix to dense array
x_dense = x.toarray()

# Generate SHAP values for the dense array
shap_values = explainer.shap_values(x_dense)

# Visualize
shap.force_plot(explainer.expected_value, shap_values, x_dense)

Output:

This SHAP force plot visualizes how two features influence a model’s prediction for a specific instance. Starting from a base value of 1.52, Feature 3997 significantly increases the prediction by 0.5955 (shown in pink/red), while Feature 9683 slightly decreases it by 0.3599 (in blue). The plot effectively illustrates each feature’s impact direction and magnitude, with the overall prediction leaning positive due to Feature 3997’s stronger influence. In a text analysis context, these features likely represent words from the input “This movie was really great,” with positive words like “great” potentially corresponding to the impactful Feature 3997. This visualization provides clear insights into the model’s decision-making process for this particular prediction.

SHAP Explanation

3. SHAP Implementation for Image Data:

Detailed Algo:

  1. Load Pre-trained Image Model: Load a pre-trained image classification model (e.g., ResNet50).
  2. Prepare Image Data: Load and preprocess the image data.
  3. Retrieve ImageNet Class Names: Download and parse the ImageNet class names.
  4. Define the Model Prediction Function: Create a function that preprocesses input data and returns model predictions.
  5. Create a SHAP Explainer: Use SHAP to create an explainer object with the model and image masker.
  6. Generate SHAP Values for Images: Compute SHAP values for a subset of images.
  7. Visualize Feature Importance: Use shap.image_plot to visualize the SHAP values.

Code Implementation:

import json
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import shap
import numpy as np
import matplotlib.pyplot as plt

# load pre-trained model and data
model = ResNet50(weights="imagenet")
X, y = shap.datasets.imagenet50()

# Assuming X[8] contains integer data that needs to be scaled to the range [0, 255]
X = np.clip(X, 0, 255).astype(np.uint8)

# getting ImageNet 1000 class names
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
class_names = [v[1] for v in json.load(file).values()]

print("Number of ImageNet classes:", len(class_names))
print("Class names:", class_names)

def f(x):
tmp = x.copy()
preprocess_input(tmp)
return model(tmp)

# define a masker that is used to mask out partitions of the input image.
masker = shap.maskers.Image("inpaint_telea", X[0].shape)

# create an explainer with model and image masker
explainer = shap.Explainer(f, masker, output_names=class_names)

# here we explain two images using 100 evaluations of the underlying model to estimate the SHAP values
shap_values = explainer(
X[1:3], max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]
)

# output with shap values
shap.image_plot(shap_values)

Output:

4. SHAP Implementation for Tabular Data:

Detailed Algo:

First, import the necessary libraries, including `pandas`, `numpy`, `shap`, `matplotlib.pyplot`, and relevant functions from `sklearn`. Next, load the Iris dataset, convert it into a DataFrame, and add the target variable. Split the data into training and testing sets and scale the features using `StandardScaler`. Train a `RandomForestClassifier` on the scaled training data. Create a SHAP explainer using `shap.Explainer` with the trained model and scaled training data, specifying `feature_perturbation=”interventional”`. Select an instance from the scaled test data and compute SHAP values for this instance. Finally, visualize the SHAP values using a waterfall plot and print a DataFrame showing the feature importance based on the SHAP values.

Code Implementation:

import matplotlib.pylab as pl
import numpy as np
import xgboost
from sklearn.model_selection import train_test_split
import shap

# Install necessary libraries
!pip install shap xgboost
!pip install plotly

# Print the JS visualization code to the notebook
shap.initjs()

# Load the dataset
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
print(X_display)

# Create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)

# Define parameters for the model
params = {
"eta": 0.01,
"objective": "binary:logistic",
"subsample": 0.5,
"base_score": np.mean(y_train),
"eval_metric": "logloss",
}

# Train the model
model = xgboost.train(
params,
d_train,
5000,
evals=[(d_test, "test")],
verbose_eval=100,
early_stopping_rounds=20,
)

# Import accuracy_score from sklearn
from sklearn.metrics import accuracy_score

# Make predictions
y_pred_prob = model.predict(d_test)
y_pred = (y_pred_prob > 0.5).astype(int)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")

# Set the default renderer for Plotly to 'colab'
import plotly.io as pio
pio.renderers.default = 'colab'

# Initialize JS visualization
shap.initjs()

# Explain the model's predictions using SHAP
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# Visualize a specific prediction
shap.force_plot(explainer.expected_value, shap_values[100, :], X_display.iloc[100, :])

# SHAP summary plot (bar type)
shap.summary_plot(shap_values, X_display, plot_type="bar")

# SHAP summary plot
shap.summary_plot(shap_values, X)

Output:

Suggestions for Further Exploration:

Source Google
  • Apply SHAP to different models and datasets in your field.
  • Compare SHAP with other interpretability techniques like LIME.
  • Use SHAP to identify and address model biases.
  • Investigate how SHAP values change across different instances or subgroups.
  • Explore SHAP’s various visualization options for different data types.

Conclusion:

Source Google

We’ve explored SHAP’s practical implementation for image, text, and tabular data, demonstrating its ability to provide local and global explanations for complex models. Key takeaways:

- SHAP uses game theory concepts to create consistent, theoretically grounded explanations
- It’s adaptable to various data types and model architectures
- SHAP values show each feature’s contribution to predictions relative to a baseline
- It offers both local (individual prediction) and global (overall model behavior) insights

Link For Seventh Article On Explainable AI : Explainable AI for Communicable Disease Prediction: A Breakthrough in Healthcare Technology

References:

  1. SHAP Documentation
  2. YouTube Playlist On Explainable AI

In our next article, we will delve into a research paper that explores the intersection of healthcare and explainable AI (XAI). This paper presents innovative methods where AI aids healthcare decision-making while providing transparent, interpretable insights. We will examine advanced AI techniques, real-world applications, and the importance of explainability for healthcare professionals. Additionally, we will discuss the challenges of integrating explainable AI in healthcare and future directions for improving AI and healthcare collaboration. By addressing the “black box” problem, explainable AI ensures AI’s role in healthcare is powerful, accountable, and understandable. Stay tuned for an insightful exploration of how explainable AI is revolutionizing healthcare.

Generated By DALLE3

--

--