Explainable AI: Counterfactual Search For Generative Models

Dmytro Shvetsov
11 min readJun 11, 2023

--

Team members: Dmytro Shvetsov, Illia Tsiporenko

In this article

  1. Introduction
  2. What is Explainable AI and Why It Is Needed.
  3. Project Goals.
  4. Paper Summary.
  5. Dataset.
  6. Methodology at Replicating Results.
  7. Classification Model.
  8. Explanation Function.
  9. Loss Functions.
  10. Evaluation metrics.
  11. Qualitative Analysis.
  12. Classifier’s Audit Analysis
  13. Conclusions.
  14. References.

Introduction

What is explainable AI and why it is needed

Nowadays, deep learning models have exhibited impressive performance across various tasks such as classification, segmentation, and object detection. However, when it comes to deploying these models in sensitive domains, particularly in medicine, it becomes crucial to understand the specific biases they have acquired from the training data. This knowledge allows us to anticipate potential failures during testing and ensure the reliability of the models in terms of what they attend to in the input image to make the final decision. In this blog post, we delve into the exciting world of counterfactual search using generative models. Put simply, this approach aims to answer the question, “How can we modify the input image to completely change the model’s prediction?” By utilizing generative models, we gain valuable insights into how the input influences predictions, enabling us to empirically evaluate the model’s attention to different aspects of an input image in order to produce a specific output class.

Project Goals

The main goal of this project is to replicate the results of the paper “Explaining the Black-box Smoothly — A Counterfactual Approach”, which employs a Conditional Generative Adversarial Network (cGAN) as an explanation function in order to audit the classifier’s predictions. The primary function of the network is to generate counterfactual examples given the input image, so that the black-box classifier’s output changes significantly.

Paper Summary

Consider that you have a pre-trained black-box multiclass classifier f. The classifier produces point estimates for posterior probability of class k:

We denote the class k as the target class for the explanation. Our explanation function If is a conditional GAN which is able produce realistic perturbation of the query image x such that the classification decision for class k is changed to a desired value c. Therefore, c is a “knob” in our formulation, which we can adjust in range [0; 1] and verify whether the classifier’s prediction changes from negative to positive respectively. Figure below summarizes this process:

To achieve this, the following properties are enforced for the explanation model:

Data consistency: generated image xc should resemble a data instance from the input space with minimum artifacts or blurring.

Classifier consistency: xc should produce the desired output from the classifier f, i.e., f(If(x, c)) ≈ c

Context-aware self-consistency: On using the original decision as the condition, i.e., c = f(x), the explanation function should reconstruct the query image. This condition is forced for self-consistency as If (x, f(x)) = x and for cyclic consistency as If (xc, f(x)) = x.

Dataset

Due to the fact that the original dataset is not available publicly, we perform the experiments on a COVID-19 Radiography Database. The dataset contains 21K chest X-ray images for COVID-19, Viral Pneumonia and Lung Opacity (non-COVID lung infection) diseases, as well as Normal cases. In addition to the class labels, binary masks for the lungs are present for the reference. The distribution of the dataset looks as follows:

Obviously, some classes are under-represented due to rareness of some diseases which sets a challenge for building a classification model in this case. We split the entire collection in three parts in stratified manner:

  1. 5% — training set for the classifier.
  2. 80% — validation set for the classifier and the training set for the generative model.
  3. 15% — validation set for the explanation function based on the counterfactual validity score.

Methodology at Replicating Results

In order to decompose the development process, we propose the following steps:

  1. Building a black-box classifier — this step involves training a baseline classifier which will be used as an explanation target for one of the classes present in the dataset.
  2. Building an explanation function — this step entails implementing an end-to-end pipeline from the original paper to be able to generate counterfactual examples.
  3. Evaluation of the trained explanator — here we will inspect how efficient the explanation function is in terms of the Counterfactual Validity metric score (i.e a fraction of all samples generated by the model where the predictions actually “flipped”).

Classification Model

As a baseline, we take the ResNet-18 classification model pre-trained on the ImageNet dataset. The classifier is fine-tuned for 4 classes on our dataset. We employ a balanced sampler to account for the class-imbalance problem. The input grayscale images are augmented with random horizontal/vertical flips and shift/scale/rotation affine transforms. The preprocessing is fixed to resizing the images into 256x256 size and normalizing them into [-1; 1] range. The resulting model achieves the following metrics on the validation set:

Explanation Function

The original paper builds a generative model (explanation function) for each class separately. The posterior probability [0; 1] range is discretized with N bins to generate conditions c. The bigger the N, the more finer the explanation function can be in terms of how many conditions can be analyzed. However, given the fact that conditional GANs require class-balanced batches of images in the training set, it would require a classifier to generate posterior probabilities, so that we have sufficient number of samples in each bin (condition). In other words, if we build a histogram (for N bins) of the classifier’s posterior probabilities on its validation set, we would need it to be as uniform as possible.

Here are the histograms for each class given N=2 for our dataset and the classifier:

Therefore, considering the size of our dataset and the distributions of posterior probabilities, we came to agreement that the most viable class for the explanation can be class Normal. In other words, we try to build the generative model that given an image x of normal lungs where f(x) >0.5 and condition c<0.5 , it should generate an abnormal image xc (i.e one of viral pneumonia, covid or lung opacity), so that the posterior probability becomes f(xc) < 0.5. Similarly, the experiment should work in the vice versa order (i.e having abnormal images, we need to generate normal ones).

The architecture of the conditional GAN is implemented exactly as in the original paper from scratch. The generator (encoder E and decoder G), and discriminator D are built from the following blocks:

We leverage the already available implementation of the conditional batch normalization (cBN) and layers with spectral normalization blocks for this architecture.

Loss Functions

First, to account for the data consistency, the original paper uses regular BCE loss. However based on the Least Squares GAN paper, using L2 function as an adversarial loss is more beneficial instead of the former one. Therefore, we use it in our approach:

where a and b are the labels for fake data and real data, and c denotes the value that G wants D to believe for fake data respectively.

Second, as for the classifier consistency term, similarly to the original paper we calculate the KL divergence loss between probability distributions:

Where predictions are the classifier’s posterior probabilities for the generated images xc, and the ground truth is the of classifier’s desired output 1-f(x).

Third, we implement the original context-aware reconstruction loss (CARL), which aims to penalize the model for small details in generated images. Instead of computing a pixel loss for the entire image, it uses additional information such as semantic segmentation masks, or object detection results to force local consistency. In our implementation, we omit the object detection term and use only the semantic segmentation part where the L1 loss is computed only on the positive pixels, i.e:

where S is a binary segmentation mask of the class j, for an image x.

Therefore, as for the context-aware self-consistency term discussed earlier, the CARL is adopted as follows:

Finally, the whole objective is formulated as follows:

where λ is the weight assigned for each of the loss terms.

Evaluation metrics

To validate usefulness of our explanation function, we use the original Counterfactual Validity (CV) metric, which is formulated as the fraction of counterfactual explanations that successfully flipped the classification decision, i.e:

Where τ is the margin threshold between posterior probability of the original image and the counterfactual example.

Additionally, we calculate the Counterfactual Accuracy (CA) metric, which is a regular accuracy score computed for the expected and predicted bucket indices obtained after discretization of the classifier’s posterior probabilities for original and counterfactual images.

Experiments and Results

In all the experiments, the generative model is trained for 1000 epochs on a powerful NVIDIA A100 40G GPU. The training set is the classifier’s validation set as discussed earlier and we are splitting the posterior probabilities with N=2 for the target class Normal to be explained.

In addition, we replicate the same training schema as in the original paper. The generator is updated once per 5 update steps of the discriminator. The optimizer is chosen to be Adam (lr=0.0002, ß1 = 0, ß2 = 0.999) operating with batches of 16 grayscale images (256x256). We report 4 most remarkable experiments with different setting of the total loss function:

A table of the final evaluation results for the trained explanation functions (cGAN)

In our work, the most successful experiment is the second one achieving 94.29% of the prediction flips according to the CA metric, and 75.8% of the flips with a large margin of 0.8 according to the CV metric. Now, we will skim through each experiment to analyze different outcomes in the trained models.

Qualitative Analysis

Experiment 1

First of all, we prove the importance of the classifier consistency to achieve any prediction “flips” given arbitrary classification model. In the first experiment, no classifier consistency is employed, which does not force the generative model to produce perturbations that actually impact the classification probability.
In our visualizations, the first row corresponds to the input images with posterior probability p. The second row is generated counterfactuals conditioned with probabilities 1-p. The third row is the absolute difference of the first two (the brighter the pixels, the more the perturbed in those areas).

An example of the “Normal” lungs and a counterfactual image that did not flip the prediction generated for the first experiment.

From the example above, although the output images look realistic, they do not flip the predictions.

Experiment 2

An example of the “Normal” lungs and a counterfactual image that did flip the prediction generated for the second experiment.

The second experiment is notably the most successful one, and later we will derive more interesting conclusions on why the predictions actually flip revealing fundamental drawback of our dataset that would bias any classifier trained from it.

Experiment 3

An example of the “viral pneumonia” lungs and a counterfactual image that did not flip the prediction generated for the second experiment.

As for the third experiment, the results look promising in terms of unique perturbations. However, due to fairly large weight assigned to the KL loss, an artifacts that look like a “flag” are present in the generated images. Here are a few examples of only generated images in the training set:

Examples of artifacts present in generated images in the training set in the third experiment.

Experiment 4

An example of the “lung opacity” lungs and a counterfactual image that did flip the prediction generated for the second experiment.

As for the forth experiment, the problem is inherently the same as in the third one. The perturbations look too “aggressive” and changing the adversarial functions does not solve that. The artifacts produced by the model resemble a grid of connected dots. Here are a few examples of only generated images in the training set:

Examples of artifacts present in generated images in the training set in the forth experiment.

Classifier’s Audit Analysis

In this section, we will go through the counterfactual examples of the second experiment more deeply to analyze what are the actual changes in the input images that impact the classifier’s decision. Below you can inspect the most notable examples of each classification label and their counterfactual examples that do flip the predictions.

An example of the “covid” lungs and a counterfactual image that did flip the prediction generated for the second experiment.
An example of the “lung opacity” lungs and a counterfactual image that did flip the prediction generated for the second experiment.
An example of the “normal” lungs and a counterfactual image that did flip the prediction generated for the second experiment.
An example of the “viral pneumonia” lungs and a counterfactual image that did flip the prediction generated for the second experiment.

Interestingly, in all the examples for each class, the most changing areas that are affected by the explanation function are places not within the lungs themselves.

Therefore, the main conclusion out of all this experimentation is that our classifier model trained on 5% of the data is basically biased to different markers that are present in the images (e.g arrows, texts, etc), which we prove with our explanation function (GAN). It means that we have a working baseline that replicates the results of the original paper, and we are able to audit the decisions of the classifier model.

You can find more counterfactual examples for each experiment in here.

Conclusions

In this project, we study the importance of the explainable AI in the context of image classification networks. We fully replicate the results of the paper “Explaining the Black-box Smoothly — A Counterfactual Approach” from scratch and prove the viability of the solution on the public COVID-19 Radiography Database. In our findings, we showcase that image classifiers can be easily biased towards visuals that should not correlate with the target labels to make the final decision. Therefore, this calls for a need to have an explanation function like, for instance, cGAN that helps audit the classifier’s attention to build more confidence how robust it is for different counterfactual examples.
Having implemented this project, we can also conclude that all the 90+%-ish classification results (e.g 1, 2, 3, 4) that are reported by people in the “Code” section of this Kaggle dataset are highly likely to be biased towards the same problem we have shown (redundant markers, arrows, etc). This raises an important note to medical institutions that not all the off-the-shelf models are production-ready and can be used for real-world testing.

References

  1. https://www.sciencedirect.com/science/article/abs/pii/S1361841522003498
  2. https://arxiv.org/abs/1512.03385
  3. https://arxiv.org/pdf/1611.04076.pdf
  4. https://arxiv.org/abs/1905.07697
  5. https://openreview.net/pdf?id=B1QRgziT-
  6. https://github.com/godisboy/SN-GAN
  7. https://github.com/pfnet-research/sngan_projection

--

--