Detecting pneumonia on X-ray images — ConvNets and Transfer Learning

Unit8
Unit8 - Big Data & AI
5 min readFeb 12, 2019

--

Each year, pneumonia claims about one million victims globally. In general, the risk is most severe for children, though there are regional disparities as well. In sub-Saharan Africa, pneumonia kills half a million children annually.

According to UNICEF, pneumonia causes more deaths in children than HIV/AIDS and malaria combined. One of the most urgent needs in combatting pneumonia is a cost-effective manner of diagnosing X-ray images. Current methods require highly skilled personnel and are impractical for many regions of the world.

Some of the main goals of our company, aside from delivering top-quality solutions, is to have positive impact on the world and keep improving our skills. This was a perfect shot to kill two birds with one stone. We decided to publish results of our experiments as a learning material for others in the form of a workshop — you can find all our materials on this github repository.

Problem statement

Take a look at those two photos. Which one do you think is infected with pneumonia?

If you are after medical training you have probably guessed right. For us — we couldn’t really spot any significant difference between the pictures. The problem can be in fact tricky even for professionals and might require additional tests.

We therefore were super-curious to see if neural network, provided with properly labelled data would be able to see patterns that could split patients into 2 classes — normal/pneumonia.

Dataset

Our journey started with Kaggle dataset available from here [1]. It consists of 5'863 X-ray images of lungs taken on a group of paediatric patients that are 1–5 years old. All of images have been labeled by 2 specialists to minimize labeling error risk. The data was already split for us into training, validation and test datasets.

Our approach

Convolutional Neural Network (CNN, or ConvNet) seems to be obvious choice given recent achievements in image classification leveraging deep learning.

CNNs work in a way that seems to be slightly similar to the way visual cortex is operating for humans and animals. We will not provide a yet-another explanation of how the ConvNet works — there is already good a number of great articles on that topic. For example — have a look here [2].

Since training of large neural networks takes immense amounts of computational power (not to mention time) we decide to use transfer learning. This concept boils down to reusing parts of the already trained networks in one domain to classify images in another domain. I our case we decided to use network trained on generic images from ImageNet database and tune it to detect pneumonia.

We leverage one of the most basic pre-trained networks called VGG16 [3].

Preview of VGG16 network [source]

We therefore hope that features that the network is able to detect (shapes, curves, lines) are good enough to be able to also grasp the difference between healthy and infected lungs. This way we are reusing pre-trained part of the network that is responsible for features learning. We only need to retrain the classification part. Instead of classifying shapes into ImageNet tags, we try to classify them into 2 classes — normal/pneumonia.

To give you some impression on the scale we are discussing here VGG16 has over 124M parameters. Transfer learning significantly reduces our computational load to approx. 16k parameters to train.

We used Tensorflow for preprocessing of images and chaining operations with Tensorflow input pipelines. For the training itself we chose the easy option and used VGG16 that is provided with Keras library.

import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Flatten, Dense
from keras.applications.vgg16 import VGG16

NUM_CLASSES = 2

base_model = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))

x = base_model.output
x = Flatten()(x)
x = Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=x)

model.summary()

In order to reduce number of trained parameters we will block first 20 layers of the model (up to the Flatten phase).

for layer in model.layers[0:20]:
layer.trainable = False

Training the model is as simple as

model.fit(
x=x_train, y=y_train,
validation_split=0.2,
shuffle=True,
batch_size=64,
epochs=20,
verbose=1
)

Just after the first attempt, we’ve arrived at 80% classification accuracy. Then we used several tricks in other to improve the end performance of the model. We introduced:

  • The weighted cost function to prevent bias towards more popular class in the in-balanced dataset (2.5x more examples of PNEUMONIA class than NORMAL)
  • CNN regularisation with additional Dropout and Batch Normalization layers to avoid overfitting to the test examples
  • Data augmentation with ImageDataGenerator [5] to increase the variety of images by manipulating them in several ways — flipping, resizing, random-rotations, etc.

We eventually achieved 91% accuracy on the test dataset.

Not too bad for just few evenings of work. This being said, we need to acknowledge the obvious limits of this early work. Especially lack of larger scale validation of the results which might be leading to issues such us ones described here [4].

The detailed notebook containing our work (and a few additional exercises we did before) can be found on our github repository. We encourage you to clone the repo, follow our tutorials and play a bit with the network by yourself!

[1]: https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia/home

[2]: https://medium.freecodecamp.org/an-intuitive-guide-to-convolutional-neural-networks-260c2de0a050

[3]: http://www.image-net.org/

[4]: https://healthitanalytics.com/news/deep-learning-for-medical-imaging-fares-poorly-on-external-data

[5]: https://keras.io/preprocessing/image/

In case you are interested about other cool things we do in Unit8 visit our website http://unit8.co

--

--

Unit8
Unit8 - Big Data & AI

Doing amazing things with Software, Big Data and AI. Aiming for ∞ Impact. http://unit8.co