Detecting COVID-19 from Chest X-Rays using Attention Maps in Keras/Tensorflow and making a Flask web app out of it.

Sagarnil Das
Analytics Vidhya
Published in
12 min readMar 30, 2020
Photo by Pille-Riin Priske on Unsplash

On March 2015, Mr. Bill Gates gave a TED talk titled : “The next outbreak? We’re not ready.”. On 18th March, 2015, he published a post on his blog, GatesNotes. The post includes his TED Talk. He wrote that the next global pandemic could be worse than the Ebola outbreak of 2014 to 2016 that killed about 11,000 people.

Fast forwarding 5 years, we are witnessing a world which surpassed every imagination we could have had about an apocalytic world and made us face the ultimate existential crisis. All across the world, lockdowns are in order with death tolls increasing making this a crucial time for humanity to unite together.

Coronaviruses are a large family of viruses that cause illnesses ranging from the common cold to more severe diseases. An outbreak of a new strain, given the name “2019-nCoV” and known as “novel coronavirus”, was identified in China in late December 2019. The virus causes the disease Covid-19.

As of 28 March 2020, there were nearly 650000 confirmed cases of Covid-19 worldwide, and close to 31000 deaths.

Source: Reuters

At this point, the best defence we can possibly deploy is to stay at our home and avoid social gatherings at all costs. Social distancing seems to be the only way forward at the present scenario. However, as the confirmed cases are showing an exponential growth, medical facilities in many nations are getting overwhelmed by the sheer number of patients and lack of a cure as of now.

China recently developed a helmet powered by artificial intelligence which can measure the body temperature of people in crowds a distance of upto 5 meters away.

Other AI Based solutions are being extensively delved into and researched. But even though such models have been reported to have a high sensitivity and specificity, these models should not be left to operate in isolation and without human intervention because studies have shown that these models can have very volatile predictions when comparing COVID-19 against other types of viral illness.

In this article, we will explore the methodology of implementing such a process and analyze the results obtained. We will use a transfer learning process coupled with a visual attention mechanism so that the deep learning agent only gives attention to parts of the image instead of looking at every part of it with equal importance. But first, we will disucuss how visual attention works and then we will dive staright into the codes (full github repo here).

The dataset has been compiled by Adrian Yijie Xu. You can find his method of implementation here. This is a combination of Kaggle Chest X-ray dataset with the COVID19 Chest X-ray dataset collected by Dr. Joseph Paul Cohen of the University of Montreal. The dataset contains 4 categories:

  1. COVID19: 60 train images, 9 test images
  2. Normal: 70 train images, 9 test images.
  3. Pneumonia (Bacterial): 70 train images, 9 test images
  4. Pneumonia (Viral) : 70 train images, 9 test images

A few samples of the images look like this:

COVID-19
Normal/Healthy
Pneumonia

One of the major issues in this type of approach is that there simply isn’t too many positive instances of COVID19 in the X-Ray images. So even though a model might be created with a high test accuracy, it may not generalize well to the real world as we simply don’t know how much variance is there when we look at an X-Ray of a positively identified patient. But with that being said, this model has reached a high validation accuracy of 85% and a test accuracy of 100%. So we can only hope that we are on the right way. So without further ado, let’s first understand what is visual attention.

Visual Attention in Convolutional Neural Networks

Visual attention works in a way very similar to how our own vision works. Let’s consider a scenario. Suppose someone tells you : “Hey look, there’s a tiger!” You turn to the direction he/she was pointing to and then at first you look at your whole field of view as one image as the visual cortexes sends the stimuli to your brain. But then you start focussing, you start focussing on the tiger as you know how it looks like (a long history of transfer learning, I must say!). Your peripheral vision becomes blurry as you lock your focus on the tiger and you start paying ‘attention’ to it.

Now think about it. In order to recognize or process the image of the tiger in your brain, did you absorb and process the whole scene part by part? No, you instantly knew where to look and locked your focus right on that spot until your peripherals became blurrier and blurrier. This is the part when you are implementing visual attention yourself. Neural networks are not so different from this.

Our main reference for understanding visual attention would be this paper: “Learn to pay attention”. In visual attention mechanism, two prevalent types are presnt.

  1. Hard attention: This method applies image cropping on the regions of interest. Typically algorithms like REINFORCE are used to train hard attention mechanisms. The output of hard attention is a binary value of 0 or 1. 1 correponds to the preservation of a pixel and 0 represents that the piel has been cropped out.
  2. Soft Attention: Soft attention uses soft shadings to focus on regions. The value of the output maps is a decimal number between 0 and 1.

The paper Learn to pay attention uses soft attention to solve a multiclass classification problem. The authors demonstrate that soft trainable attention improves performance on multiclass classification by 7% on CIFAR-100, and they show example heat maps highlighting how the attention helps the model focus on parts of the image most relevant to the correct class label.

The follwing picture depicts their model architecture:

One more novel way as suggested by K. Scott Mader is to build an attention mechanism to turn pixels in the GAP on an off before the pooling and then rescale (Lambda layer) the results based on the number of pixels. The basic idea is that a Global Average Pooling is too simplistic since some of the regions are more relevant than others.

In this work, we will do something similar to these methods to solve the X-Ray classification problem. You can find the notebook here.

Let’s write some codes! We will subdivide it into two parts:

  1. The deep learning part
  2. The web app creation part

The Deep Learning part

A) Defining Global parameters

As the dataset have too few images, we are going to use image augmentation so that our model does not overfit too badly. We are going to use the ImageDataGenerator class for this.

B) Loading Data and Augmentation

We will use the ImageDataGenerator class in keras to load and augment the images. This class creates the augmentation in memory and the images are not changed within the drive.

C) Model Architecture

Now we will create the model architecture and apply attention to it.We are using VGG16 as the pretrained model with 'imagenet' weights. While calling this model, we will keep the include_top parameter to False as we won't include the fully-connected layer at the top of the network. This layer will follow a layer of Batch Normalization which is a technique for training very deep neural networks that standardizes the inputs to a layer for each mini-batch. This has the effect of stabilizing the learning process and dramatically reducing the number of training epochs required to train deep networks.

In the attention part of the model, we will use multiple 1x1 convolutions. This is because we want to avoid a very large number of feature maps and so we want to downsample them. We build an attention mechanism to turn pixels in the GAP on an off before the pooling and then rescale (Lambda layer) the results based on the number of pixels. The model could be seen as a sort of ‘global weighted average’ pooling. In the final model, we combine the Pretrained VGG16 model and the attention model in a linear fashion. The diagram belowrepresent the attention model.

Attention model architecture

And this diagram represents the final model architecture.

Final model architecture

The code to create this architecture is as follows:

D) Callbacks

We will use two callbacks for this model.

  1. Model Checkpointing: We will recursively overwrite the same file with the best weights as the model gets trained. Later we will need this file to create the web app.

2. Reduce Learning rate on plateau: This parameter reduces the learning rate when a metric has stopped improving. This is really important in our case as we have too few images and we don’t want to skip over the global minima and get stuck to a local one.

E) Fitting the model

Now we are all set to train the model.

The steps_per_epoch represents the total number of steps (batches of samples) to yield from generator before declaring one epoch finished and starting the next epoch. It should typically be equal to ceil(num_samples / batch_size)

F) Plotting the loss function and accuracy

Let’s create a plot for accuracy and the loss function.

G) Evaluate the model on test data and test it on some samples

We will evaluate the model on the test data and plot some images along with the predicted label and the probability of the label.

H) Creating the main function and invoking it

Let’s call all these functions in appropriate order in the main function and then lets’s execute the main function.

I) Loss and accuracy plots

With Attention

We see that the plots are very volatile but there is a steady trend of improvement of the model. One of the main way these can be improved is to get more training data. But nevertheless, we achieve an validation accuracy of around 85% with this data.

In the github repo, you will also find a method without attention. Comparing the test results of attention vs without attention on some sample test data:

Whew! That was a long part. But that concludes our deep learning part. Let’s now move onto the web application development part where we will create a simple web interface using flask and in that web interface we can upload an x-ray image and it will be classified between [COVID19, NORMAL or TERTIARY PNEUMONIA] in the app itself.

The Web App creation part

For the web app creation part, I used the keras-flask-deploy-webapp repo. However, there are some significant changes you need to make in the code in order to make it work for this case. So I am going to tell you how to do it step by step.

  1. Clone the repo:
git clone https://github.com/mtobeiyf/keras-flask-deploy-webapp

2. Change directory

cd keras-flask-deploy-webapp

3. Install requirements

pip install -r requirements.txt

4. Copy and paste your model (h5 file) in the models directory within the keras-flask-deploy-webapp. This should not be only the model weights but also the model architecture. That’s why if you go up to the callbacks function above, you will find that we made the save_weights_only = False so that it saves the whole model and not just the weights.

models folder (highlighted)

5. Open the app.py file. Here you will have to make some changes.

a) In the top import lines, change the line

from tensorflow.keras.applications.imagenet_utils import preprocess_input, decode_predictions

to whatever model you are using. So if you are using VGG16, then it would be

from keras.applications.vgg16 import preprocess_input

Also, note that I got rid of the decode_predictions method in the import statement. This is because this statement is only used if you are classifying Imagenet data to make the labels into human readable form. But since we are not, we won’t be needing this.

b) Next, comment out these two line:

#from keras.applications.mobilenet_v2 import MobileNetV2
#model = MobileNetV2(weights='imagenet')

As we will use our own model, we won’t need these two lines.

c) In the MODEL_PATH parameter, specify your model name which you kept in the models folder in step 4. e.g

MODEL_PATH = 'models/covid_attn_weights_best_vgg16.h5'

d) Comment in the next two lines, i.e.

model = load_model(MODEL_PATH)
model._make_predict_function()

e) In the model_predict function, change the image size to 150, 150, as this was the image size we worked with.

img = img.resize((150, 150))

e) In the predict function, get rid of these two lines:

pred_class = decode_predictions(preds, top=1)   # ImageNet Decode         result = str(pred_class[0][0][1])               # Convert to string

Add these lines instead:

pred_class = np.argmax(preds)
# Process your result for human
pred_proba = "{:.3f}".format(np.amax(preds)) # Max probability

print(pred_class)
if pred_class == 0:
result = 'COVID19'
elif pred_class == 1:
result = 'NORMAL'
elif pred_class == 2:
result = 'TERTIARY PNEUMONIA'

So effectively the model_predict function now looks like this:

And the predict function looks like this:

6. Changing utils.py — Almost there. Now this should have been good enough for the web app to run. But one last caveat we need to solve. When you upload an image in the web application, it comes to our applicatio as base64 encoded data. We are converting it into a PIL image format with a function named base64_to_pil which resides in our utils.py file in order for us to read and preprocess it with keras and opencv. Now, I don’t know how common knowledge this one is, but as I have some idea in VFX industry, I know this for a fact. JPG and PNG files differ in one aspect. PNG files has an extra channel to it apart from the customary R, G, B channels. This channel is called the ‘Alpha’ channel. In our dataset, we have both Jpeg and Png files. So the problem which arises when you upload a PNG file is as it contain an extra channel, the shape of the decoded image becomes (150x150x4) instead of (150x150x3) and this will raise an error. So the best way to solve this problem is to create a logic that if the image.mode is not RGB, then make it RGB.

So go to utils.py and rewrite the function base64_to_pil like this:

Now open the command terminal and run

python app.py

Go to http://127.0.0.1:5000/ and everything should work properly. :)

7. Upload an image and see the classification your model makes.

End Notes

I hope this article has been helpful to you. The major problem of this approach is there’s simply not enough data. But most likely we will get more data in the coming days and that will help us to help the humanity in our way, however tiny it might be. The collaborative effort of everyone of us is what’s required in these dark times.

It’s the fear which is the worst possible thing that could have happened in the present times. The fear of how many more got infected, the fear of exhaustion of basic needs of survival and the fear of what’s there in our fate tomorrow? In these times, I felt that maybe this is how I could offer my share of contribution towards the mitigation and annihilation of the largest epidemic we have ever seen in this century by a tiny fraction. If you think that you can add to this dataset or improve my procedure, please fork this repo and contribute and then make a pull request. If you think, someone you know who is in the medical field can get help from this, please let me know or if you can, execute this model for his/her benefit as per the requirement. This can be a very good sanity checker as this will predict in microseconds especially when time is of such essence.

Above all, we must be hopeful that we will get through this together. As they say: “The night is the darkest just before the dawn.”

You are not alone…

REFERENCES

  1. https://healthcare-in-europe.com/en/news/imaging-the-coronavirus-disease-covid-19.html
  2. https://towardsdatascience.com/detecting-covid-19-induced-pneumonia-from-chest-x-rays-with-transfer-learning-an-implementation-311484e6afc1
  3. https://arxiv.org/pdf/1804.02391.pdf
  4. https://www.kaggle.com/kmader

--

--

Sagarnil Das
Analytics Vidhya

Lead, Data Science at Hopscotch. Mentor at Udacity, Alumni of Robotics, AIND, DAND, MLND and Cloud Devops nanodegree.