Analytics Vidhya
Published in

Analytics Vidhya

Unavoidability of Model Interpretability

High score model doesn’t mean that it is interpretable, and worse than that, model results could be misleading. Never trust a model that is telling 99% accuracy at the first shot. Tools like LIME, SHAP or SHAPASH should be strongly informative.

Co-authors: Yaniv Goldfrid, Dana Velibekov and Yair Stemmer

Interpretable AI

We looked at few kaggle notebooks in the pneumonia dataset challenge that have high accuracy, and we decided to check there interpretability. The model should identify the same symptoms as a doctor, if not, there is a problem somewhere. At the end, we understood that high score model doesn’t mean that it is explainable, and worse than that, the model results could be misleading.

In this article, we will proceed first a short EDA, apply baseline models and CNN model, and finally get inside the algorithm’s head to explain what makes it predict the outcome the way it does.

EDA (Exploratory Data Analysis)

The original dataset consists of two classes:

  • Healthy
  • Pneumonia (Bacteria or Virus)

We are working on a total of 5856 X-Ray images:

Dataset is slightly unbalanced towards the “Pneumonia” class.

Patients defined as healthy, on average, had larger and more detailed X-ray images than those who were sick.

Here’s how the images look like:

Healthy patients
Sick patients

We believe that the person without medical training can’t tell the difference between those images.

Conclusions we made by looking at pictures:

  • Some images exhibit the letter R on the top or middle left
  • Some images exhibit text on the top right or down left
  • Some images display medical devices in the form of little circles and wires for pneumonia category
  • Images are not perfectly centered
  • Chests are not exactly vertical in all images
  • Chest widths vary from image to image

Trying out baseline models

In order to apply the baseline models (Logistic regression and Random Forest) images were resized (64*64) and converted into tabular form — each row represented an image and columns respective pixels.

When both Random Forest and logistic regression produced nearly perfect results on the validation set (no data leakage was performed), we started wondering what could be a reason for the faultless prediction?

Using built-in sklearn API to determine feature importance was a poor choice since every X-ray is unique and therefore aggregated feature importance is not a good solution. Since each pixel represents a separate feature, applying LIME (Eli, SHAP or other interpretability tools) did not make sense. We still did it to have an intuition on the method and received the following pictures.

LIME results for most important pixels with random forest model
Sick patient’s X-ray
Model Heatmap

Here we have an X-Ray classified with Random Forest Regressor as sick with Lime explaining the most important pixels

Red circles: around most significant pixel explaining sickness

Blue circles: around most significant pixel explaining health

Model’s most important points for decision making based on heatmap: concentrated, localized on the lungs, but also present on image border

This result can’t be reliable, most important pixels are on the border and it is not taking its decision sufficiently based on patient’s lungs.

Moving to Neural Networks

For training a CNN model, we decided to go with Transfer Learning. Since the model was initially dedicated to help doctors, we chose the MobileNet model. As the name suggests, its weights are light and are mostly used in mobile applications.

As usual, we left out fully connected layers and made sure the learning rate was very low (to avoid catastrophic forgetting).

Final structure of the model:

Model Summary

Making prediction on test set brought us the following results:

Confusion Matrix

precision score is: 81.13%

recall score is: 99.23%

0 = healthy

1 = sick

In the medical field it is extremely important to decrease the number of False Negatives (predict Healthy when the patient is Sick) — that’s why we concentrated on Recall score.

We were really happy with the results (they were similar to the other notebooks) when we started wondering what made the model make its decisions.

Model Interpretability

It was extremely important for us not to treat the model as a black box and to see what it bases its prediction on.

Local Interpretable Model-agnostic Explanations (LIME, 2016) has three basic ideas behind it:

  • Model-agnosticism — no assumptions are made about the underlying model
  • Interpretability — results can be easily interpreted (Tabular or Images)
  • Locality — explanation “in the neighbourhood” of the instance we want to explain

Let’s understand the syntax of LIME

  1. Instantiate a class
explainer = lime_image.LimeImageExplainer()

2. Produce the superpixels which allow to see places that the model learned from mostly when making a decision

explain_instance(image, classifier, top_labels=2, hide_color=None, num_samples=1000, distance_metric=’cosine’)
  • Image — image preprocessed to work with tf.keras API
  • top_labels — if not None, ignore labels and produce explanations for the K labels with highest prediction probabilities, where K is this parameter.
  • num_samples — size of the neighborhood to learn the linear model
  • distance_metric — the distance metric to use for weights.
explanation = explainer.explain_instance(process_img(img_path)[0], model.predict,top_labels=2,hide_color=0,num_samples=1000,distance_metric=’cosine’)

3. Intermediate step for plotting

get_image_and_mask(label, positive_only=True, hide_rest=False, min_weight=0.0)
  • label — label to explain
  • positive_only — if True, only take superpixels that contribute to the prediction of the label. Otherwise, use the top num_features superpixels, which can be positive or negative towards the label
  • hide_rest — if True, make the non-explanation part of the return image gray

Below are some images produced by LIME:

Healthy patient’s analysis with LIME
Sick patient’s analysis with LIME

We were disappointed to learn that the model mostly learns background, limbs, bones and medical devices as signs of being Healthy or Sick.

It is obvious that we can’t define if a patient is sick based on medical devices tied on him or image background. Medical devices are here because the patient is sick, it is not the origin. Image background depends more on the radiologist’s method, he will probably prefer to define a different position for patient that are sick to improve result visibility.

Conclusion

To avoid overfitting and interpretability problems of this kind, we should:

  1. Use data augmentation techniques
  2. Crop images so that it only includes lung areas
  3. Use color equalization on the pictures
  4. Along with giving prediction (Healthy or Sick) show boundaries of the affected lung areas
  5. Use dataset from different origin

It is critical to deal with interpretability in depth when we know how the model should take its decision.

If we want models to help us make important decisions, we must know the reasons why the model made its decision.

This article was made by Yaniv Goldfrid, Dana Velibekov and Yair Stemmer

Dataset:

Further reading:

First image reference:

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store