How to make your data and models interpretable by learning from cognitive science
(This post accompanies a talk that Been Kim gave at South Park Commons on interpretable machine learning. If you want the technical details from Been Kim herself, check out the video and papers at the bottom of this post.)
Intro: Most machine learning models are inscrutable. What can we do?
It’s an unfortunate truth of modern machine learning: even if your model performs perfectly well on the metric you optimized for, that’s no guarantee that you’re going to be happy with its performance in the real world.
Sure, the test-set accuracy is great. But you may not have noticed that the errors it makes are concentrated on categories that are important not to mess up on (such as tagging black people as gorillas). It may reinforce discriminatory biases because you didn’t encode fairness into your objective function (Bolukbasi et al. 2016). It may fail spectacularly if the real-world environment differs imperceptibly from the testing environment (adversarial examples, e.g. Goodfellow et al. 2014). Or it may satisfy the letter of your request, but definitely not the spirit (https://blog.openai.com/faulty-reward-functions/).
The problem is that a single metric, such as classification accuracy, is an incomplete description of most real-world tasks (Doshi-Velez and Kim 2017). Other important outcomes — such as fairness, privacy, safety, or usability — are not captured in simple performance metrics.
As we continue deploying ML in more and more real-world applications, unintended outcomes have the potential to become increasingly problematic for society (as discussed by the AINow initiative, the Future of Life Institute, and other groups). What can we do?
Interpretability: one path forward
One line of research to address these difficulties is to design explainable or interpretable models. The ability to understand which examples the model is getting right or wrong, and how it’s coming to the answers it gets, could help users of ML systems notice important gaps between the formalized problem description and the desired real-world outcomes.
In recent years, researchers have started workshops and conferences on model interpretability, such as the NIPS Interpretable ML workshop and the Fairness, Accountability, and Transparency (FAT*) conference. Funders and regulators are also looking towards explainability as a solution, from the EU’s recent Right to Explanation legislation to DARPA’s Explainable AI program.
Been Kim: putting the “human” back in “human-interpretable”
Been Kim is a research scientist who builds interpretable ML models at the People+AI Research Initiative at Google Brain. In her recent talk at the South Park Commons AI Speaker Series, she presented a series of methods that use example-based reasoning inspired by the cognitive science of human decision-making, and showed that they are easier for humans to predict and collaborate with.
Unlike other approaches, Kim’s work is explicitly inspired by the cognitive science of human reasoning. Specifically: human reasoning is often prototype-based, using representative examples as a basis for categorization and decision-making. Similarly, Kim’s models use representative examples to explain and cluster data.
Throughout Kim’s talk, assertions of “interpretability” were backed up with experimental data showing concrete desired outcomes — for example, that users can more consistently predict the model’s results, or that they give a higher subjective satisfaction rating.
In the rest of this post I’ll explain the two main methods that Been Kim showed during her talk:
The first method, called MMD-Critic (Maximum Mean Discrepancy), is not itself an ML model, but rather a way to understand data itself. It’s an unsupervised method that can be applied to an unlabeled dataset, or to individual categories within a labeled dataset.
The second method, called the Bayesian Case Model (BCM), is an unsupervised learning method which leverages both prototypes and sparse features to be more interpretable without any loss of power compared to standard methods. Kim also demonstrates that BCMs are easier for humans to collaborate with, by incorporating an interactive BCM model into the task of grading course assignments.
I’ll give a brief overview of how MMD-Critic and BCM work. If you want more details than I provide here then you should definitely check out the videos and papers at the end of this post.
MMD-critic: using prototypes and criticisms to look at your data
A common refrain among advisors of data analysis trainees is to “look at your data!” rather than leaping blindly into model-fitting. This is great advice. Excessive trust of summary statistics can mask bizarre input distributions, broken data pipelines, or bad assumptions. Preemptively reaching for a modeling framework when your raw data is a mess is a prime recipe for “garbage in, garbage out”.
That said, how exactly should you go about looking at your data? If your data consists of thousands of images, you can’t look at all of them. Should you just look at image 000001.png through 000025.png and call that good enough?
To answer this question, Kim took inspiration from the cognitive science of how humans understand categories. Specifically, human categorization can be modeled as using prototypes: examples that are representative of the category as a whole. An item’s category membership is determined by its similarity to the category’s prototypes. (see https://en.wikipedia.org/wiki/Prototype_theory and https://en.wikipedia.org/wiki/Recognition_primed_decision for more detail on the cognitive science)
One disadvantage of prototype-based reasoning is that it is prone to overgeneralization. That is, the properties of the prototypical members are assumed to be universally shared among the group, even if there is substantial variation within the group. One technique that can help avoid overgeneralizing is to show exceptions or criticisms to the rule: minority datapoints that differ substantially from the prototype, but nonetheless belong in the category.
For example, the distribution of cat images mostly consists of single cats sitting, standing, or lying down. However, an image of a cat sprawled across a keyboard belly-up, wearing a costume, or hiding inside a bag is still a cat image, even though it differs substantially from the prototypical images. Notably, these unusual examples are important minorities, rather than lone outliers. There are many cat images showing atypical positions and costumes, and so these images are important to a full understanding of cat images.
Kim et al. developed an unsupervised algorithm for automatically finding prototypes and critics for a dataset, called MMD-critic. When applied to unlabeled data, it finds prototypes and critics that characterize the dataset as a whole. It can also be used to visualize a category of images within a labeled dataset.
The MMD-critic algorithm works in two stages: First, prototypes are selected so that the set of prototypes is similar to the full dataset. Maximum Mean Discrepancy (MMD) refers to the specific way of measuring of the difference between the prototype distribution and the full data distribution. Second, criticisms are selected from parts of the dataset that are underrepresented by the prototypes, with an additional constraint to ensure the criticisms are diverse. The result of this method is a set of prototypes that are typical of the dataset as a whole, and a set of criticisms that identify large parts of the dataset that differ most from the prototypes.
If you’d like to try MMD-Critic on your own data, an implementation is available at https://github.com/BeenKim/MMD-critic.
Pilot study with human subjects
In order to validate the MMD-critic method, Kim set up a small pilot study in which human subjects did a categorization task. Users were shown an image of an animal, and were asked to predict which sub-group it came from (for example, if they were shown a dog, they would have to categorize it as breed 1 or breed 2, based on example images from each breed).
Users were given this task in four different conditions, which showed example group members in different ways: 1) all images in each group (200–300 of them); 2) just the prototypes; 3) prototypes and criticisms, and 4) a random selection of images from each group, with the same number of images as condition 3.
In her pilot results, Kim found evidence that:
- Viewing just the prototypes of each group allowed users to make more accurate and more time-efficient predictions, compared with viewing all the group members or a random subset.
- Including criticisms improves accuracy over prototypes alone, at a small cost to time-efficiency.
- Viewing a random subset of images is the least accurate and the least efficient.
Bayesian Case Model (BCM): cog-sci-inspired clustering
A selection of prototypes and criticisms can provide insight into a dataset, but it isn’t itself a machine learning model. How can prototype-based reasoning be extended to a fully-fledged and operational ML model?
The second of the two models Been Kim presented in her talk was a novel type of admixture model, designed to incorporate the interpretability of case-based reasoning without any loss of performance over standard mixture models.
In order to understand the Bayesian Case Model as an application of “case-based reasoning” to “admixture models”, it’s useful to clarify what those terms refer to:
- Case-based reasoning is a human reasoning method for real-world problem-solving. Previously-seen examples are used as a scaffold for solving novel problems. The relevant features that relate the old problem to the new problem are identified, and previous problem-solving strategies are reused and revised. This is more than just a formal problem-solving procedure; it is also a description of everyday, informal human reasoning.
- An admixture model is a type of generative model for unsupervised learning. The features of a data distribution are modeled as being derived from a mixture of underlying sources (such as topics, subpopulations, or clusters) that are inferred but not directly observed. Fitting an admixture model to an observed dataset is a form of unsupervised learning. The identified underlying sources can be inspected directly to gain insight into the underlying structure of the data, or used as the basis for a clustering analysis. (For more depth, see Wikipedia on mixture models, plus this explanation on how admixture models differ from mixture models)
In order to understand the difference in interpretability between a traditional admixture model such as Latent Dirichlet Allocation (LDA) and the Bayesian Case Model (BCM) that Kim presented, consider the following figure from Kim 2015:
In this example, a hypothetical dataset of cartoon faces with different shapes, colors, eyes, and mouths has been analyzed using an admixture model, and three underlying clusters have been discovered (left column). LDA and BCM would discover similar underlying clusters; they differ only in how the clusters are represented. BCM represents the clusters in a more interpretable format, without any loss of representational power.
A typical admixture model (middle column, LDA) would represent the identities of the three clusters as a long list of feature probabilities — 26% likelihood of green color, 23% likelihood of square shape, etc. This can be difficult for humans to interpret because it provides a laundry list of continuous values, rather than a concise and memorable handle for the cluster (see Doshi-Velez and Kim 2017’s discussion on “cognitive chunks”). By contrast, a Bayesian Case Model would represent each cluster using 1) a prototypical example of a representative class member (right column, “prototype”), and 2) a subspace of the prototype’s features that are actually important for cluster membership (right column, “subspaces”). This provides a more cognitively accessible handle for each cluster: a single example as a prototype, paired with guidance about which of the prototype’s features are important to pay attention to.
Evaluating the BCM using an interactive grading system
In evaluating interpretability in this case, Kim focused on users’ ability to collaborate with the model by altering it interactively.
She built a BCM-based interactive extension to OverCode (http://people.csail.mit.edu/elg/overcode), a system that uses cluster analysis to allow instructors to visualize thousands of programming solutions. The interactive extension allowed instructors to directly manipulate the clusters by selecting which submissions should be used as prototypes of the BCM, and which keywords are important subspaces for each prototype.
When instructors were tasked with using the interactive BCM system to select a set of examples to review in a recitation, they reported that they were more satisfied, better explored the full spectrum of students’ submissions, and discovered more useful features and prototypes (p < 0.001), compared to a non-interactive version.
The path ahead
During the Q&A, Kim gave a sense of some interesting future challenges that interpretable ML has left to tackle:
- Examples are not the final answer to everything. For example, in medical research, researchers want to discover new patterns that they can’t yet see or notice. An example of a representative patient might elicit a reaction of “I know everything about this patient; so what?”
- You can’t expect a human to understand or predict what a system with super-human performance is going to do, almost by definition. Conceiving of interpretability in a way that boils down to human prediction will no longer be straightforwardly useful once systems exceed our ability to predict their actions. That said, Kim believes interpretability will remain relevant to super-human systems. Even if they cannot be holistically understood, there’s still the possibility of understanding locally, for a single datapoint, why the decision was made in a certain way.
Conclusion / summary
I took away the following major takeaways from Been Kim’s talk:
- When you look at your raw data, focus on prototypical examples if you want a more efficient and accurate way to view your data than a random sample. Additionally, for a maximally accurate sense of the diversity of your data, include criticisms.
- To ensure that your users will be able to collaborate with your models, consider tailoring the models themselves to the quirks of human cognition. If your system thinks the way your users do, then your users will likely better be able to impart their knowledge back to the system.
- “Interpretability” has many meanings. Define your goals clearly for your specific application, and run experiments with human subjects to verify that your model achieves the user outcomes you were aiming for.
As ML systems become increasingly powerful, it will be increasingly important for us to have confidence in what they do. And in order for that confidence to be well-founded and not misplaced, we will need to take into account what it means for a human specifically to “have confidence”, to “trust”, or to “understand”. Our attention span is limited, and our cognitive capacities are idiosyncratic and inescapably human. If we are to truly understand the ML systems of the future and today, we will need to account for our own process of understanding.
This is a summary of a talk Been Kim gave at the South Park Commons AI Speaker Series titled “Interactive and Interpretable Machine Learning Models” Images and video from the talk provided by Google and used with permission.
Appendix: Full video, slides, papers, and code
- Talk slides for download
- Full video below:
- Examples are not Enough, Learn to Criticize! Criticism for Interpretability. Kim, Khanna, and Koyejo, NIPS 2016.
- Github code: https://github.com/BeenKim/MMD-critic
- NIPS oral presentation: slides and 15-minute talk
BCM papers and code:
- The Bayesian Case Model: A Generative Approach for Case-Based Reasoning and Prototype Classification. Kim, Rudin, and Shah, NIPS 2014.
- iBCM: Interactive Bayesian Case Model Empowering Humans via Intuitive Interaction. Kim, Glassman, Johnson, and Shah, MIT CSAIL TR 2015.
- Code: https://users.cs.duke.edu/~cynthia/code/BCM.zip