Classifying COVID-19 X-ray images

gyiernahfufieland
Analytics Vidhya
Published in
9 min readSep 9, 2021
Chest X-ray image

Can you guess whether the X-ray image above belongs to someone with or without COVID-19? To be honest, I have no idea. But today, let’s create a CNN model which is able to classify these X-ray images !

Let’s get started!

Here’s where I got my dataset. And I have referred to this for the project.

The dataset was split into train, test and validation sets. For test and validation dataset, 60 images from each class will be used. There are 4 classes in the dataset: Normal, Viral Pneumonia, COVID, and Lung Opacity.

Image augmentation was further carried out for the created dataset. All images were resized to 224*224 based on the requirements of pre-trained models. Additionally, horizontal flip is given to random images. And finally, since models pre-trained on Imagenet will be used, the dataset was normalized using the mean and standard deviation of Imagenet.

Image Augmentation

Using the pytorch ImageFolder function, data augmentation is applied to the dataset. A data loader is also defined here, with a batch size of 20, 2 loader worker and image shuffling equal True.

Image Loader

Before we proceed, let’s further explore our dataset.

Data Exploration

Number of classes in the dataset
Number of images in each dataset

In the previous section, we described how we have applied augmentation to our dataset. Here, we create a plot to visualize them.

Plot augmented images
Augmented images

To further explore the dataset, we create a frequency count function which count the number of images for each class.

Calculating frequency of each image in each class

As seen in the result below, the validation and test dataset have equal number of images from each class. For the training dataset on the other hand, class imbalance can be observed.

Frequency count result
Number of Chest X-Rays vs Classes

Let’s create some functions that can help us in the downstream processes.

The following function gets total number of correctly predicted labels.

The following function retrieve predictions for all samples. The function iterates over the batches coming from the data loader, and the result of each batch will be concatenated to a prediction tensor which is defined as all_preds, that will be returned to the caller.

This function generates the confusion matrix based on labels and prediction.

The following function calculates True Positive, True Negative, False Positive, and False Negative using results derived from the confusion matrix. From here, evaluation metrics such as Accuracy, Recall, Precision, and F1-Score can be derived.

The following function is created to train our pre-trained models. model.train() sets the model to training mode during training stage. model.eval() is used during the evaluation phase. It is necessary to set the mode of model as batch normalization behave differently during training and evaluation phase. Under the training phase, zero_grad() is used to zero out gradients prior to backpropagation. This is done because Pytorch accumulate gradients on subsequent backward passes, denoted as loss.backward() below (Anon, 2021). Torch.no_grad() reduces memory usage and therefore, increasing computation efficiency.

As the time taken to train the model can be long, a checkpoint is defined which saves the model whenever validation loss has decreased. In addition, the running results during training is also saved for further plotting.

np.random has the tendency to generate the same random numbers for each data batch. To ensure randomness for each data batch, a worker init function is defined.

Alright ! We are now ready to train the models !

Training the models.

The WeightedRandomSampler() is used as there is a class imbalance dataset. This function provides equal probability to all classes. The class weight is first defined, and then sample weights are created.

The models will be trained for 25 epochs and use Cross Entropy as our loss function. The CrossEntropyLoss is the categorical cross-entropy loss for multi-class classification.

As these pre-trained models will be fine-tuned, the fully connected layer, which is the feature extraction layer has to be updated. Previously, as these models were trained on ImageNet, there were 1,000 classes defined. The out_features will be defined as 4 classes instead for this project.

The models will then be trained with Adam Optimization and learning rate of 0.0001. Similar steps will be repeated for all 4 models.

It took me about 30 minutes — 1 hour for each model, running on GPU in Colab. Densenet121 took the longest time among 4 models.

To further improve our result, let’s further tune our learning rate.

Tuning and Validation

Based on results from the models above, Densenet121 is the best performing model. This will be shown in the results section next. However, to further improve the models, a learning scheduler was created to decay learning rate of each parameter by 0.1 gamma for every 3 epochs (1029 iterations * 3 = 3147 iterations).

Learning Rate Scheduler for densenet121
Learning Rate Scheduler for resnet18

Plot_summary function defined to plot the accuracy and loss for training and validation dataset.

Plot_confmat function is defined to plot confusion matrix of the predicted vs actual results in training and testing dataset.

Results Evaluation

Densenet121

The lowest validation loss, 0.064216 is at epoch 6. As the number of epochs increases, we see that the model is starting to overfit with an increase of validation loss.

Densenet121 confusion matrix

Densenet121 can classify 60/60 COVID and Viral Pneumonia images respectively.

Densenet121 with Learning Rate Scheduler

The validation loss for Densenet121 model with learning rate scheduler is the lowest at epoch 13. Although the lowest validation loss, 0.065519 is slightly higher compared to Densenet121 model without learning rate scheduler, it can be observed that this model has better generalization compared to the previous model without learning rate scheduler.

Densenet121 with lr_scheduler confusion matrix

Densenet121 with lr_scheduler can classify 60/60 Viral Pneumonia images.

Running loss and accuracy for Resnet18

The validation loss is the lowest at epoch 6, at 0.064792. As number of epochs increases, overfitting can be observed.

Resnet18 confusion matrix

Resnet18 performs well in classifying Viral Pneumonia images.

Running loss and accuracy for Resnet18 with lr_scheduler

The validation loss for Resnet18 model with learning rate scheduler is the lowest at epoch 19. Although the lowest validation loss, 0.069836 is slightly higher compared to Resnet18 model without learning rate scheduler, it can be observed that this model has better generalization compared to the previous model without learning rate scheduler.

Confusion Matrix for Resnet18 with lr_scheduler

The Resnet18 with lr_scheduler shows good performance in classifying COVID and Viral Pneumonia.

Running loss and accuracy for Squeezenet

The lowest validation loss for Squeeznet is 0.117238 at epoch 24.

Squeezenet confusion matrix
Running loss and accuracy for Resnet50

The lowest validation loss for Resnet50 is 0.078578 at epoch 11.

Resnet50 confusion matrix

Resnet50 performs well in classifying Viral Pneumonia images.

Visualization and Critical Analysis

To better understand which part in the image that influences the model’s decision during classification, Grad-CAM will be implemented.

apply_mask is a function that creates heatmap from mask and synthesize GradCAM images.

With the generated GradCAM images, plot_gradcam is used to plot the GradCAM images for each class.

Registers a forward and backward hook to store forward pass(activations) and backward pass (gradients) at that layer.

The image below shows how each model comes up with their classifying decision.

Localization with GradCAM
Results in evaluating COVID-19 images
Results in evaluating Lung Opacity images
Results in evaluating Normal images
Results in evaluating Viral Pneumonia images
Summary performance for all models

In this section, the performance of each model will be further evaluated. Based on evaluation metrics such as accuracy, recall, precision, and F1-Score, the Densenet121 and Resnet18 with lr_scheduler showed great performance in classifying COVID-19 images. For Lung Opacity and Normal images, Densenet121 with lr_scheduler outperforms the other models. Lastly, multiple models have shown good performance in classifying Viral Pneumonia images. They are, Densenet121 with lr_scheduler, Resnet18, Resnet18 with lr_scheduler and Resnet50.

The full project can be found here.

References

Anon (2018). np.random generates the same random numbers for each data batch #5059. [Online]. 2018. github. Available from: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562. [Accessed: 26 August 2021].

Anon (2021). Zeroing Out Gradients in Pytorch. [Online]. 2021. Pytorch. Available from: https://pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html. [Accessed: 26 August 2021].

Inkawhich, N. (n.d.). Finetuning Torchvision Models. [Online]. Available from: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html.

Minaee, S., Kafieh, R., Sonka, M., Yazdani, S. and Jamalipour Soufi, G. (2020). Deep-COVID: Predicting COVID-19 from chest X-ray images using deep transfer learning. Medical Image Analysis. 65.

Priyavrat, M. (2021). Classification and Gradient-based Localization of Chest Radiographs. [Online]. 2021. Github. Available from: https://github.com/priyavrat-misra/xrays-and-gradcam

Rahman, T. (2021). COVID-19 Radiography Database. [Online]. 2021. Available from: https://www.kaggle.com/tawsifurrahman/covid19-radiography-database.

--

--

gyiernahfufieland
Analytics Vidhya

从我的视野分享我爱的一切。Hey, how are you today?