Interpreting CNNs using LIME
moving ahead from tabular data
Till now we have covered a lot of things in Interpretable AI, be it the basics, model-specific interpretation methods like in case of Linear Regression and Decision Trees or agnostic methods like PDPs, LIME and SHAP.
But all these were on tabular data.
It’s time we use some complex dataset, say Images !
My debut book “LangChain in your Pocket” is out now
Images differ a lot from tabular datasets (obviously). If you remember, in any of the previous interpretation method we have discussed so far, we talk about interpretation of models in terms of feature importance sort of metric or visualization. Like feature ‘A’ has more influence than feature ‘B’ in the prediction.
But we don’t have any specific features in images that we can name. How would interpretation work here?
Basically, we will be highlighting regions of importance in the image that lead to the prediction by the model. But let’s first understand how LIME method is tweaked to incorporate images as well.
You might wish to revisit how LIME works for tabular data first
So, let me summarize LIME for you
Generate summary stats for training dataset
Generate an artificial dataset using the summary stats
Pickup a random sample from original dataset
Assign weights to samples in artificial dataset depending upon closeness/proximity to random sample picked
Train a white-box model on these weighted samples
Interpret the white-box model
Now, incase of images, the major blocker in the above approach is how to generate random samples as summary stats would be of no use in this case
How to generate the artificial dataset?
The naive method would be to
- Pickup a random sample from the dataset
- Randomly turn on (keep as it) & off (set as 0) some pixels to generate new dataset
But you missed onto one point
Usually in images, the objects present (say dog vs cat classification , objects: Dog & Cat) that led to the prediction by the model span over multiple pixels and not on one. So even if you turn-off a pixel or two from that object , they will still look very similar to the sample we picked up.
What we need to do is set a pool of neighbouring pixels ON and OFF together to bring the randomness in the dataset we are creating. Hence, what we do is to segment the image in multiple segments called superpixels and than turn on and off these superpixels to generate random samples.
Let’s jump onto the codes to interpret a baseline CNN using LIME for binary classification. So, we have the below samples to classify.
Class 0: Random images with white rectangles of arbitrary shapes
Class 1 : Random images
Let’s create a baseline CNN
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.layers import Input, Dense, Embedding, Flatten
from keras.layers import SpatialDropout1D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.models import Sequential
from randimage import get_random_image, show_array
import random
import pandas as pd
import numpy as np
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries#preparing above dataset artificially
training_dataset = []
training_label = []
for x in range(200):
img_size = (64,64)
img = get_random_image(img_size)
a,b = random.randrange(0,img_size[0]/2),random.randrange(0,img_size[0]/2)
c,d = random.randrange(img_size[0]/2,img_size[0]),random.randrange(img_size[0]/2,img_size[0])
value = random.sample([True,False],1)[0]
if value==False:
img[a:c,b:d,0] = 100
img[a:c,b:d,1] = 100
img[a:c,b:d,2] = 100
training_dataset.append(img)
training_label.append(value)#training baseline CNN model
training_label = [1-x for x in training_label]
X_train, X_val, Y_train, Y_val = train_test_split(np.array(training_dataset).reshape(-1,64,64,3),np.array(training_label).reshape(-1,1), test_size=0.1, random_state=42)epochs = 10
batch_size = 32model = Sequential()
model.add(Conv2D(32, kernel_size=3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Flatten())
# Output layer
model.add(Dense(32,activation='relu'))
model.add(Dense(1, activation='sigmoid'))model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_val, Y_val), epochs=epochs, batch_size=batch_size, verbose=1)
I won’t be explaining the above snippet as self explanatory
Next, let’s bring in LIME into the picture
x=10
explainer = lime_image.LimeImageExplainer(random_state=42)
explanation = explainer.explain_instance(
X_val[x],
model.predict,top_labels=2)
)
image, mask = explanation.get_image_and_mask(0, positives_only=True,
hide_rest=True)
The above code snippet requires some explanation
- We initialized LimeImageExplainer
- This object can be used to explain the output for a particular sample using explain_instance. Here we took the 10th sample from validation set
- get_image_and_mask() helps to return highlighted areas which led to the prediction by the model alongside original image.
Let’s see a few samples which were actualy 1(random images) but detected as 0 (random images with white boxes)
As it can be observed on the right image, we have highlighted region bordered by yellow color. This images are label=1 but labeled as 0 as you can observe the highlighted region does look like a rectangle hence confusing for the model.
Similar to above example, this has also been predicted by the model as class=0. On interpreting, can see where the problem lies in !! again, some sort of shape has been misinterpreted by the model.
So, you see, how we can understand what is the actual problem with the model leading to misclassification and this is why interpretable and explainable AI is so important.