Introduction to Captum — A model interpretability library for PyTorch

PyTorch
PyTorch
Mar 23 · 7 min read

Author: Narine Kokhlikyan, Research Scientist at Facebook AI

In the literature, Explaining Explanations: An Overview of Interpretability of Machine Learning, the interpretability of AI is the ability of describing AI models in human understandable terms. By understanding AI models better and why they are making certain predictions, we can start to answer difficult questions about the internals of our models, trust, accountability, and fairness.

Captum is a model interpretability library for PyTorch which currently offers a number of attribution algorithms that allow us to understand the importance of input features, and hidden neurons and layers. The word Captum means “comprehension” in Latin and serves as the linguistic root of the word “understanding” in many Latin-stemming languages.

We announced Captum as an open-source library mid-October 2019 during PyTorch Developer Conference. Since then we have been actively expanding the library by adding new algorithms, functionalities, tutorials, and helping users with its adoption. The current 0.2.0 version of the library contains well-tested gradient and perturbation-based attribution algorithms.

Captum supports any PyTorch model, meaning it is not limited to classification models but can also be used for any application or domain. Some of the algorithms might be more common for certain types of applications, such as computer vision. Still, the implementations are generic so that the users can apply them to any PyTorch model, interpret, and visualize the attributions. For visualization, we also built an interactive tool called Captum Insights. We will talk about it in detail later in this blog post.

The Algorithms in Captum

The diagram below shows all attribution algorithms available in the Captum library divided into two groups. The first group, listed on the left side of the diagram, allows us to attribute the output predictions or the internal neurons to the inputs of the model. The second group, listed on the right side, includes several attribution algorithms that allow us to attribute the output predictions to the internal layers of the model. Some of those algorithms in the second group are different variants of the ones in the first group.

Most algorithms can also be grouped into gradient and perturbation based approaches. In the diagram above, algorithms outlined in green are perturbation, and those in orange are gradient-based approaches. The other two algorithms outlined in blue are general purpose and can’t be classified as perturbation or gradient-based approaches.

Another important aspect of these algorithms is that many of them require baselines or references. Baselines or references are often described as uninformative inputs that are used to compare and contrast with the original inputs, based on that comparison, make certain parts of the original inputs responsible for the output predictions. It is important to note that the choice of a baseline is an essential part of some attribution algorithms mentioned above, since the judgments of what is necessary for a particular prediction are based on those comparisons, yet, finding good baselines still remains an open research topic.

It is also important to note that all the algorithms in the Captum library can also be used with models that support PyTorch `DataParallel`.

A comprehensive list of algorithms, their applications, pros and cons can be found here: https://captum.ai/docs/algorithms_comparison_matrix.

Getting Started with Captum

Before installing make sure you are aware of installation requirements:

  • Python >= 3.6
  • PyTorch >= 1.2

To get started with Captum, you need to install it via Anaconda (recommended):

conda install captum -c pytorch

Or pip:

pip install captum

Below you can find an example of how to use Captum for a simple toy model.

Let’s define our toy model as such:

Let’s create an instance of our model and set it in eval mode.

model = ToyModel()
model.eval()

In order to use an attribution algorithms fromCaptum library and understand important input features and neurons we can simply create an instance of our attribution algorithm by passing the forward function of our model or any modifications of it and call attribute on that instance by pass the inputs of the model. In addition to that we also need to specify the output that we would like to interpret / understand. target input argument is used to select that output index.

Output
--------------------------------
attributions = tensor([[ 4.4417, 9.2981, 5.2851],
[10.3284, 1.3315, 18.1970]])

In the example described above the attributions have the same shape and dimensionality as the inputs of our model.

Now let’s attribute to one of the layers:

from captum.attr import LayerConductance

lc = LayerConductance(model, model.lin1)
attributions = lc.attribute(input, target=0)

Output
--------------------------------
attributions = tensor([[ 0.0000, 4.2083, 14.7958],
[ 0.0000, 6.6869, 23.1374]]

The code snippet above allows us to compute the attribution with respect to the output of first linear layer. Returned attributions have the same shape and dimensionality as first linear layer’s output.

With the same principle we can compute attributions for any other PyTorch model.

Visualizations in Captum

Depending on the application, the users have the freedom to visualize the attributions in their preferred way. For images, it is common to highlight the most and least important pixels, for text, the most and least important tokens, and for a general case, we can always use bar charts to visualize the attributions. Visualizing the attributions of high dimensional layers and neurons can become especially challenging.

Below we visualize important pixels, on the right side of the image, that has a swan depicted on it. A pre-trained ResNet18 model was used to make predictions, which resulted in a prediction of the black swan with a probability of 0.34.

Attributions can be interesting, especially when we look into multimodal use cases. They can help us to understand which modality and to what extent are contributing to the final predictions. In the picture below, we used a pre-trained VQA model to predict a given image and a question about the image. We can see highlighted tokens and image pixels as important.

We can also see the contributions of each modality. In this case, those numbers reveal that text is a more important feature than image.

We can also perform feature attributions by ablating a group of features together. In the case of images, we can use image segmentation and define each segment as a group of features that we would like to ablate together. In the example below, we segmented an image into three groups: monitor, screens, and the background and defined feature ablation masks based on those segments.

On the right side of the image, we can see the visualizations of attribution maps when attributing to the target `monitor` class. In this case, we observe that the pixels on the monitors appear to be very important. They are highlighted in dark green, whereas the background is neutral, it doesn’t have any effect on the prediction, and the bottles in front of the monitors have negative attributions. Interestingly the borders of the monitors that separate the monitors from other classes are also identified to have negative attribution.

Similarly, we can also visualize attributions in the layers. Below we applied LayerIntegratedGradientson all 12 layers of a BERT Model for a Question and Answering task. We attributed one of our predicted tokens, namely output token `kinds`, to all 12 layers. We can see a heatmap of all 12 layers and all tokens where each cell in the heatmap corresponds to the aggregated attribution score. The lighter the color the higher is the aggregated attribution score for a given token in a given layer.

It is also interesting to observe how attribution scores and their distributions change across layers as we go deeper into the network.

For more details about interpreting BERT models, read the full tutorial: https://captum.ai/tutorials/Bert_SQUAD_Interpret.

Model Debugging with Captum Insights

Visualizing feature attributions and model internals can be very challenging. Captum Insights is an interactive debugging and visualization tool built on top of the core Captum library that enables feature visualizations.

Captum Insights works across images, text, and other features to help users understand feature attribution. Some examples of the widget are below.

Captum Insights can be embedded into Jupyter and Google Colab notebooks and allows users to perform interactive instance based feature attribution.

More details on Captum Insights and how to set it up in Jupyter notebooks can be found here: https://captum.ai/docs/captum_insights.

Looking Ahead

Attributions are the beginning of our journey. In the near future we plan to expand Captum to at least four more packages:

  1. captum.optim → this component will focus on various types of optimizations based model understanding approaches including, optimization based visualizations
  2. captum.robust → this component will focus on the algorithms and visualizations that lie in the intersection between adversarial robustness and interpretability
  3. captum.metric → this component with focus on different types of model interpretability, sensitivity, trust and robustness related metrics
  4. captum.benchmark → This component will focus on different benchmarking datasets and methodologies. It is important to mention that, although evaluation of different attribution approaches is important, finding good evaluation metrics can be more challenging than we think.

Stay up to date with Captum by bookmarking the Captum website and starring our GitHub page.

We would love to hear your feedback, receive your contributions and stay connected using GitHub and our discussion forums:

  1. Discussion Forums: https://discuss.pytorch.org/c/captum
  2. Github Issues: https://github.com/pytorch/captum/issues

PyTorch

An open source machine learning framework that accelerates…

Thanks to Jspisak

PyTorch

Written by

PyTorch

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.

PyTorch

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment

More From Medium

More from PyTorch

More on Pytorch from PyTorch

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade