Diagnosing Pneumonia with Deep Learning
Using PyTorch and CNNs to classify x-ray images that show traits of Pneumonia.
Utilizing computer vision and deep learning for diagnosing medical conditions is an extensive area of research and there have been several breakthroughs in this field. Classifying images based on some criteria into various categories is termed Image Classification. In this article, I’ll be explaining the concept of Binary Classification, a sub-set of Image Classification. We’ll develop a deep learning model using PyTorch and achieve good enough results with the use of a concept called Transfer Learning. The model would analyze and classify x-ray images into two categories — “Pneumonia” and “Normal”.
Let’s now take a stroll through the pipeline:-
As always, the first step would be to import the necessary libraries and packages. My framework of choice for this venture is PyTorch. The libraries are imported as shown down below :-
The dataset we’ll be using is called “Chest X-Ray Images (Pneumonia)” by Paul Mooney and can be found on Kaggle using this link. The dataset is organized into 3 folders (train, test, val) and contains sub-folders for each image category (Pneumonia/Normal). There are 5,863 X-Ray images (JPEG) in total. It’s around 1-GB, but you can use the notebook environment built into Kaggle if you do not want to download the file.
Using the code above we can find that there are 3875 x-ray images showing traits of Pneumonia and 1341 x-ray images that are normal. Since the dataset is unbalanced we can use certain operations such as Horizontal Flip, Rotations, etc to account for this imbalance. This is known as Data Augmentation. The code below performs the same apart from converting the input into tensors and normalizing them. The values passed in for the normalization operation is standard for models trained on the ‘Image Net’ data. As stated earlier, we’ll be using transfer learning and I’ll explain more about this in the subsequent sections.
Also, do note that we do not perform certain operations on the test set or validation set as these, if applied, would lead to erroneous results.
We are going to load the data in batches. The following code utilizes the
Dataloader class under
torch.utils.data to do the same as shown below. The show batch function helps us visualize a batch of the image data after all the transforms have been applied.
We have the option to train our model using the GPU if such a compatible (CUDA enabled) component is present in your system. Deep Learning models are computationally intensive and it is usually better to train the models using a GPU instance as it leads to a quicker training time. The following helper functions aid us in determining whether a GPU is present and moves the data into it if one is indeed present.
Our data has now been pre-processed and is ready to be fed into the model. So, let’s talk a bit about the process involved in training.
A batch of the data is passed into the model and a forward-pass is performed. The loss is calculated between the predicted and actual values. The loss function I’ve used here is cross entropy which can be found in the
torch.nn.functional module. The loss values are used in the calculation of gradients which are further utilized to update the parameters of the model to improve it’s performance. All this process is done for some number of ‘epochs’. ‘Epochs’ is the number of times the model sees the entire training data.
As promised, let’s have a small discussion on transfer learning. As you know, there have been several research works in the field of Deep Learning. These works have led to the creation of novel architectures that perform well in tasks such as classification, object detection, segmentation, etc. Though we can create and use our own model architecture it rarely gives us very good results. One alternative for this is to use a model that was pre-trained on another dataset. We take that model’s architecture and weights and modify/tune the final layers according to our needs. This is known as ‘Transfer Learning’. You can read more about this here.
I chose to use the VGG16 model and build upon it to achieve the classification mechanism required for my task. Note that I also move the model to the GPU. Always ensure that both the data and the model are in the same instance (be it CPU or GPU). This is shown in the code snippet given below :-
Now that we have both our model and data, we create two more functions. One to train the model and another to evaluate the model. The fit function defined below aids us in training the model with the specified optimizer and learning rate for the stated epochs.
I chose the Adam optimizer for it’s known performance and faster convergence rate. The model was trained with the learning rates of 0.001 and 0.0001 for 5 epochs each. Though we only trained the model for a small number of epochs, it achieved a validation accuracy of 93.5%. That’s the advantage of employing transfer learning !
The major part of developing a model such as this one is it’s performance on unseen data. This is the ‘Test Set’ which has a diverse range of data that closely models the real world. The model we developed in this article achieved an accuracy of around 82% on the test set.
An example where the model diagnosed incorrectly :-
Though we obtained a decent level of accuracy, this is still not enough when it comes to deploying the model in real-time. We can further fine-tune the model or use other pre-trained model architectures to obtain better performance.
Final Thoughts :
- Before using a pre-trained model, I used a custom architecture of my own. It gave a training accuracy of 83% and test accuracy of 71%. Application of transfer learning increased the results greatly.
- You can try out the ResNet series of architectures and see how the results fare in comparison.
The notebook for this article can be found in this GitHub repository of mine. Feel free to fork/clone it to further improve the results yourselves.
I hope you found this article helpful and I’m glad to be a part of your journey :)
Feel free to comment and let me know what you think. You can also connect with me on LinkedIn.