Interpreting CNNs using LIME

moving ahead from tabular data

Mehul Gupta
Data Science in your pocket

--

Till now we have covered a lot of things in Interpretable AI, be it the basics, model-specific interpretation methods like in case of Linear Regression and Decision Trees or agnostic methods like PDPs, LIME and SHAP.

But all these were on tabular data.

It’s time we use some complex dataset, say Images !

My debut book “LangChain in your Pocket” is out now

Images differ a lot from tabular datasets (obviously). If you remember, in any of the previous interpretation method we have discussed so far, we talk about interpretation of models in terms of feature importance sort of metric or visualization. Like feature ‘A’ has more influence than feature ‘B’ in the prediction.

But we don’t have any specific features in images that we can name. How would interpretation work here?

Basically, we will be highlighting regions of importance in the image that lead to the prediction by the model. But let’s first understand how LIME method is tweaked to incorporate images as well.

You might wish to revisit how LIME works for tabular data first

So, let me summarize LIME for you

Generate summary stats for training dataset

Generate an artificial dataset using the summary stats

Pickup a random sample from original dataset

Assign weights to samples in artificial dataset depending upon closeness/proximity to random sample picked

Train a white-box model on these weighted samples

Interpret the white-box model

Now, incase of images, the major blocker in the above approach is how to generate random samples as summary stats would be of no use in this case

How to generate the artificial dataset?

The naive method would be to

  • Pickup a random sample from the dataset
  • Randomly turn on (keep as it) & off (set as 0) some pixels to generate new dataset

But you missed onto one point

Usually in images, the objects present (say dog vs cat classification , objects: Dog & Cat) that led to the prediction by the model span over multiple pixels and not on one. So even if you turn-off a pixel or two from that object , they will still look very similar to the sample we picked up.

What we need to do is set a pool of neighbouring pixels ON and OFF together to bring the randomness in the dataset we are creating. Hence, what we do is to segment the image in multiple segments called superpixels and than turn on and off these superpixels to generate random samples.

Let’s jump onto the codes to interpret a baseline CNN using LIME for binary classification. So, we have the below samples to classify.

Class 0: Random images with white rectangles of arbitrary shapes

Class 1 : Random images

Let’s create a baseline CNN

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.layers import Input, Dense, Embedding, Flatten
from keras.layers import SpatialDropout1D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.models import Sequential
from randimage import get_random_image, show_array
import random
import pandas as pd
import numpy as np
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries
#preparing above dataset artificially
training_dataset = []
training_label = []
for x in range(200):

img_size = (64,64)
img = get_random_image(img_size)

a,b = random.randrange(0,img_size[0]/2),random.randrange(0,img_size[0]/2)
c,d = random.randrange(img_size[0]/2,img_size[0]),random.randrange(img_size[0]/2,img_size[0])

value = random.sample([True,False],1)[0]
if value==False:
img[a:c,b:d,0] = 100
img[a:c,b:d,1] = 100
img[a:c,b:d,2] = 100

training_dataset.append(img)
training_label.append(value)
#training baseline CNN model
training_label = [1-x for x in training_label]
X_train, X_val, Y_train, Y_val = train_test_split(np.array(training_dataset).reshape(-1,64,64,3),np.array(training_label).reshape(-1,1), test_size=0.1, random_state=42)
epochs = 10
batch_size = 32
model = Sequential()
model.add(Conv2D(32, kernel_size=3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Flatten())
# Output layer
model.add(Dense(32,activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_val, Y_val), epochs=epochs, batch_size=batch_size, verbose=1)

I won’t be explaining the above snippet as self explanatory

Next, let’s bring in LIME into the picture

x=10
explainer = lime_image.LimeImageExplainer(random_state=42)
explanation = explainer.explain_instance(
X_val[x],
model.predict,top_labels=2)
)

image, mask = explanation.get_image_and_mask(0, positives_only=True,
hide_rest=True)

The above code snippet requires some explanation

  • We initialized LimeImageExplainer
  • This object can be used to explain the output for a particular sample using explain_instance. Here we took the 10th sample from validation set
  • get_image_and_mask() helps to return highlighted areas which led to the prediction by the model alongside original image.

Let’s see a few samples which were actualy 1(random images) but detected as 0 (random images with white boxes)

Left (original) vs Right (highlighted region with yellow)

As it can be observed on the right image, we have highlighted region bordered by yellow color. This images are label=1 but labeled as 0 as you can observe the highlighted region does look like a rectangle hence confusing for the model.

Left (original) vs Right (yellow border added)

Similar to above example, this has also been predicted by the model as class=0. On interpreting, can see where the problem lies in !! again, some sort of shape has been misinterpreted by the model.

So, you see, how we can understand what is the actual problem with the model leading to misclassification and this is why interpretable and explainable AI is so important.

That’s a wrap !!

--

--