Interpretability tools for understanding your machine learning models (part 1)
There are powerful tools to help humans understand “black box” machine learning (ML) models. Incorporating them in our model development can help pave the way to more robust, safe, and unbiased ML systems.
Artificial intelligence (AI) systems have wide applications in biology, healthcare, autonomous driving, natural language processing, computer vision, and other fields. Cutting edge systems, such as GPT-3 (see this post by an Edge colleague!), DALL-E, and Imagen have made headlines in large part due to the underlying models’ high accuracy and ability to learn complex data. However, an increase in model complexity often comes at the cost of interpretability — the human ability to reason about a model’s decisions. Interpretability is an increasingly important component to deploying “black box” models to avoid inaccurate or biased decision making. At Edge Analytics, we work on projects in digital medicine and the life sciences, and develop models that impact real people. We believe that understanding our models will help us develop more robust models and ultimately drive better outcomes.
In Part 1 of this blog series, we will illustrate the building blocks of interpretability from a statistical and machine learning perspective. Interpretability may be more familiar than you think and may already be part of your data science toolbox. In Part 2, we will walk through an example of deep learning interpretability based on attention layers from large language models, inspired by one of our projects at Edge. Stay tuned!
What is interpretability?
Interpretability is the degree to which humans can understand what a model has learned and how it makes predictions. While interpretability has been an active area of AI research for many years, there is no general consensus on how we define, measure, or apply interpretability.
One might wonder — if a model predicts with 99.99% accuracy, why do we care whether we understand the reasoning? Beyond accuracy, interpretable models can build trust among practitioners, increase adoption, and mitigate unintended bias. Widespread adoption of AI systems, particularly in high-stakes fields like medicine, often require that decisions are reproducible and defensible. A deeper understanding of our models may also help us uncover our own subconscious biases — in the questions we ask, the data we collect, and the tasks we choose to model.
What are the goals for modeling?
To better understand the trade-off between accuracy and interpretability, one important consideration is the goal behind modeling. Some data scientists train models on existing data to make highly accurate predictions about future data (e.g. stock price prediction), while others may train models to draw conclusions about existing data (e.g. market research survey). Commonly used machine learning metrics like accuracy, precision, recall, F1, AUROC, or AUPRC primarily assess prediction accuracy, but they do not assess whether representations learned by a model are justifiable or understandable. Prediction and inference are often not mutually exclusive goals, which is why it is important to think about how to assess both.
Prediction (accuracy-driven): training a model to learn from existing data to make highly accurate predictions about future data
Inference (explanation-driven): training a model to understand driving input features in order to draw conclusions about existing data
To illustrate how we may examine model interpretability, we will walk through traditional linear and nonlinear machine learning models on a tabular dataset in the following sections.
Linear models are highly interpretable
Linear regression and logistic regression are both examples of general linear models (GLMs). To illustrate linear model interpretability, we will train a linear regression model on sklearn’s diabetes dataset, using a number of input variables (age, sex, BMI, etc.) to predict diabetes disease progression. When assessing driving inputs in a model, data scientists often think about the effect size and significance (read more in chapter 3.2). In linear regression, these can be assessed using the coefficients (i.e. “betas” or “weights”) and p-values, respectively, associated with each input variable.
Effect size: the strength of the effect of a given input variable on the output variable
Significance: the likelihood that the detected effect is true and not due to random chance
Linear models get us far in many use cases because they generalize well to new or unseen data and are less prone to overfitting with regularization (read more about the bias-variance tradeoff). However, adapting linear models to fit more complex data often involves significant manual data processing and feature engineering. This process can be time-consuming, requires more specialized domain knowledge, and may be incompatible with high-dimensional data (images, time series, etc.).
Nonlinear models can fit more complex data
Nonlinear models can account for more complex relationships and higher dimensionality in data. Decision trees, random forests, and neural networks are all examples of nonlinear models. To illustrate interpretability for a nonlinear model, we can train a decision tree regressor on the same diabetes dataset.
At training, the decision tree algorithm recursively optimizes for splits on features that maximize information gain at the top of the tree. To compute a prediction, one can traverse through the decision tree splits based on the feature values for a given sample; the average across samples at the terminal node is the prediction (or majority class in the case of classification). Feature importances tell us about the predictiveness of each input feature and are primarily based on the order in which the features appear in the decision tree splits.
Note: Examples above are shown only to illustrate interpretability tools. We have not tuned these models or tested on holdout data to draw actual conclusions about the diabetes dataset. These steps are important components of a machine learning workflow and a prerequisite for model interpretation in practice.
Global interpretability vs. local explanation
So far, we have explored interpretability at a global level (what are important features for the model overall?). We can extend interpretability to assess explanations at the local level (what are driving features for a single sample?). Here, we examine one sample from the diabetes dataset.
For linear regression, contribution towards a prediction can generally be calculated using coefficient * value for a given input feature. However, for nonlinear models like the decision tree, this approach does not apply as there is often not one single coefficient associated with a given feature. Feature attribution (contribution of features towards a prediction) is a challenge for nonlinear models, because contribution from one feature may be distributed across different parts of the model or may depend on other features.
SHAP (SHapley Additive exPlanations) is one method that can be used for nonlinear model feature attribution. This approach is adapted from game theory with optimal credit allocation, and can provide local explanations as well as global interpretability. SHAP values can be used to explain many “black box” models, but here we illustrate its usage with the decision tree we trained previously.
Deep learning interpretability
In this post, we illustrated some examples of interpretability with traditional machine learning models. In Part 2, we will discuss interpretability for neural networks, focusing on attention layers from a large language model. While interpretability in deep learning is still nascent, we hope that new and existing tools will become more widely adopted. Specifically, we hope that tools for interpretability can be used as a force for good in future AI systems as the presence of algorithms in our lives grows over time.
Further reading
- Colab notebook with code to reproduce the plots shown in this blog post
- Model Interpretability, Machine Learning Best Practices in Healthcare and Life Sciences (paper)
- Towards A Rigorous Science of Interpretable Machine Learning (paper)
- Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (book)
- Interpretable AI (textbook)
Special thanks to Ren Gibbons, Brinnae Bent, and Lina Colucci for reviewing this blog and providing invaluable feedback.
Edge Analytics is a consulting company that specializes in data science, machine learning, and algorithm development both on the edge and in the cloud. We partner with our clients, who range from Fortune 500 companies to innovative startups, to turn their ideas into reality. Have a hard problem in mind? Get in touch at info@edgeanalytics.io.