Explainable AI (for image classification), on Vertex AI

Anurag Bhatia
Maven Wave Data Science Blog
8 min readFeb 1, 2023

What is XAI?: Explainable AI is a subset of machine learning where the focus is more on understanding what goes on under the hood when a machine (or model) learns something, rather than being obsessed with getting better performance metrics even if it involves increasing model complexity several fold.

Image credit: Analytics Vidhya

Why should we care?:

  1. In heavily-regulated industries like finance and healthcare, it’s not just good enough to have good model performance metrics. Since it involves money — or even more so if it’s a matter of life and death — we also need to be in a position to interpret and explain the model behavior to other stakeholders (e.g. business teams, regulators, etc.).
  2. Troubleshooting: What if the model performance metrics are not good enough? It can be super helpful if we are able to a) get underneath the model b) dig deeper into how the model is interpreting the annotated data, and finally c) does that align with what we assumed the model would learn? That is where model explainability comes in.

Levels of model explainability: There are two different ways in which model inference results can be explained. The global-level (sometimes also referred to as the model level) is a summary of which features carry a higher weightage vis-a-vis others. This is useful information if the dataset happens to be of a structured, tabular format. e.g. a chart showing measured importance of each column..

Image credit: yellowbrick

It’s a generic summary across all predictions from a trained model, while the instance-level provides details specific to each prediction from that model. e.g. In an image-classification problem, the instance-level explainability can narrow down to the specific pixels in an image which decide which class is to be assigned to that image. By definition, such pixel info for an image will be completely independent from that of the next image.

Though LIME and KernelSHAP are also used to interpret the results of image models, we are primarily going to focus on the ‘Integrated Gradients’ (IG) and ‘Explanation with Ranked Area Integrals’ (XRAI) methods here.

To begin with IG, a baseline image is created which usually consists of all pixels as 0s, all of them as 1s, or random values between 0 and 1. In a series of steps thereafter, each pixel value is changed from the one in the baseline image to its actual counterpart value in the actual image. And for each such step/change, a gradient is calculated. These gradients are integrated and the pixels with the highest ones are superimposed on top of the actual image, to highlight which specific pixels in an image were considered most relevant to decide the appropriate class assigned to that image. i.e. The most influential image pixels (as decided by the model) are given the most weightage and depicted accordingly. e.g.

Image: https://github.com/CVxTz/IntegratedGradientsPytorch

By definition, that implies that two of the most important factors in IG method are:

  1. the baseline image (pixel values) chosen. Reason: The model explanation is relative to the baseline. Hence, one option is to decide baseline pixel values based on the usual type of images in our image classification problem. e.g. random values for X-ray type images (where it’s safe to assume that blackish regions represent something very specific i.e. tissue in this case).
  2. num-integral-steps i.e. the number of steps involved between the baseline image and the input image for calculating numerical integration mentioned above. Higher this value, usually higher is the model explainability. The flipside, of course, is that the operation becomes more computationally intensive.

XRAI: Weights of the trained model are used to further train a network used for interpretation. The goal here is to use existing, relatively known and simpler preprocessing steps (functions) to arrive at an approximation which is (hopefully) not too different from the one used in the originally trained model. For each patch in the image, the pixel-level attributions are integrated and patches are clubbed into regions on the basis of the values of their integrated gradients. Regions are removed one by one and trained model is invoked multiple times, to ascertain the importance of each region while inferencing.

Regarding comparisons between IG and XRAI, their respective common use-cases and when to choose one over the other, here is a pretty concise summary ..

Source: Google Cloud

Alright. Enough of theory. Time to dive into the code. We have chosen fire detection as our objective and it’s a binary classification problem since our manually curated dataset has only two classes: Fire and No-Fire. For the purpose of this blog, we assume we have already trained a model using keras and saved its checkpoints in Cloud Storage (GCS) and are now interested in knowing more about why our trained model is behaving the way it is. We start by loading our checkpoints and mentioning a few other details on what the existing trained model expects in terms of input image (for getting prediction).

Explainability methods rely on their ability to 1) create multiple (perturbed) versions of the original input image and 2) invoke the trained model and get predictions for each of those image versions. Hence, we’ll create model signatures to get this done. The first one is to perform the usual preprocessing steps in case of an image file e.g. read jpg file from GCS, get the bytes, apply padding, convert it to a tensor etc..

We create a second one which takes the output of the first one as an input, invokes the model, gets the class having the highest probability and assigns the label to that input image accordingly..

We then stitch together everything we have done so far. Summary: take our existing trained model in GCS as the starting point, create model signatures required to make the model compatible with requirements of explainable-AI, and finally save it (locally, for now).

Until now, we have adjusted our trained model for explainable-AI, but have not explicitly mentioned which method to be used for it and values of required arguments (e.g. baseline pixels to be used, number of steps involved etc.), but now is the right time to do so. We do it in the form of a metadata file (in json format) which is created programmatically through an inbuilt method for this purpose in Explainable-AI SDK..

We write a custom shell script to deploy our model on Vertex AI. It involves a) creating a model b) create a version of that model and finally c) deploy it on a Vertex AI endpoint..

And finally, we prepare a JSON payload request and save the response back in another JSON file..

request json

Finally, we use matplotlib to visualize the results of the model explainability we had set out to achieve in the first place. Here are the results..

How to interpret this?: Results for images 2 & 3 (from left) look fine, since maximum importance is being given to that part of the image where the burning flame is the strongest. Results for images 4 & 5 have results which are mixed. Though they do rightly focus on the bright flame areas, they also include parts which they shouldn’t have. e.g. portion of trees and individuals in images 4 and 5 respectively. Most importantly, the result for image 1 certainly suggests something is not right. The upper half of the electricity tower seems to be getting the most attention, rather than the fire engulfing the field/hill elsewhere in the picture.

What exactly is the value-add?: Model explainability can guide us towards the next steps ahead. e.g. Given our results in the image above, it suggests that even if the model prediction is correct (as is the case with all 5 input images above), the model may have learnt something very different from what we thought it would. Some of our options, going forward, are:

  • Check whether there are lot of images with electricity tower, categorized as “Fire”. In that case, we might have inadvertently ended up inducing a possible bias in our model.
  • Compare the difference after increasing value of ‘num-integral-steps’
  • Check explainability results for 1) images of “No Fire” label as well 2) images of “Fire” label but those which show something in daylight (rather than just those at night, as is the case in our rather limited sample of 5 input images chosen above)
  • Add more images for model training, preferably in different conditions.
  • Do some more data augmentation before/while training.

Hope this was helpful in explaining the concept of Explainable-AI and its implementation in a specific scenario. For more details about this fire-detection project, take a look at my github repository.

Happy model explaining..

References: The biggest inspiration for this blog has been the awesome computer vision book by Valliappa Lakshmanan, Martin Görner, and Ryan Gillard. I found their code snippets super helpful and applied them on an image classification problem for an entirely different use-case and dataset. i.e. fire-detection.

--

--