Explainable Artificial Intelligence
Interpreting Deep Learning Models for Computer Vision
Interpreting Convolutional Neural Network Models built with TensorFlow
Introduction
Artificial Intelligence (AI) is no longer a field restricted only to research papers and academia. Businesses and organizations across diverse domains in the industry are building large-scale applications powered by AI. The questions to think about here would be, “Do we trust decisions made by AI models?” and “How does a machine learning or deep learning model make its decisions?”. Interpreting machine learning or deep learning models has always been a task often overlooked in the entire data science lifecycle since data scientists or machine learning engineers would be more involved with actually pushing things out to production or getting a model up and running.
However, unless we are building a machine learning model for fun, accuracy is not the only thing which counts! Business stakeholders and consumers will often ask the tough questions of fairness, accountability and transparency of any machine learning model being used to solve real-world problems!
There is a whole branch of research for explainable artificial intelligence (XAI) now! While the scope of this article is not to cover XAI, if you are interested you can always read my series on XAI.
In this article, we will look at concepts, techniques and tools to interpret deep learning models used in computer vision, to be more specific — convolutional neural networks (CNNs). We will take a hands-on approach and implement our deep learning models using Keras and TensorFlow 2.0 and leverage open-source tools to interpret decisions made by these models! In short, the purpose of the article is to find out — what do deep learning models really see?
Convolutional Neural Networks
The most popular deep learning models leveraged for computer vision problems are convolutional neural networks (CNNs)!
CNNs typically consist of multiple convolution and pooling layers which help the deep learning model in automatically extracting relevant features from visual data like images. Due to this multi-layered architecture, CNNs learn a robust hierarchy of features, which are spatial, rotation, and translation invariant.
The key operations in a CNN model are depicted in the figure above. Any image can be represented as a tensor of pixel values. The convolution layers help in extracting features from this image (forms feature maps). Shallower layers (closer to the input data) in the network learn very generic features like edges, corners and so on. Deeper layers in the network (closer to the output layer) learn very specific features pertaining to the input image. The following graphic helps summarize the key aspects of any CNN model.
Since we are only concerned with understanding how CNN models perceive images, we won’t be training any CNN models from scratch here. Rather, we will be leveraging the power of transfer learning and pre-trained CNN models in our examples.
A pre-trained model like VGG-16 has already been pre-trained on a huge dataset (ImageNet) with a lot of diverse image categories. Considering this fact, the model should have already learned a robust hierarchy of features. Hence, the model, having learned a good representation of features for over a million images belonging to 1,000 different categories, can act as a good feature extractor for new images suitable for computer vision problems.
Interpreting CNN Models — What does a deep learning model really see?
Here’s the interesting part, can we really unbox the opacity presented to us by a seemingly black-box CNN model and try and understand what’s really going on under the hood and what does the model really see when it looks at an image? There are a wide variety of techniques and tools for interpreting decisions made by vision-based deep learning models. Some of the major techniques covered in this article are depicted as follows.
Let’s look at each of these techniques and interpret some deep learning CNN-based models built with Keras and TensorFlow.
SHAP Gradient Explainer
This technique tries to combine a multitude of ideas from Integrated Gradients, SHapley Additive exPlanations (SHAP) and SmoothGrad. This technique tries to explain model decisions using expected gradients (an extension of integrated gradients). This is a feature attribution method designed for differentiable models based on an extension of Shapley values to infinite player games. We will use the shap
framework here for this technique.
Integrated gradients values are a bit different from SHAP values, and require a single reference value to integrate from. However in SHAP Gradient Explainer, expected gradients reformulates the integral as an expectation and combines that expectation with sampling reference values from the background dataset. Thus this technique uses an entire dataset as the background distribution versus just a single reference value. Let’s try and implement this on some sample images. To get started, we load up some basic dependencies and model visualization function utilities.
The next step is to load a pre-trained VGG-16 model, which was trained previously on the Imagenet dataset. We can do that easily with the following code.
Once our CNN model is loaded, we will now load a small dataset of images which can be used as a background distribution and we will use four sample images for model interpretation.
We have four different types of images including a picture of one of my cats! Let’s first look at our model’s prediction for each of these images.
[['n02999410', 'chain'],
['n01622779', 'great_grey_owl'],
['n03180011', 'desktop_computer'],
['n02124075', 'Egyptian_cat']]
Let’s start by trying to visualize what the model sees in the 7th layer of the neural network (typically one of the shallower layers in the model).
This gives us some good perspective into the top two predictions made by the model for each image and why did it take such decisions. Let’s take a look at one of the deeper layers in the VGG-16 model and visualize the 14th layer’s decisions.
Now you can see the model gets stronger and more confident with the prediction decision based on the shap value intensities and also aspects like why the model predicts a screen
vs. a desktop_computer
where it also looks at the keyboard. Predicting my cat as a tabby
because of specific features like the nose, whiskers, facial patterns and so on!
Interpreting CNN Models built with TensorFlow 2.0
For the remaining four techniques, we will leverage a pre-trained model using TensorFlow 2.0 and use the popular open-source framework tf-explain
. The idea here is to look at different model intepretation techniques for CNNs.
Load Pre-trained CNN Model
Let’s load one of the most complex pre-trained CNN models out there, the Xception model which claims to be slightly better than the Inception V3 model. Let’s start by loading the necessary dependencies and our pre-trained model.
You can see from the model architecture snapshot above that this model has a total of 14 block with multiple layers in each block. Definitely one of the deeper CNN models!
Model Predictions on Sample Image
We will reuse the sample image of my cat and make the top-5 predictions with our Xception model. Let’s load our image first before making predictions.
Let’s now making the top-5 predictions on this image using our Xception model. We will pre-process the image before inference.
[[('n02124075', 'Egyptian_cat', 0.80723596),
('n02123159', 'tiger_cat', 0.09508163),
('n02123045', 'tabby', 0.042587988),
('n02127052', 'lynx', 0.00547999),
('n02971356', 'carton', 0.0014547487)]]
Interesting predictions, at least the top 3 here definitely are relevant!
Activation Layer Visualizations
This technique is typically used to visualize how a given input comes out of specific activation layers. The key idea is to explore which feature maps are getting activated in the model and visualize them. Usually this is done by looking at each specific layer. The following code showcases activation layer visualizations for one of the layers in Block 2 of the CNN model.
This kind of gives us an idea of which feature maps are getting activated and what parts of the image they typically focus on.
Occlusion Sensitivity
The idea of interpretation using occlusion sensitivity is quite intuitive. We basically try to visualize how parts of the image affects our neural network model’s confidence by occluding (hiding) parts iteratively. This is done by systematically occluding different portions of the input image with a grey square, and monitoring the output of the classifier.
Ideally specific patches of the image should be highlighted in red\yellow like a heatmap but for my cat image it kind of highlighted the overall image in a red hue, the reason for this could be because of the zoomed image of a cat. However the left side of the image has a higher intensity focusing more on the shape of the cat to an extent rather than the texture of the image.
GradCAM
This is perhaps one of the most popular and effective methods for interpreting CNN models. Using GradCAM, we try and visualize how parts of the image affects neural network’s output by looking into the class activation maps (CAM). Class activation maps are a simple technique to get the discriminative image regions used by a CNN to identify a specific class in the image. In other words, a class activation map (CAM) lets us see which regions in the image were relevant to this class.
- The output of grad-CAM will be pixels that contribute to the maximization of a target function. If for example you are interested in what maximizes category number 285, then zero out all the other categories.
- Compute the gradients of the target function, with respect to the convolutional layer outputs. This can be done efficiently with backpropagation
Given an image and a class of interest (e.g., ‘tiger cat’ or any other type of differentiable output) as input, we forward propagate the image through the CNN part of the model and then through task-specific computations to obtain a raw score for the category. The gradients are set to zero for all classes except the desired class (tiger cat), which is set to 1. This signal is then backpropagated to the rectified convolutional feature maps of interest, which we combine to compute the coarse Grad-CAM localization (blue heatmap) which represents where the model has to look to make the particular decision.
Let’s look at GradCAM visualizations for specific blocks in our CNN model. We start by visualizing one of the layers from block 1 (shallower layer).
Like we expected, this being one of the shallow layers, we see higher level features like edges and corners being activated in the network. Let’s now visualize GradCAM outputs from one of the deeper layers in the network in Block 6.
Things definitely start to get more interesting, we can clearly see that when the model predicts the cat as tabby
, it is focusing on both the textures and also the overall shape and structure of the cat versus when it predicts the cat as an Egyptian_cat
. Finally, let’s take a look at one of the deepest layers in the model from Block 14.
Very interesting to observe that for the tabby
cat label prediction, the model is also looking at the region surrounding the cat which is basically focusing on the overall shape \ structure of the cat and also some aspects of the cat’s facial structure!
SmoothGrad
This technique helps us visualize stabilized gradients on the inputs towards the decision. The key objective is to identify pixels that strongly influence the final decision. A starting point for this strategy is the gradient of the class score function with respect to the input image. This gradient can be interpreted as a sensitivity map, and there are several techniques that elaborate on this basic idea.
SmoothGrad is a simple method that can help visually sharpen gradient-based sensitivity maps. The core idea is to take an image of interest, sample similar images by adding noise to the image, then take the average of the resulting sensitivity maps for each sampled image.
For the tabby
cat focus is definitely on key points on the face including patches and stripes which are very distinguishing characteristics.
Conclusion
This should give you a good idea of how you can not only leverage pre-trained complex CNN models to predict on new images but to even try and make an attempt to visualize what the neural network models are really seeing! The list of techniques here are not exhaustive but definitely cover some of the most popular and widely used methods to interpret CNN models. I recommend you to try these out with your own data and models!
All the code used in this article is available on my GitHub in this repository as Jupyter notebooks.
I am a Google Developer Expert in Machine Learning and I look forward to sharing more interesting aspects of machine learning and deep learning over time. Feel free to check out my Medium and LinkedIn for updates on interesting content!