DeepViz, a visual exploration tool — Explainable AI and Visualization (Part 12)
This article is a continuation of the research and design for the study ‘Explainable Deep Learning and Visual Interpretability.’
Design Goals
In addition to model complexity, the problem of trust and transparency has been receiving more attention lately for deep neural networks and other non-linear models. As a result, several methods have been developed to interpret what a deep neural network (DNN) has learned. A large section of work in DNN interpretability is dedicated to visualizing specific neurons or layers. This work focuses on a method that can visualize the impact of particular regions of a given input instance on the prediction or the desired class of interest.
The goal is to build an interactive visualization tool to interpret a visual classifier model and provides an explanation to the end user using visual evidence that justifies its recommendation, decision or action. I use the explanation approach in the above figure as the guiding principle to design and develop the explanation interface of the prototype.
The prototype objective can be summarized mainly as (i) Providing visual interpretation that justifies model prediction or decision to non-specialists, (ii) synthesizing the concept of deep learning to a non-technical audience, and (iii) broadening people’s access to interactive tools for deep learning.
The prototype is intended and designed for people with a non-technical background to interact with and visually understand the decision of an image classifier. The high-level intent is to develop a set of utilities or toolkits as the first step towards building a robust explanation framework that can make any visual classifier more interpretable and explainable for the end user.
DeepViz — Visual Exploration Tool
DeepViz serves as a visual exploration tool to enable interactive and explorable explanations of the model. It utilizes a visualization technique that focuses on the class differential properties of the visual object and helps the end user understand the model inference. The tool also visualizes the intermediate output of the network as a directed graph displaying a hierarchical representation of input transformation.
Firstly, When an image is submitted, DeepViz predicts a class label and then shows why the predicted label is appropriate for the image using the relevant heat map. This method is useful to highlight which part of the image is most relevant to the classification decision. The second method is a feature activation graph that displays what features a neural network has learned and how the network builds its internal representation of features detected by the network during the inference process.
Secondly, DeepViz produces an activation graph visualizing the feature maps that are output by intermediate layers in the network. This graph visualization outlines how given input is decomposed into various individual filters learned by the network.
Design Approach
This section introduces two visualization techniques implemented in the prototype for interpreting and exemplifying the prediction of an image classifier model.
Sensitivity Analysis
Sensitivity Analysis visualizes a heat map of the activation output corresponding to the predicted label. This method is useful for interpreting which part of an input image led the model to its final classification decision. It localizes the class discriminative region in the image by generating a heat map superimposed on the input image. It highlights the discriminating region corresponding to the desired class clearly.
For instance, when an image is submitted, the model predicts the class. The input image is correctly classified as “bee-eater.” To understand why the model arrived at this decision, the tool generates a heat map that visualizes the importance of each pixel in the input image for the prediction of that class. In this example, the bee-eater’s beak and neck are the basis for the model’s decision. With the relevance heat map, the user can verify that the model works as intended.
The tool uses localization to know what region in the image attributed most to the classification decision. It helps identify individual pixels in the input image with the highest activations. Localization, in the context of CNN, is a task to localize objects in images using only whole image class labels. It indicates where the model has to look to make a specific decision. It identifies pixels that are pivotal for the prediction object.
I use the class activation mapping approach to obtain the localization map. I focused on the output of the last convolutional layer because it detects the high-level features and is different from the predicted class. Therefore, I computed the value of the activation output of the last convolutional layer and then computed the gradient of the probability score of the class with respect to the logits or activation output of the last convolutional layer. This step helps capture the importance of the feature map for a targeted class.
The gradient value is pooled within each filter by taking the global average mean. The resulting value is then multiplied by the activation output of the last convolutional layer for the weighted combination. The resulting value is averaged, and all the filters are collapsed to create a heat map vector.
This is followed by ReLU (Rectified Linear Unit) activation in order to discard the negative values from the heat map and normalize them. The heat map vector is resized to match the size of the input image. An RGB colormap is applied to the heat map vector to transform the 1-channel greyscale image into a color image. Then the color heat map is overlaid on top of the input image to form the final output.
Feature Activation Graph
A feature activation graph visualizes the intermediate activations of the network. It takes the values of the activation output of all the convolutional layers and produces a directed acyclic graph. The output of the layer is actually the sum of the output of the activation function of all the channels in that layer. The graph decomposes the activation output at every layer into a distribution of channels per layer.
Visualizing intermediate activations helps display the feature maps produced as output by various convolution and pooling layers in a network, given a certain input. This gives an overview of how an input instance is decomposed into the different filters learned by the network. These filters consist of 3 dimensions (width, height and depth), where each channel encodes a relatively independent feature. Thus we visualize these feature maps for every channel as a 2D image in individual vertices.
The direction of the graph is from bottom to top, corresponding to the lower level to a higher level. Each level represents a convolutional layer, where each vertex represents an individual channel. An edge connecting two vertices represents the weight between neurons in successive layers.
Users can explore the edges to discover what the network detects at each layer as it propagates the activation output forward. This step helps the user understand when fed with an input image, how successive layers of the network transform the input image. This also gives an idea of the meaning of the individual network filters.
Although the VGG16 model has 18 such hidden layers with channels ranging from 64 to 512, we have limited the number of layers and channels in our view to only 8 to maintain optimum performance of the application, given the rendering effort is expensive to render all channels.
Technical Design
The following segment provides an overview of the technical design and resources for the prototype development.
Client-side Neural Network
Developing AI applications using a modern deep learning framework is a non-trivial task. Normally these frameworks and libraries are leveraged by native applications that run on a native platforms environment such as Linux, Windows, MacOS/iOS and Android. Thus most production-level libraries are developed for and written in Python, Java and C++.
Developing a machine learning application that is cross-platform and portable on multiple devices is not easy. The development of a native application is an intricate and time-consuming process. It is particularly complicated for mobile applications as the app vendors usually need to develop and maintain both iOS and Android versions in addition to the desktop application.
Compared to the native application, client-side applications can make the portability issue simpler for the cross-platform. The sample implementation of the deep learning-powered web application can be deployed on multiple platforms regardless of operating systems, hardware or device types.
Deep learning in the browser is at the experimental stage, and recently, several JavaScript-based deep learning frameworks have been introduced, making it possible to perform several deep learning tasks directly on the browser. Some of the supported features include model training, importing pre-trained models, transfer learning and inferences.
However, there is a debate on the feasibility and effectiveness of web-based deep learning applications. On one hand, those who object think browsers are not primed for running deep learning tasks, and it’s merely impractical due to the poor performance of client-side scripting and limitations imposed by the browsers.
On the other hand, advocates think that the browser is an ideal platform for realizing client-side machine learning that allows highly rich interaction and improves personalization for end users. The benefits include but are not limited to faster user interaction, preserving data privacy, lower back-end payload, reduced data transfer and performance latency of HTTP client-server communication.
Supported Browser Features
I analyzed the modern browser support for machine learning tasks and took into consideration the factors that may affect efficiency when building and deploying deep learning applications on the web. One of them is the debugging capability to support model and data inspection when running deep learning tasks.
Further, in-browser deep learning allows users to use local data and then train the model directly in the browser, meaning there is no back-end end or server necessary.
Model Selection
As the model selection is the centerpiece of the data science workflow, I evaluated a set of candidate pre-trained models for my classification project by running a series of studies and experiments. Specifically, I created a special Python script that can load any of these pre-trained network architectures from disk using Keras, and then classify a set of sample images. I ran the experiments on models, namely, MobileNet, Inception V1, Inception V3, VGG16 and VGG19 and compared them in terms of performance, accuracy, network architecture and size of the learned parameters.
VGG16 and VGG19 have better accuracy than other models, and both are trained on the ImageNet database. The “16” and “19” stand for the number of hidden layers in the network. Due to their depth and number of fully connected units, VGG16 is 533MB, and VGG19 is 574MB in size.
This experiment also helped assess if the data collected for inference is suitable for the classification task. I took into account the hyperparameter setting and other configuration details required to evaluate during the inference process if required. We tested the compatibility of the model with the available dataset. Based on my evaluation and initial experiments, I selected VGG-16 as the primary model for the prototyping phase.
VGG-16 is based on the convolutional neural network model proposed by the Visual Geometry Group from the University of Oxford in the paper “Very Deep Convolutional Networks for Large-Scale Image Recognition”. The model achieves 92.7\% test accuracy in the ImageNet database, which is a dataset of over 14 million images grouped into 1000 categories. It was one of the popular models submitted to the ILSVRC-2014 competition. The list of hidden layers can be seen in the table. This model is used as the backbone network in our project.
Dataset
ImageNet is a massive database of image collection that consists of over 15 million labeled high-resolution images belonging to roughly 22,000 categories. These images were downloaded from Google image search and labeled by humans using Amazon’s Mechanical Turk crowd-sourcing tool in 2012. It took about two and a half years to label all the images.
The labeled dataset was first introduced in a competition called the ImageNet Large-Scale Visual Recognition Challenge (ILSVRC). The competition used a subset of ImageNet with around 1000 images in each of the 1000 categories. In total, there are 1.2 million training images, 50,000 validation images, and 150,000 testing images. The images have been down-sampled to a fixed resolution of 256x256 dimensions to maintain consistency and resolution.
Framework Selection
To develop client-side deep learning applications, I surveyed the open-source machine learning frameworks for the web: TensorFlow, Keras, and WebDNN. Based on my assessment and feedback from the developer community, I selected TensorFlow in Python and TensorFlow.js for the prototype development. I also use JavaScript ES6 and Node.js for server-side processing and environmental setup.
I found TensorFlow.js as a better choice for training and inferring deep learning models on the browser. While several other open-source JS platforms for machine learning have appeared in the recent past, I noticed TensorFlow.js as feature-rich and well-documented when compared to other web libraries. Further, TensorFlow.js takes advantage of GPU processing power to accelerate deep learning tasks on browsers via WebGL, which is an essential criterion for my tool that runs inference on a vision-based model. WebGL is a back-end compute and a JavaScript API for rendering interactive 2D and 3D graphics within web browsers without the use of additional plug-ins.
Since both client-side and server-side framework is part of the TensorFlow ecosystem, I can easily access APIs that are compatible with either one. This also makes model conversion easier and allows for models to be ported between Python and JavaScript ecosystems.
Prototyping
As one of my goals for the prototype is to broaden people’s access to interactive tools for deep learning, and the tool is targeted towards the non-technical audience, DeepViz is a web-based interactive visualization tool that can be accessed from any modern web browser.
The development environment for the client-side application is set up using the JavaScript ecosystem and W3C web standards. DeepViz uses web technologies: HTML, CSS, JavaScript ES6 and SVG. Node.js is used for back-end scripting for the JavaScript ecosystem. For client-side visualization, I use SVG and D3.js for rendering vector graphics for the graph views. I also use the Lucid library for experimenting and testing feature visualizations in the VGG16 model.
I leverage transpilers to transform code written in JavaScript ES6 into standard ES5 JavaScript that is executable in any browser. I also use bundling and tooling for local development and prototype deployment on the server.
The development environment for the client-side application is set up in the Node.js environment. It’s also used for tooling and bundling production deployment assets. For the back-end deep learning framework, we executed all our code on the workstation hosted on the virtual server provided by the NYU High-Performance Computing services.
Environment Setup
DeepViz uses web technologies: HTML, CSS, JavaScript ES6 and SVG. Node.js is used for back-end scripting for the JavaScript ecosystem. For client-side visualization, I use SVG and D3.js for rendering vector graphics for the graph views. I also use the Lucid library for experimenting and testing feature visualizations in the VGG16 model.
The development environment for the client-side application is set up in the Node.js environment. It’s also used for tooling and bundling production deployment assets. For the back-end deep learning framework, we executed all our code on the workstation hosted on the virtual server provided by the NYU High-Performance Computing services.
The next article in this series covers the result and conclusion of the XAI research experiment.