Making Decision Trees Accurate Again: Explaining what Explainable AI did not
Combining neural networks and decision trees for accurate and interpretable computer vision models (and how our method works).
This is an extended version with an expanded methods description of the Towards Data Science article “What Explainable AI fails to explain (and how we fix that)”.
The interpretability of neural networks is becoming increasingly necessary, as deep learning is being adopted in settings where accurate and justifiable predictions are required. These applications range from finance to medical imaging. However, deep neural networks are notorious for a lack of justification. Explainable AI (XAI) attempts to bridge this divide between accuracy and interpretability, but as we explain below, XAI justifies decisions without interpreting the model directly.
What is “Interpretable”?
Defining explainability or interpretability for computer vision is challenging: What does it even mean to explain a classification for high-dimensional inputs like images? As we discuss below, two popular definitions involve saliency maps and decision trees, but both approaches have their weaknesses.
What Explainable AI Doesn’t Explain
Saliency Maps¹
Many XAI methods produce saliency maps, but saliency maps focus on the input and neglect to explain how the model makes decisions. For more on saliency maps, see these saliency tutorials and Github repositories.
What Saliency Maps Fail to Explain
To illustrate why saliency maps do not fully explain how the model predicts, here is an example: Below, the saliency maps are identical, but the predictions differ. Why? Even though both saliency maps highlight the correct object, one prediction is incorrect. How? Answering this could help us improve the model, but as shown below, saliency maps fail to explain the model’s decision process.
Decision Trees
Another approach is to replace neural networks with interpretable models. Before deep learning, decision trees were the gold standard for accuracy and interpretability. Below, we illustrate the interpretability of decision trees.
For accuracy, however, decision trees lag behind neural networks by up to 40% accuracy on image classification datasets². Neural-network-and-decision-tree hybrids also underperform, failing to match neural networks on even the dataset CIFAR10, which features tiny 32x32 images like the one below.
As we show in our paper (Sec 5.2), this accuracy gap damages interpretability: high-accuracy, interpretable models are needed to explain high-accuracy neural networks.
Enter Neural-Backed Decision Trees
We challenge this false dichotomy by building models that are both interpretable and accurate. Our key insight is to combine neural networks with decision trees, preserving high-level interpretability while using neural networks for low-level decisions, as shown below. We call these models Neural-Backed Decision Trees (NBDTs) and show they can match neural network accuracy while preserving the interpretability of a decision tree.
NBDTs are as interpretable as decision trees. Unlike neural networks today, NBDTs can output intermediate decisions for a prediction. For example, given an image, a neural network may output Dog. However, an NBDT can output both Dog and Animal, Chordate, Carnivore (below).
NBDTs achieve neural network accuracy. Unlike any other decision-tree-based method, NBDTs match neural network accuracy (< 1% difference) on CIFAR10, CIFAR100, and TinyImageNet200. NBDTs also achieve accuracy within 2% of neural networks on ImageNet, setting a new state-of-the-art accuracy for interpretable models. The NBDT’s ImageNet accuracy of 75.30% outperforms the best competing decision-tree-based method by a whole ~14%.
How and what Neural-Backed Decision Trees Explain
Justifications for Individual Predictions
The most insightful justifications are for objects the model has never seen before. For example, consider an NBDT (below), and run inference on a Zebra. Although this model has never seen Zebra, the intermediate decisions shown below are correct — Zebras are both Animals and Ungulates (hoofed animal). The ability to see justification for individual predictions is quintessential for unseen objects.
Justifications for Model Behavior
Furthermore, we find that with NBDTs, interpretability improves with accuracy. This is contrary to the dichotomy in the introduction: NBDTs not only have both accuracy and interpretability; they also make both accuracy and interpretability the same objective.
For example, ResNet10 achieves 4% lower accuracy than WideResNet28x10 on CIFAR10. Correspondingly, the lower-accuracy ResNet⁶ hierarchy (left) makes less sense, grouping Frog, Cat, and Airplane together. This is “less sensible,” as it is difficult to find an obvious visual feature shared by all three classes. By contrast, the higher-accuracy WideResNet hierarchy (right) makes more sense, cleanly separating Animal from Vehicle — thus, the higher accuracy, the more interpretable the NBDT.
Understanding Decision Rules
With low-dimensional tabular data, decision rules in a decision tree are simple to interpret e.g., if the dish contains a bun, then pick the right child, as shown below. However, decision rules are not as straightforward for inputs like high-dimensional images.
As we qualitatively find in the paper (Sec 5.3), the model’s decision rules are based not only on object type but also on context, shape, and color.
To interpret decision rules quantitatively, we leverage an existing hierarchy of nouns called WordNet³; with this hierarchy, we can find the most specific shared meaning between classes. For example, given the classes Cat and Dog, WordNet would provide Mammal. In our paper (Sec 5.2) and pictured below, we quantitatively verify these WordNet hypotheses.
Note that in small datasets with 10 classes i.e., CIFAR10, we can find WordNet hypotheses for all nodes. However, in large datasets with 1000 classes i.e., ImageNet, we can only find WordNet hypotheses for a subset of nodes.
How it Works
The training and inference process for a Neural-Backed Decision Tree can be broken down into four steps.
- Construct a hierarchy for the decision tree, called the Induced Hierarchy.
- This hierarchy yields a particular loss function, which we call the Tree Supervision Loss.
- Start inference by passing the sample through the neural network backbone. The backbone is all neural network layers before the final fully-connected layer.
- Finish inference by running the final fully-connected layer as a sequence of decision rules, which we call Embedded Decision Rules. These decisions culminate in the final prediction.
Running Embedded Decision Rules
We first discuss inference. As explained above, our NBDT approach featurizes each sample using the neural network backbone. To understand what happens next, we will first construct a degenerate decision tree that is equivalent to a fully-connected layer.
Fully-Connected Layer: Running inference with a featurized sample is a matrix-vector product, as shown below.
This yields a matrix-vector product yields a vector of inner products, which we denote with y-hat. The index of the largest inner product is our class prediction.
Naive Decision Tree: We construct a basic decision tree with one root node and a leaf for each class. This is pictured by “B — Naive” in the figure above. Each leaf is directly connected to the root and has a representative vector, namely a row vector from W (Eqn. 1 above).
Also pictured above, running inference with a featurized sample x means taking inner products between x and each child node’s representative vector. Like the fully-connected layer, the index of the largest inner product is our class prediction.
The direct equivalence between a fully-connected layer and a naive decision tree motivates our particular inference method, using an inner-product decision tree. In our work, we then extend this naive tree to deeper trees. However, that discussion is beyond the scope of this article. Our paper (Sec. 3.1) discusses how this works, in detail.
Building Induced Hierarchies
This hierarchy determines which sets of classes the NBDT must decide between. We refer to this hierarchy as an Induced Hierarchy because we build the hierarchy using a pretrained neural network’s weights.
In particular, we view each row vector in the fully-connected layer’s weight matrix W as a point in d-dimensional space. This is illustrated by “Step B — Set Leaf Vectors“. We then perform hierarchical agglomerative clustering on these points. The successive clustering then determines the hierarchy, as illustrated above. Our paper (Sec. 3.2) discusses this in more detail.
Training with Tree Supervision Loss
Consider “A — Hard” in the figure above. Say the green node corresponds to the Horse class. This is just one class. However, it is also an Animal (orange). As a result, we know that a sample arriving at the root node (blue) should go to the right, to Animal. The sample arriving at the node Animal also should go to the right again, towards Horse. We train each node to predict the correct child node. We call the loss that enforces this the Tree Supervision Loss, which is effectively a cross entropy loss for each node.
Our paper (Sec. 3.3) discusses this in more detail and further explains “B — Soft”.
Trying NBDTs in under a minute
Interested in trying out an NBDT, now? Without installing anything, you can view more example outputs online and even try out our web demo. Alternatively, use our command-line utility to run inference (Install with pip install nbdt). Below, we run inference on a picture of a cat.
nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32 # this can also be a path to local image
This outputs both the class prediction and all the intermediate decisions.
Prediction: cat // Decisions: animal (99.47%), chordate (99.20%), carnivore (99.42%), cat (99.86%)
You can load a pretrained NBDT in just a few lines of Python as well. Use the following to get started. We support several WideResNet28x10, ResNet18 for CIFAR100, CIFAR100, and TinyImageNet200.
from nbdt.model import HardNBDTfrom nbdt.models import wrn28_10_cifar10model = wrn28_10_cifar10() model = HardNBDT( pretrained=True, dataset='CIFAR10', arch='wrn28_10_cifar10', model=model)
For reference, see the script for the command-line tool we ran above; only ~20 lines are directly involved in transforming the input and running inference. For more instructions on getting started and examples, see our Github repository.
Conclusion
Explainable AI does not fully explain how the neural network reaches a prediction: Existing methods explain the image’s impact on model predictions but do not explain the decision process. Decision trees address this, but unfortunately, images⁴ are kryptonite for decision tree accuracy.
We thus combine neural networks and decision trees. Unlike predecessors that arrived at the same hybrid design, our neural-backed decision trees (NBDTs) simultaneously address the failures (1) of neural networks to provide justification and (2) of decision trees to attain high accuracy. This primes a new category of accurate, interpretable NBDTs for applications like medicine and finance. To get started, see the project page.
By Alvin Wan, *Lisa Dunlap, *Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E. Gonzalez
where * denotes equal contribution
[0] Designed by author Alvin Wan. Footnote exists to clarify we have rights to use this graphic.
[1] There are two types of saliency maps: one is white-box, where the method has access to the model and its parameters. One popular white-box method is Grad-CAM, which uses both gradients and class activation maps to visualize attention. You can learn more from the paper, “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization” http://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf. The other type of saliency map is black-box, where the model does not have access to the model parameters. RISE is one such saliency method. RISE masks random portions of the input image and passes this image through the model — the mask that damages accuracy the most is the most “important” portion. You can learn more from the paper “RISE: Randomized Input Sampling for Explanation of Black-box Models”, http://bmvc2018.org/contents/papers/1064.pdf.
[2] This 40% gap between decision tree and neural network accuracy shows up on TinyImageNet200.
[3] WordNet is a lexical hierarchy of various words. A large majority of words are nouns, but other parts of speech are included as well. For more information, see the official website.
[4] In general, decision trees perform best with low-dimensional data. Images are the antithesis of this best-case scenario, being extremely high-dimensional.