Breaking down Network Dissection
Distilling the ideas from MIT CSAIL’s intriguing paper: “Network Dissection: Quantifying Interpretability of Deep Visual Representations”.
“Network Dissection: Quantifying Interpretability of Deep Visual Representations” is one of my favourite explainable AI papers. I have distilled the contents into a short blog post which may help you quickly grasp the ideas in the paper.
Motivation
Deep learning has revolutionized computer vision, natural language processing, and reinforcement learning by achieving state of the art performance in many challenging tasks. One major drawback of deep networks is their “black box” nature; it is hard to understand how these networks arrive at their outputs. Researchers from MIT’s CSAIL propose a technique called “Network Dissection” which provides a framework for describing what features individual convolutional filters in a CNN are focusing on.
Network Dissection
To perform Network Dissection, there are 3 steps:
Step 1: Identify a broad set of human-labeled visual concepts
To identify a set of human-labeled visual concepts, the authors constructed the Broadly and Densely Labeled Dataset (Broden) dataset. The Broden dataset combines several densely labeled image data sets: ADE [43], Open- Surfaces [4], Pascal-Context [19], Pascal-Part [6], and the Describable Textures Dataset [7]. Each image in the dataset contains a visual concept which is labelled with a pixel-wise binary segmentation map within the image. For example, if the visual concept being labelled is a car, then all of the pixels that contain a car are labelled as 1, and the rest of the pixels labelled as 0. The Broden dataset has 63,305 images with 1197 visual concepts. There are 6 categories of concepts: textures, colors, materials, parts, objects, and scenes.
2. Gather the response of convolutional filters to known concepts
The goal of this step is to obtain a pixel-wise binary segmentation map for each convolutional filter which tells us what parts of the image highly activate the convolutional filter. From the previous step, all of the “known concepts” are contained within the Broden dataset. To gather the response of convolutional filters to the known concepts, images from the Broden dataset are fed into the CNN being dissected and a forward pass is performed. The activation map of every convolutional filter is then stored. The activation map of a convolutional filter is simply the output of the convolutional filter given the input image to the network.
Generally, the deeper into a CNN the filter is, the smaller the size of the activation map. Since we want a binary segmentation map for the input image, the authors use bilinear interpolation to upsample the activation map to be the same size as the input image.
The activation map is real-valued but this must be converted to a binary map for step 3. Let the activation map for unit k be a_k. For each unit k, the top quantile level T_k is determined such that P(a_k > T_k) = 0.005 over every spatial location of the activation map in the data set. To obtain a binary map, all pixels in the activation map with a value greater than T_k are labelled 1, and the rest of the pixels are labelled 0. The result is a binary map which is labelled 1 for the parts of the input image that highly activated the convolution filter.
3. Quantify alignment of hidden variable — concept pairs
Now we have a binary map which tells us where a human-labelled concept is (from step 1), and a binary map which tells us where a convolutional filter is highly activated (from step 2). If the convolutional filter is highly activated for regions of the image that contain a human-labelled concept, then perhaps we can say that the filter is “looking for” that concept. To quantify the alignment between the two binary maps, intersection over union (IoU) is used.
IoU = (number of pixels where the binary maps are both 1) / (total number of unique pixels labelled 1 in both binary maps)
If the IoU of the binary concept map and the binary activation map exceeds 0.04, then the authors labelled the convolutional filter that produced the activation map as a detector of the labelled concept.
Network Dissection and interpretability
The more convolutional filters that align with the human-labelled concepts in your dataset, the more “interpretable” the CNN is. This definition of interpretability hinges on having a comprehensive dataset of concepts to compare convolutional filters to. As the Broden dataset has only 1197 different concepts, there are many human understandable visual concepts which aren’t present in the dataset. If a convolutional filter were to align highly with a human understandable concept not present in the dataset, the framework of Network Dissection would call this filter uninterpretable.
Experiments
Now that we have a framework that can be used to label convolutional filters as detectors for concepts in a dataset, let’s go over some experiments to see that Network Dissection can tell us.
Quantifying interpretability of deep visual representations
In this experiment, Network Dissection was performed on AlexNet trained on ImageNet and AlexNet trained on Places205. The number of detectors identified by Network Dissection for each of the 6 concept categories was plotted for each convolutional layer of the network. The number of object and texture detectors increased into the deeper convolutional layers.
Effect of regularization on interpretability
The authors investigated whether random initializations, dropout, or batch normalization had any effect on the interpretability (number of detectors identified by Network Dissection) of a CNN. Random initialization does not seem to affect interpretability. The network without dropout, had more texture detectors but fewer object detectors. Batch normalization seemed to significantly decrease the interpretability of the network.
Number of detectors vs epoch
In this experiment, the number of detectors was plotted for different training iterations. As the training process continues, the number of detectors mostly increased. In the plot below, after the dotted red line, the interpretability drops as the network begins to overfit.
Conclusion
Network Dissection is a useful framework can automatically quantify the interpretability of a CNN. I hope this article has helped you understand how Network Dissection works. Thanks for reading!