Medical Image Classification and Segmentation — A Case Study Approach

Namrata Thakur
Analytics Vidhya
Published in
21 min readMay 15, 2021

--

Using AI to predict and eventually prevent the disaster of lung collapse

Credit: https://www.medicalnewstoday.com/articles/318110

Artificial Intelligence is impacting nearly every business. The enormous investments that are being poured into this domain as well as the rate at which the advancements are being done prove only one thing: AI is here to stay and it is going (already doing so..!) to disrupt the way of working for every industry.

The medical field is no different from these impacts. Different types of application are now in use that has some component of AI coded in it. Primarily, we see RPA (Robot Process Automation), chatbots, and different healthcare analytics as use cases for AI in this field. Another very important operation of AI is medical image classification and segmentation.

Most applications of AI in medicine read in some type of data, either numerical (such as heart rate or blood pressure) or image-based (the one that we will be discussing soon) as an input. The algorithms then learn from the data and churn out either a probability or a classification followed by segmentation of the infected portion. For example, the actionable result could be the probability of having an arterial clot given heart rate and blood pressure data or labeling an imaged tissue sample as cancerous or non-cancerous.

In this blog, we are doing to discuss a case study called ‘SIIM-ACR Pneumothorax Segmentation’ that includes detection of disease from chest X-rays.

Contents:

  1. Business Problem
  2. Mapping the real-world problem as a Deep Learning problem
  3. Data set Analysis
  4. Real-World Business Constraints
  5. Performance Metrics
  6. Existing Approaches
  7. My first cut approach
  8. EDA
  9. Pre-processing
  10. Modeling
  11. Final Pipeline
  12. Deployment
  13. Future Work
  14. References
  15. Github Repository link
  16. Linkedin profile
  1. Business Problem :

Before we start to proceed with our solution, let’s first understand this clinical condition. In normal situations, the lungs touch the walls of the chest. But sometimes, the air gets accumulated in the space between the chest wall and the lungs, i.e. in the pleural space. This air starts to pressurize the lung and gradually a portion of it or (sometimes) the entire lung may collapse. This medical condition is known as Pneumothorax. Pneumothorax is basically a combination of two words pneumo(air) and thorax(chest). Thus, it is also known as lung collapse.

A few different things can cause pneumothorax, and symptoms can vary widely. The causes of pneumothorax are categorized as either primary spontaneous, secondary spontaneous, or traumatic.

A primary spontaneous pneumothorax (PSP) occurs when the person has no known history of lung disease. The direct cause of PSP is unknown.

Secondary spontaneous pneumothorax (SSP) can be caused by a variety of lung diseases and disorders. SSP carries more serious symptoms than PSP, and it is more likely to cause death.

A traumatic pneumothorax is the result of an impact or injury. Potential causes include blunt chest injury or an injury that damages the chest wall and pleural space.

Pneumothorax is usually diagnosed by a radiologist on a chest x-ray, and can sometimes be very difficult to confirm. This is because symptoms of pneumothorax may hardly be noticeable at first and can be confused with other disorders. An accurate AI algorithm to detect pneumothorax would be useful in a lot of clinical scenarios. AI could be used to triage chest radiographs for priority interpretation or to provide a more confident diagnosis for non-radiologists.

2. Mapping the real-world problem as a Deep Learning problem :

The approach, which we are using in this case study, will first detect the presence of the disease in the inputted X-ray. If the condition is present, then it will segment highlight out the infected portion. So, in the first phase, we are doing a Binary Classification (pneumothorax present or not). Depending on the result of the classification model, we, in the second phase, are highlighting the infected portion in the image i.e. we are using the Image Segmentation technique.

Before going forward let’s understand the topic of Image Segmentation a bit more.

What is Image Segmentation?

Image segmentation is a way of classifying or segmenting different elements of an image into different classes. Though it sounds like object detection, it is actually more detailed than that. This is because we draw a bounding box around different objects/classes in the given image in object detection. But in image segmentation, we classify each pixel of the image into different classes. So we can achieve a more in-depth explanation using image segmentation.

Types of Image Segmentation:

Credit: https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/108126

1.Semantic Segmentation — — — — It identifies all instance of an object in the same color. For example: In the above picture it is segmenting all the persons into the same class and hence it uses the same color segment for all the persons.

2. Instance Segmentation — — — — It identifies each instance of an object in an image. For example: In the above picture, it is segmenting each person into a different class.

In our case study, we will be using the semantic segmentation type.

3. Data set Analysis :

The dataset is taken from the Kaggle competition page.

The data is comprised of images in DICOM format containing run-length-encoded (RLE) masks. The instances of pneumothorax are indicated by encoded binary masks in the annotations. Some training images have multiple annotations depicting the multiple locations of the event. Images without pneumothorax have a mask value of -1. The task is to predict the mask of pneumothorax in the given X-ray image.

a) Files given:

train-rle.csv, stage_2_sample_submission.csv (test_data), train_images, test_images.

b) Total File Size : 4GB

c) Total number of records: 12,954 (train_data), 3204 (test_data)

d) The train-rle.csv contains image IDs and their corresponding RLE masks and the test CSV file only contains the image IDs.

4. Real World Business Constraints :

a) Low latency is important.

b) Mis-classification/ miss-segmentation cost is considerably high as we are dealing with medical data and thus it is very sensitive to such errors.

5. Performance Metrics :

A) Segmentation Part:

  1. Dice Coefficient (IntersectionOverUnion/IOU) :

The Dice coefficient can be used to compare the pixel-wise agreement between a predicted segmentation and its corresponding ground truth. The formula is given by:

Credit: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

, where X is the predicted set of pixels and Y, is the ground truth. The Dice coefficient is defined to be 1 when both X and Y are empty.

2. Combo Loss — (Binary Cross Entropy + Dice Loss/ F1 loss):

For evaluation of segmentation models researchers have found that Binary Cross-Entropy and the Dice Loss is the best combinational loss function we can use. This combo loss is very helpful for problems with imbalanced datasets.

B) Classification Part:

  1. False Negative (FN):

Since we are dealing with medical data, we have to minimize the False Negative as the situation where the patient has the disease (ie true label = positive) but the model predicts him to be safe (ie predicted label = negative) can prove to be disastrous.

2. Precision or Positive Predictive Value (PPV) :

Positive predictive value or Precision is the probability that people with a positive prediction truly have the disease.

So, PPV of 0.71 means that if a person is predicted to have pneumothorax there is a 71% probability that he actually has the disease.

3. True Positive Rate (TPR) or Sensitivity or Recall :

Sensitivity or True Positive Rate is the probability that the model outputs positive given that the case is actually positive. It shows what portion of the positive class got correctly classified.

As we want the majority of the positive class to be correctly predicted. Hence, we need a high Sensitivity (TPR).

6. Existing Approaches :

A) U-Net to Predict Pneumothorax:

This kernel is using the vanilla unet structure for the segmentation problem. It used 6 encoder layers with neurons in 64–128–256–256–256–256 pattern and 5 decoder layers with neurons in 512–512–256–128–64 pattern. It found out that Combo loss is not performing well so it used the custom metrics like custom IOU (IntersectionOverUnion) as well as MeanIoU metric from tf.Keras.metrics package. Checkpoints used are ModelChekchpoint for saving the best weights and ReduceLROnPlateau with best patience as 5 and cooldown as 1. There are no augmentations used in this kernel.

B) Unet Xception Keras for Pneumothorax Segmentation:

This kernel uses the Unet++ architecture with pretrained EfficientNetB4 model as encoder and Residual blocks in the decoder part. He also used Stochastic Weight Averaging for the last 3 layers for better convergence. The augmentations used in this kernel were ElasticTransform, GridDistortion, OpticalDistortion, and the simple HorizontalFlip. He used images of size 256x256 in 16 batches. He used callbacks like ModelCheckpoint to save the best model and a custom LearningRateSchedular with cosine annealing.

7. My first cut approach:

a. Since the data are given are images, we cannot do much EDA on those. But the images are provided in .dcm format and this format stores a lot of metadata. Using these I intend to dig into the images to understand the quality of the dataset.

b. Once we have extracted and analyzed the metadata, we need to convert the format from .dcm to .png for further steps.

c. Also, in the data we are provided with rle (run-length encoding), storing runs of data as single data value and count, rather than original sum. We need to convert this RLE to masks that will be needed for the further segmentation process.

d. Once we have done the above steps, our data will be ready and different classification and segmentation algorithms can be applied to it.

e. Using the logs we need to see which model is performing better both in terms of performance speed and metric value.

f. We then need to deploy the best model so that we can build an end-to-end application around it.

8. Exploratory Data Analysis:

Let’s first read the CSV file that contains the RLEs and the ImageId.

Output

We see that there is information about the images in the form of individual ImageId and its corresponding EncodedPixels, i.e. the RLE (Run Length Encoding of the masks).

Let’s check the count of the records.

We have a total of 12954 train images. Out of these, we have some duplicates as evident from above. There is a total of 907 records that are duplicates. After dropping the duplicates we will have a total of 12,047 unique images.

The folder structure that we got after unzipping showed that each image is present in a separate folder. We need to move those images from their individual folder to a common directory that will hold all the images together. After that, let’s save the file path for each image in the .csv as a separate column. This will help us in accessing the images easily in the later stages.

The CSV after these steps looks like this:

All the train Dicom images are moved to this common path ‘Train Dataset/siim/train_dicom_images’.

Analysis of DICOM Images :

We have files in .dcm format. This format is commonly used in the medical imaging field and is known as DICOM(Digital Imaging and Communications in Medicine). Nearly all forms of medical imaging have become digitized nowadays and DICOM is the file format that is being used for storing such images ( e.g.: X-ray scan, and CT scan) along with the metadata.

Let’s plot some of the images to see how they look like.

Let’s examine one image in detail to see the type of metadata that it stores. For that, we have a library in python to work around DICOM images i.e. ‘pydicom’. We can install it simply with pip3 install pydicom.

We learn that there is a lot of information given for each patient. We are going to examine only some of the features like ‘Age’, ‘Gender’, ‘Modality’, and ‘View Position’.

Output

The column ‘Pneumothorax’ is created with the logic that if the value of the feature ‘EncodedPixels’ is ‘-1’, we assign ‘No’ else ‘Yes’.

Analyzing the ‘Gender’ field:

We have a majority of Male (M) Gender with 55% of the records. Female (F) consists of the remaining 45% of the records. We have only these two Genders represented in the data provided.

Analyzing the target (Pneumothorax) field:

As expected, the majority of the records (77.85%) do not have any pneumothorax occurrence recorded. A small percentage (22.15%) of records have the disease.

Because this is medical data, it is quite expected that it will be heavily imbalanced as it reflects the real-world scenario where the majority of the patients who get an X-Ray don't have the mentioned disease.

Analyzing the Occurrence of Pneumothorax and Gender together:

Out of that 55% record of Male, we have 77.5% as Healthy and 22.5% with pneumothorax detection. Out of the 45% record of females, we have 78.2% as Healthy and 21.8% with pneumothorax detection.

Analyzing the ViewPosition field:

We get different views of the chest by changing the orientation of the patient’s body and the direction of the x-ray beam. In the dataset, we see two positions — PA (posteroanterior) and AP (anteroposterior). In a PA view, the x-ray beam enters through the posterior (back) aspect of the chest and exits out of the anterior (front) aspect, where the beam is detected. In AP view, the positions of the x-ray source and detector are reversed. AP chest x-rays are harder to read than PA x-rays and are therefore generally reserved for situations where it is difficult for the patient to get an ordinary chest x-ray, such as when the patient is bedridden. We see a similar trend in the above graph where a significant majority of the records (60.38%) have PA with only 39.62% of records having AP.

Analyzing the ViewPosition and Occurrence of Pneumothorax as parameters together :

To summarise the plot:

There are 60.38% records with PA as viewPosition and 39.62% records with AP. Out of all the records for PA, 76.7% are Healthy with no Pneumothorax Detection. The remaining 23.3% records have the disease in them. Out of all the records for AP, 80% are Healthy with no Pneumothorax Detection. The remaining 20.4% of records have the disease in them.

Analyzing the Age field:

We see that the distribution of the ‘age’ peaked at 58 as it has the maximum count for it. A more common bracket of age is 50–60. We also see 1 record with an age of 413 and that could be an outlier.

9. Pre-processing:

Under the data pre-processing portion, we will do two things:

  1. Convert DICOM images to PNG
  2. Create Masks from RLEs given

Let's start with the first task, then :

  1. Convert DICOM images to PNG:

Since we cannot use .dcm images in the model building process so we need to convert the format to .png.

In this method, we are changing the image size from 1024x1024 to a smaller (and hence, more manageable in our limited computing resource) 256x256 size. Along with this, we are replacing the format from .dcm to .png and storing the new smaller size png images in a different folder.

2. Create Masks from RLEs given:

We have mask data in the form of Run Length Encoded Pixels that we need to convert to .png format so that they can properly become the ground truth for each PNG image given. Organizers have provided a function for creating masks using RLE to the pixels.

Now, let’s visualize one PNG image with the ground truth mask:

Sample X-ray with the mask

Our result should give the Masked image (extreme right one) given the chest x-ray.

10. Modeling:

Note: The objective of this study is not to achieve great accuracy or outperform any model, but to explore the hows and whys of the model behavior and also to experiment with different architectures.

As mentioned in Section 2, we first need to classify the images as Pneumothorax or not. Based on the result we will proceed with the segmentation. So, first, we will build a Binary Classification model as the first step.

A) PART 1 — Pneumothorax Classification :

First, we classify the Chest X-Rays as either No Pneumothorax Present (Label:0) or Pneumothorax Present (Label:1). This part is the image classification part where we are applying transfer learning technique using the pre-trained model CheXNet (a 121 layer DenseNet model that is fine-tuned on Chest X-Ray images) to classify the images.

To do this binary classification task, we need the ground truth as binary labels. Currently, we have the ground truths as either RLEs (as given) or Masks (as converted above). So, we need to create the binary labels from the format given. We observe that in the images that do not have any pneumothorax condition the ‘EncodedPixels’ column has a value of -1. Else it has the necessary RLEs. So, using this knowledge we create the binary labels such that we assign the label 1 for the images that have non-negative RLEs else label 0.

With this now our data is the format for a classification model. Let’s split the data now in Train and CV.

imagePath contains the full location path of individual images in the format Train Dataset/siim/train_png_images/imageID.

Total Train Dataset Size: 10842 & Total Valid Dataset Size: 1205.

Class Distribution in Train Dataset

Handling Class Imbalance:

We have imbalance data as the Positive class count far outdoes the count for the negative class. This is a pretty common incident in the medical domain. We need to handle this data imbalance problem so that the model doesn't become too biased for the majority class. To make sure that both the classes contribute to the loss equally, we give weights to each class. This weight is inversely proportional to the frequency of the class. We calculate this frequency as count_of_class/total_records_present.

Class Frequencies

As expected, the positive class has far fewer records present and hence the frequency is also significantly less than that of the negative class. With this, we have now assigned weights to each class. Just to confirm, let’s check the contribution of each class on the loss so that the model remains unbiased to either class. We calculate this contribution of each class as the product of the class_weight and the class_frequency.

Now, our model will not be biased due to the class imbalance.

Handling Augmentations:

Data augmentation is used to increase the amount of data by adding slightly modified copies of already existing images. It acts as a regularizer and helps reduce overfitting when training a model. The augmentations we use here are random_flip_left_right, random_contrast, random_brightness, random_saturation, random_hue, and adjust_gamma from the tf. image pipeline. We also use random augmentations so that every image faces a random choice of augmentation.

With this, our pipeline is built for the classification model.

We are using the CheXNet model to build our classifier. It is a 121 layer DenseNet that is pre-trained on chest x-ray images (Ref: https://arxiv.org/abs/1711.05225).

Let’s understand DenseNet a bit more before building our architecture.

DenseNet is architecture for image classification & object recognition. It is quite similar to ResNet architecture though it has some fundamental differences. ResNet uses an additive method (+) that merges the previous layer (identity) with the future layer, whereas DenseNet channel-wise concatenates (.) the output of the previous layer with the future layer.

DenseNet: https://arxiv.org/pdf/1608.06993.pdf

Let’s see the model structure that we have created on top of DenseNet. We have directly loaded the weights of ChestXNet for our DenseNet-121 architecture. Along with this, we also added some extra layers while fine-tuning our model.

As we can see, while defining DenseNet121 we gave weights = None and include_top = False.

This is because we are not using the imagenet weights (DenseNet is originally trained on the imagenet dataset) instead we are using the CheXNet weights that are loaded separately later.

By specifying the include_top=False argument, we load a network without including the classification layers at the top. This is because ChexNet is trained to classify 14 classes among which one is Pneumothorax. Since we are dealing with only 1 class here, we don't need to include the top layer while loading the architecture.

In the extra layers, we added a GlobalAveragePooling2D and a series of BatchNormalization, Dense, and Dropout layers to further fine-tune the model.

We used the below callbacks:

We are using callbacks like ModelCheckpoint — for saving the best model having the maximum validation AUC, EarlyStopping — for stopping the training if the metric value (val_auc) doesn't change for 6 epochs, and ReduceLRonPlateau — for reducing the learning rate if the metric value doesn't change for 3 epochs.

Now that we have trained the model let’s check the metric values to decide how the model is performing.

Evaluating Metrics Values:

AUC-ROC value:

We got an AUC Score of 0.908 on the CV Dataset.

Confusion Matrix

We got a Validation Accuracy of 86.1%.

Sensitivity (TRUE POSITIVE RATE) is the probability that the model outputs positive given that the case is actually positive. It shows what portion of the positive class got correctly classified. This is also the Recall.

Negative predictive value (NPV) is the probability that people with a negative prediction truly don’t have the disease. So, NPV of 0.914 means that if a person is predicted to not have pneumothorax there is a 91.4% probability that he doesn't actually have the disease.

Sensitivity and NPV are the most important metric values. Since we are dealing with medical data, we have to minimize the False Negative as the situation where the patient has the disease (ie true label = positive) but the model predicts him to be safe (ie predicted label = negative) can prove to be disastrous. We also need to maximize True Positive as we want the majority of the positive class to be correctly predicted. Hence, we need a high Sensitivity (TPR) and a high NPV.

Precision_Recall Curve

We got an Average Precision Score of 0.731 on the CV Dataset. The Precision-Recall curve (PRC) shows the trade-off between precision and recall. A high area under the curve represents both high recall and high precision, where high precision relates to a low false-positive rate, and high recall relates to a low false-negative rate. High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall).

The model is performing well as indicated by the metric values.

With this, we complete our classification part.

B) PART 2 — Pneumothorax Segmentation:

In the second section and also the most important one, we build the segmentation model where we apply different architectures (Nested Unet, and Double Unet with pre-trained VGG19 as backbone encoder) to predict the masks.

  1. UNet++ : Nested UNet architecture for Medical Image Segmentation:

UNET++ has many similarities with UNET as both have the encoder-decoder architecture.

Let’s now see the difference between UNET and UNET++:

Ref: Biomedical Image Segmentation: UNet++ by Jingles (Hong Jing)

UNET++ (Nested UNET) has the following advancements over UNET:

a) It has re-designed skip pathways (shown as green)

b) It has dense skip connections (shown as blue)

c) Deep Supervision (shown as red)

Nested Unet

a) Re-designed skip pathways (shown as green):

The convolution layers that are present in the re-designed skip pathways (as shown in green) are used to bridge the semantic gap between the feature sets of the encoder and the decoder part.

Skip connections in UNET connect the feature maps of the encoder of a particular layer with the upsampled feature maps of the corresponding decoder layer. This can result in infusing or connecting dissimilar feature maps. This is remedied in UNET++. The output of the previous convolutional layer of the same dense block is fused with the corresponding up-sampled output of the lower convolutional layer. This makes the resultant encoder feature maps more similar to those of the corresponding decoder. Thus, optimization is easier when semantically similar feature maps are received.

b) Dense skip connections (shown as blue):

Dense skip connections (shown in blue) have implemented skip pathways between the encoder and decoder. These Dense blocks are inspired by DenseNet with the purpose to improve segmentation accuracy and improves gradient flow. These connections ensure that all prior feature maps are present in the current node. This generates full resolution feature maps at multiple semantic levels.

c) Deep Supervision (shown as red):

Ref: https://arxiv.org/pdf/1807.10165.pdf

Deep supervisions (shown in red) are added, so that model can be pruned to adjust the model complexity, i.e. to balance between prediction time and performance.

To understand UNet++ better then you can visit the paper, here.

Below, is the code of vanilla UNET++ without any pre-trained backbone.

Below, is the combo loss:

Compiling the model:

Training the model:

Let’s plot the metrics to see how our model is performing:

Evaluation Metrics

As we can see the loss is decreasing per epoch but very slowly. The Dice Score is not improving much beyond 0.21. The solution is that we should train the model for further epochs, change the learning rate more aggressively, and use a pre-trained model as the encoder backbone.

As we have stated earlier, the objective of this blog is to experiment with different architectures and see which one performs better and why?

So, keeping that in mind we will not be adding any pre-trained backbone in our previous architecture. Instead, we will try with Double Unet structure with VGG19 as its encoder backbone.

2. Double UNET:

As the name suggests, this is a combination of two UNET models one after another. You can check the paper here.

A short explanation of architecture by the author:

DoubleU-Net starts with a VGG19 as encoder sub-network, which is followed by a decoder sub-network. In the network, the input image is fed to the modified UNet(UNet1), which generates predicted masks (i.e., output1). We then multiply the input image and the produced masks (i.e., output1), which acts as an input for the second modified U-Net(UNet2) that produces another the generated mask (output2). Finally, we concatenate both the masks (output1 and output2) to get the final predicted mask (output).

Below is the architecture:

DC-UNET

We can see that the above architecture uses 4 encoder blocks and 4 decoder blocks. More number of such blocks increases the total trainable parameters and hence need more computational power to train properly. Since we have a limitation in accessing such computing power, we will use only 2 encoder and 2 decoder blocks in our architecture.

The TensorFlow code for the model building is quite large and is present in my Github Profile and hence I am not adding it here.

Let’s check the metrics:

Compared to UNET++ the loss is decreasing very well at the end of 30 epochs. The Dice score is still not very good. We can train further epochs to see better performance. But for the sake of time, we will stop here only.

Let’s see the prediction of this model:

RED: Predicted Mask; GREEN: Original Mask

We see that the model is correctly predicting when there is no mask. For the cases where masks are present the prediction is more or less similar to the ground truth. We can obviously improve the performance by handling the data imbalance better or training the model for more epochs or by applying the other techniques mentioned above.

10. Final Pipeline:

To build the final inference pipeline, we saved the above two model weights.

Let’s build the pipeline now:

Let’s see some predictions:

Positive Predictions:

Negative Predictions:

Failure Cases:

It is important to show some failure cases. Those cases where the classifier is predicting that pneumothorax is present but the segmentation model fails to highlight any segments.

Failure Case

12. Deployment :

Let’s build an end-to-end application around the model and deploy it using Flask APIs. Below is the video of the deployed version of the model running:

Deployed Video

13. Future Work:

a) We can use different methods available for handling data imbalance, better augmentation techniques, etc.

b) Using a pre-trained backbone for UNET++ and also using the deep supervision = true option to check how the model performance improves.

16. LinkedIn Profile Link:

https://www.linkedin.com/in/namrata-thakur893

Thanks for reading the blog. I would like to thank my mentor and also the whole AppliedAiCourse team.

That’s it for this case study. If you have any suggestions to improve it please leave them in the comments..!!

--

--