Interpreting Machine Learning Models

Source: Antonio Robers on Flickr

Visit us at to learn more about how we’re using data science to improve hiring

Why Interpretability Matters

In the mid 1990s, a national effort was undertaken to build algorithms to predict which pneumonia patients should be admitted to hospitals and which treated as outpatients. Initial findings indicated neural nets were far more accurate than classical statistical methods. Doctors, however, wanted to understand the “thinking” behind this algorithm, so statisticians catalogued “decision rules” from the more easily interpreted regression results.

It turned out that both the regression and the neural net had inferred that pneumonia patients with asthma have a lower risk of dying, and shouldn’t be admitted. Obviously, this is counterintuitive. But it reflected a real pattern in the training data — asthma patients with pneumonia usually were admitted not only to the hospital but directly to the ICU, treated aggressively, and survived. [1]

Had this “high-performing” neural net been deployed in a clinical setting, it could have caused unnecessary deaths. Only by interpreting the model was a crucial problem discovered and avoided. Understanding why a model makes a prediction can literally be an issue of life and death. As algorithms are used to make decisions in more facets of everyday life, it’s important for data scientists to train them thoughtfully to ensure the models make decisions for the right reasons.

Many machine learning textbooks present students with a chart that shows a tradeoff between model interpretability and model accuracy. This is a heuristic, but many students come away thinking that this tradeoff is as strict as a law of physics.

In this post, we’ll explore (and question) this tradeoff, offer a framework for measuring interpretability, and apply that framework to a few common models.

First, let’s return to why interpretability matters — and when it doesn’t. As the pneumonia example illustrates, interpretability is key for “debugging” models. It’s required in regulated industries like finance and healthcare to audit the decision process and ensure it’s not discriminatory. The US Fair Credit Reporting Act requires that agencies disclose “all of the key factors that adversely affected the credit score of the consumer in the model used, the total number of which shall not exceed four” — and they’re not wrong to set this limit. Models implemented in popular software packages can easily accept thousands of data points, and a huge feature set can quickly make a straightforward explanation nearly impossible (not to mention that collinear features can complicate things further, but we won’t address that here).

Interpretability is also key to winning trust in algorithms that try to improve upon human judgement, instead of just automating it. Take our work at Ansaro trying to predict which job applicants will perform best. Our goal is to do better than human intuition. For users to accept our predictions, they have to understand them.

So when can we deprioritize interpretability? We think these criteria are reasonable guidelines:

  • Global Interpretability: How well can we understand the relationship between each feature and the predicted value at a global level — for our entire observation set. Can we understand both the magnitude and direction of the impact of each feature on the predicted value?
  • Local Interpretability: How well can we understand the relationship between each feature and the predicted value at a local level — for a specific observation.
  • Feature Selection: Does the model help us focus on only the features that matter? Can it zero out the features that are just “noise”?

It’s important to note we’re not talking about interpreting model accuracy — we assume that models have been cross-validated using train, validation, and test datasets, and that an appropriate evaluation metric like AUC or F1-score has been chosen. We also assume that the feature set has been chosen thoughtfully, though this is a big assumption — and one with many interpretations. Next, we’ll apply this framework to a few common model types, to get an idea of how strict the accuracy–interpretability tradeoff really is.

Linear Regression

We’ll start with linear regression. There’s a reason that linear regression has been the go-to model for the scientific community for the past century — because it’s the gold standard in interpretability.

An ordinary least squares (OLS) model generates coefficients for each feature. These coefficients are signed, allowing us to describe both the magnitude and direction of each feature at the global level. For local interpretability, we need only multiply the coefficient vector by a specific feature vector to see the predicted value, and the contribution of each feature to that prediction.

A classic OLS regression doesn’t eliminate noise features, but we can accomplish that by removing features for which the confidence interval crosses zero and rerunning the model. Or we can use slightly more sophisticated methods, like Ridge or Lasso regression, that essentially zero out noise features.

Random Forest

In the middle of the accuracy-interpretability spectrum are random forests. We’ve often seen them described as “black boxes,” which we think this is unfair — maybe “gray” but certainly not “black”!

Random forests are collections of decision trees, like the one drawn below. The splits in each tree are chosen from random subsets of our features, so the trees all look slightly different. A single tree can be easily interpreted, assuming it is not grown too deep. But how we can interpret a random forest that contains hundreds or thousands of trees?

Many implementations of random forest classifiers include out-of-the-box methods for quantifying the overall magnitude of each feature. For example, scikit-learn’s RandomForestClassifier.feature_importances_ allows us to assess the relative importance of features with one line of code. Feature importances, when used with proper cross-validation, can also allow us to identify the features that are pure noise.

However, understanding features’ directionality is more difficult. We can quickly identify that Feature X may be the most important, but does it make Outcome Y more or less likely? There may not be a yes-or-no answer. Unlike a linear regression, random forests can identify non-monotonic relationships (a big part of the reason they outrank regression on the accuracy axis). In one region of the observation space, a feature’s direction may be positive; in another it may be negative.

Understanding how a particular observation’s features contribute to the prediction is also challenging, but doable. To achieve local interpretability, we can catalogue the decision paths for a specific observation through all our decision trees. We then sum the decreases in the Gini index for each feature, across all these paths. For the non-statisticians, Gini index decrease is a measure of how much more “cleanly” the classes are separated after a split in a decision tree. This sounds like heavy lifting, but libraries like Ando Saabas’ excellent treeinterpreter make this practical. [2]

Neural Networks

As the hottest topic in machine learning over the past decade, we’d be remiss if we didn’t mention neural networks. Hailed for outstanding accuracy in difficult domains like image recognition and language translation, they’ve also generated criticism for lacking interpretability:

“Nobody understands how these systems — neural networks modeled on the human brain — produce their results. Computer scientists “train” each one by feeding it data, and it gradually learns. But once a neural net is working well, it’s a black box. Ask its creator how it achieves a certain result and you’ll likely get a shrug.” — Wired Magazine, October 2010 [3]

We think that’s a dramatic overstatement. We’re also cognizant of the fact that there are many types of neural network architectures, and making blanket statements about them is difficult [see the Asimov Institute’s terrific Neural Net Zoo]. For simplicity, we’ll focus on convolutional neural nets (CNNs), widely used for image recognition (and many other applications).

Imagine we’re training a CNN to predict the probability a 64x64 pixel image contains a cat. We could start to gain insight into how important features are by tweaking them and seeing how the resulting probability changes. But the features fed into this CNN are 4,096 RGB pixel values. Knowing that a particular pixel corresponds with being a cat isn’t particularly useful, nor would we even expect such a relationship to occur — pixels representing a cat could appear anywhere in the image. For this model to be semantically interpretable, we need to understand its features at a more abstract level.

However, building up raw pixel values (or any high-dimensional data like audio waveforms or unstructured text) into abstract features through the network layers is where interpretability breaks down, where we start to lose the global understanding of what a specific feature contributes. The ability to encode non-linear relationships across such an array of features is where neural nets outperform many other models in terms of accuracy. Despite this complication, all is not lost in terms of interpretability.

Back to our example. Instead of tweaking individual pixels, we can track and inspect the training images that maximally activate neurons. By looking at neurons at deeper levels of the CNN, we may be able to find neurons that correspond to semantically meaningful concepts like “ear” or “tail.” We can then track the weights assigned to the neurons that we believe represent abstract concepts by the network’s final layer, giving us a rough idea of global feature importance.

For local interpretability, we can use occlusion to understand where a CNN is “paying attention.” We iterate across the image, setting a patch of pixels to be zero, running the occlude image through the CNN, and logging the “cat” probability. We can then visualize the contribution of each part of the image to the “cat” probability as a 2D heat map. These methods aren’t as simple as examining coefficients, but it shows that neural networks are not completely black boxes.

Concluding Thoughts

Rather than being a static tradeoff, we think of accuracy-vs-interpretability as a frontier, one which is constantly being pushed outwards.

Over the next decade, we believe developing more interpretable models will be as important as developing more accurate models for the data science community. Much of the work on advancing interpretability will be done by domain-specific experts — better ways to visualize CNN results, for example. But there are also exciting advancements in approaches that transcend specific model types. One approach that’s caught our attention is LIME (Local Interpretable Model-Agnostic Explanations) [4], which allows you to build any model you like, then use perturbation and linear approximation to explain specific predictions.

As machine learning becomes more important in our day-to-day lives, promoting trust in good algorithms — and the ability to detect bad algorithms — is critical. So, too, is designing algorithms from the start with interpretability in mind.

We’re hiring! If you’d like to help companies make the best hiring decisions, check out our job postings: