Interpretability tools for understanding your machine learning models (part 1)

Carmen Lai
Edge Analytics
Published in
7 min readAug 19, 2022

There are powerful tools to help humans understand “black box” machine learning (ML) models. Incorporating them in our model development can help pave the way to more robust, safe, and unbiased ML systems.

Photo by DeepMind on Unsplash

Artificial intelligence (AI) systems have wide applications in biology, healthcare, autonomous driving, natural language processing, computer vision, and other fields. Cutting edge systems, such as GPT-3 (see this post by an Edge colleague!), DALL-E, and Imagen have made headlines in large part due to the underlying models’ high accuracy and ability to learn complex data. However, an increase in model complexity often comes at the cost of interpretability — the human ability to reason about a model’s decisions. Interpretability is an increasingly important component to deploying “black box” models to avoid inaccurate or biased decision making. At Edge Analytics, we work on projects in digital medicine and the life sciences, and develop models that impact real people. We believe that understanding our models will help us develop more robust models and ultimately drive better outcomes.

With increasing model complexity, accuracy typically increases while interpretability decreases.
(
image source)

In Part 1 of this blog series, we will illustrate the building blocks of interpretability from a statistical and machine learning perspective. Interpretability may be more familiar than you think and may already be part of your data science toolbox. In Part 2, we will walk through an example of deep learning interpretability based on attention layers from large language models, inspired by one of our projects at Edge. Stay tuned!

What is interpretability?

Interpretability is the degree to which humans can understand what a model has learned and how it makes predictions. While interpretability has been an active area of AI research for many years, there is no general consensus on how we define, measure, or apply interpretability.

Interpretability is the degree to which humans can understand what a model has learned and how it makes predictions. (image source)

One might wonder — if a model predicts with 99.99% accuracy, why do we care whether we understand the reasoning? Beyond accuracy, interpretable models can build trust among practitioners, increase adoption, and mitigate unintended bias. Widespread adoption of AI systems, particularly in high-stakes fields like medicine, often require that decisions are reproducible and defensible. A deeper understanding of our models may also help us uncover our own subconscious biases — in the questions we ask, the data we collect, and the tasks we choose to model.

What are the goals for modeling?

To better understand the trade-off between accuracy and interpretability, one important consideration is the goal behind modeling. Some data scientists train models on existing data to make highly accurate predictions about future data (e.g. stock price prediction), while others may train models to draw conclusions about existing data (e.g. market research survey). Commonly used machine learning metrics like accuracy, precision, recall, F1, AUROC, or AUPRC primarily assess prediction accuracy, but they do not assess whether representations learned by a model are justifiable or understandable. Prediction and inference are often not mutually exclusive goals, which is why it is important to think about how to assess both.

Prediction (accuracy-driven): training a model to learn from existing data to make highly accurate predictions about future data

Inference (explanation-driven): training a model to understand driving input features in order to draw conclusions about existing data

Machine learning tends to prioritize prediction, while statistical learning prioritizes inference. (image adapted from this blog)

To illustrate how we may examine model interpretability, we will walk through traditional linear and nonlinear machine learning models on a tabular dataset in the following sections.

Linear models are highly interpretable

Linear regression and logistic regression are both examples of general linear models (GLMs). To illustrate linear model interpretability, we will train a linear regression model on sklearn’s diabetes dataset, using a number of input variables (age, sex, BMI, etc.) to predict diabetes disease progression. When assessing driving inputs in a model, data scientists often think about the effect size and significance (read more in chapter 3.2). In linear regression, these can be assessed using the coefficients (i.e. “betas” or “weights”) and p-values, respectively, associated with each input variable.

Effect size: the strength of the effect of a given input variable on the output variable

Significance: the likelihood that the detected effect is true and not due to random chance

Volcano plot showing significance against effect size from linear regression. While BMI and LDL have similarly high effect sizes, BMI has higher significance overall for the model.

Linear models get us far in many use cases because they generalize well to new or unseen data and are less prone to overfitting with regularization (read more about the bias-variance tradeoff). However, adapting linear models to fit more complex data often involves significant manual data processing and feature engineering. This process can be time-consuming, requires more specialized domain knowledge, and may be incompatible with high-dimensional data (images, time series, etc.).

Nonlinear models can fit more complex data

Nonlinear models can account for more complex relationships and higher dimensionality in data. Decision trees, random forests, and neural networks are all examples of nonlinear models. To illustrate interpretability for a nonlinear model, we can train a decision tree regressor on the same diabetes dataset.

At training, the decision tree algorithm recursively optimizes for splits on features that maximize information gain at the top of the tree. To compute a prediction, one can traverse through the decision tree splits based on the feature values for a given sample; the average across samples at the terminal node is the prediction (or majority class in the case of classification). Feature importances tell us about the predictiveness of each input feature and are primarily based on the order in which the features appear in the decision tree splits.

Splitting criteria learned by a decision tree model (max_depth=3) on diabetes data, with the most important and predictive features at the top.

Note: Examples above are shown only to illustrate interpretability tools. We have not tuned these models or tested on holdout data to draw actual conclusions about the diabetes dataset. These steps are important components of a machine learning workflow and a prerequisite for model interpretation in practice.

Global interpretability vs. local explanation

So far, we have explored interpretability at a global level (what are important features for the model overall?). We can extend interpretability to assess explanations at the local level (what are driving features for a single sample?). Here, we examine one sample from the diabetes dataset.

Waterfall plot showing input feature contributions (blue, red) to the final prediction (gray) for diabetes progression from linear regression.

For linear regression, contribution towards a prediction can generally be calculated using coefficient * value for a given input feature. However, for nonlinear models like the decision tree, this approach does not apply as there is often not one single coefficient associated with a given feature. Feature attribution (contribution of features towards a prediction) is a challenge for nonlinear models, because contribution from one feature may be distributed across different parts of the model or may depend on other features.

SHAP (SHapley Additive exPlanations) is one method that can be used for nonlinear model feature attribution. This approach is adapted from game theory with optimal credit allocation, and can provide local explanations as well as global interpretability. SHAP values can be used to explain many “black box” models, but here we illustrate its usage with the decision tree we trained previously.

Local explanation — additive force plot using SHAP values to explain which input variables drove this decision tree prediction for a single sample. Red=increases prediction, Blue=decreases prediction.
Global interpretability — beeswarm plot showing direction of feature contributions (SHAP values) towards model predictions, across all samples. Each dot represents a contribution from one feature for one sample. Here, a higher BMI value (red) is associated with higher SHAP values, and therefore increases model predictions in aggregate.

Deep learning interpretability

In this post, we illustrated some examples of interpretability with traditional machine learning models. In Part 2, we will discuss interpretability for neural networks, focusing on attention layers from a large language model. While interpretability in deep learning is still nascent, we hope that new and existing tools will become more widely adopted. Specifically, we hope that tools for interpretability can be used as a force for good in future AI systems as the presence of algorithms in our lives grows over time.

Further reading

  • Colab notebook with code to reproduce the plots shown in this blog post
  • Model Interpretability, Machine Learning Best Practices in Healthcare and Life Sciences (paper)
  • Towards A Rigorous Science of Interpretable Machine Learning (paper)
  • Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (book)
  • Interpretable AI (textbook)

Special thanks to Ren Gibbons, Brinnae Bent, and Lina Colucci for reviewing this blog and providing invaluable feedback.

Edge Analytics is a consulting company that specializes in data science, machine learning, and algorithm development both on the edge and in the cloud. We partner with our clients, who range from Fortune 500 companies to innovative startups, to turn their ideas into reality. Have a hard problem in mind? Get in touch at info@edgeanalytics.io.

--

--

Carmen Lai
Edge Analytics

Data scientist at Edge Analytics, with a background in neuroscience and passion for biotechnology.