Lung Localization with Pytorch Lighnining

Sandaruwan Herath
Data Science and Machine Learning
13 min readApr 23, 2024

Over the last decade, deep learning, particularly Convolutional Neural Networks (CNNs), has significantly advanced the field of medical imaging and these advancements include improved accuracy in image classification, segmentation, and enhancement.

DALL-E

Lung localization is a pivotal pre-processing step in analysing thoracic images for various clinical applications, including detecting pathological findings like nodules, tumours, and other anomalies.

Convolutional Neural Networks (CNNs) have become a powerful tool in medical image analysis, especially for tasks like lung localization. These specialized deep learning models automatically and adaptively learn different levels of detail, like patterns within patterns, from image data. This makes them exceptionally well-suited for image recognition tasks where preserving the relationships between features is crucial, such as identifying subtle abnormalities in medical images.

Why CNNs Shine in Lung Localization

Feature Learning

Unlike traditional machine learning approaches that require hand-crafted features, CNNs can learn features directly from the pixel data. This benefits lung localization, where manual feature engineering can make the relevant features highly complex and not easily definable.

Feature Learning [https://www.analyticsvidhya.com/blog/2021/05/convolutional-neural-networks-cnn/]

Hierarchy of Features

Toxicity Prediction Using Deep Learning[6]

CNNs learn a hierarchy of features in a way that mimics the human visual system. In lung localization, lower layers may detect edges and blobs, while deeper layers can identify more complex structures pertinent to the lungs, such as the bronchial tree and vascular structures.

Translation Invariance

Using pooling layers, CNNs achieve translation invariance, meaning they can recognize lung patterns regardless of their position in the image. This is particularly important for medical imaging, where the position and orientation of the lungs can vary between scans.

Translation Invariance [https://pyimagesearch.com/2021/05/14/are-cnns-invariant-to-translation-rotation-and-scaling/]

Automated End-to-End Learning

CNNs provide an end-to-end learning approach where raw images can be inputted into the network, and through a series of convolutional and pooling layers, the network outputs the localization of the lungs.

CNN Feature Learning Process [https://towardsdatascience.com/basics-of-the-classic-cnn-a3dce1225add]

Pytorch Implementation

The implementation of lung localization with CNNs typically involves the following steps:

Image Dataset: Our evaluation employed a dataset comprising Computed Tomography Pulmonary Angiography (CTPA) images, specifically curated for Pulmonary Embolism (PE) detection. These images were sourced in DICOM format, a standard for handling, storing, printing, and transmitting information in medical imaging. The dataset was structured and accessed based on a predefined CSV file containing relevant metadata.

Data Preprocessing: The preprocessing stage involved standard procedures to prepare the CTPA image dataset for analysis. This included normalization, resizing, and augmentation techniques to ensure the models were trained on data that is representative of various clinical scenarios.

Model Application: Both the ResNet50 and EfficientNetB0 models were applied to the dataset, with the training process spanning 100 epochs. The ensemble of these two models was aimed at leveraging their complementary strengths, thus enhancing the overall performance in lung localization.

Setting Up Environment and package installation

# Basic data handling and visualization libraries
!pip install numpy pandas matplotlib
# PyTorch and torch-vision for deep learning
!pip install torch torchvision
# Image augmentation library and PyTorch Lightning for structured deep learning code
# TensorBoard for visualization
!pip install imgaug pytorch_lightning tensorboard
# Libraries for handling medical images
!pip install pydicom pylibjpeg pylibjpeg-libjpeg
# OpenCV for image processing - using headless version to avoid GUI issues on headless servers
!pip install opencv-python-headless
# Upgrading the OpenCV library to the latest version
!pip install - upgrade opencv-python
# IPython kernel for Jupyter
!pip install ipykernel
# GDCM for reading DICOM files that are compressed with JPEG
!pip install gdcm -c conda-forge

Data Set

The RSNA Pulmonary Embolism (PE) Detection Challenge dataset is a comprehensive collection of CT scans designed for the development of deep-learning models to detect pulmonary embolism. Provided by the Radiological Society of North America, this dataset includes numerous high-resolution CT images annotated by experienced radiologists, making it ideal for training robust diagnostic models. Each CT scan in the dataset is represented in DICOM format and is accompanied by detailed CSV files containing patient IDs, scan details, and precise annotations regarding the presence of pulmonary embolism.

Given the dataset’s extensive volume and variety, it is particularly valuable for researchers looking to implement and test deep learning algorithms. The annotations include binary labels for PE presence in each image slice, along with additional metadata like patient demographics and scan characteristics, which are crucial for nuanced analysis. Utilizing a sample from this dataset allows for manageable computational demands while maintaining a broad spectrum of data critical for achieving high accuracy and generalizability in PE detection models.

CSV File Content

The dataset usually comes with accompanying CSV files that contain metadata and annotations. Key components of these files include:

Patient IDs: Unique identifiers for each patient.

Study and Series IDs: Identifiers for each CT scan session.

Labels: Binary labels indicating the presence or absence of pulmonary embolism in each image slice.

Location Data: Information on the specific location of embolisms within the lung arteries, if present.

Additional Features: May include demographic information like age and sex, and technical aspects of the scan like the slice thickness and imaging equipment.

Suitability for Deep Learning Projects

This dataset is particularly well-suited for deep learning projects due to its size, diversity, and detailed annotations. These characteristics enable the development of robust models capable of detecting subtle variations in imaging associated with different stages and manifestations of pulmonary embolism.

Data Preprocessing

The PrepareData class is specifically tailored for preparing and processing CT scan images for lung detection models, leveraging deep learning techniques. The process begins with manual annotation, where bounding boxes are drawn around the lung areas in a set of sample images. These annotations, which capture the coordinates of lung regions, are stored in a CSV file named lung_bbox.csv. Alongside this, a train.csv file provides additional training labels and metadata necessary for processing.

Initialization and Configuration

Upon initializing the PrepareData class, it loads the lung region coordinates and training data from their respective CSV files. The class is configured with paths to both the source DICOM images and the destination directory for processed image outputs. It also initializes variables to track image statistics, which are crucial for normalizing the data to improve model accuracy and training efficiency.

Image Processing Methodology

The prepare_data method iterates over each image listed in the bounding box CSV file. It uses the identifiers from the CSV to locate and read the DICOM files corresponding to each image. Key to enhancing the image data for lung detection, a windowing function is applied to adjust the image intensities. This step helps to emphasize the lung areas within the CT scans, making the lung tissues more distinguishable from the surrounding anatomy.

Rescaling and Data Segregation

After applying the windowing technique, the images are resized to a uniform dimension of 512x512 pixels, converted to a floating-point format to conserve memory, and saved as NumPy files. The images are allocated to either a training or validation dataset based on a predetermined index, ensuring a balanced distribution for effective model training.

Statistical Analysis and Saving Processed Data

Throughout the processing, the class calculates the sums and squared sums of pixel values for images in the training dataset. These statistics enable the normalization of the dataset, ensuring consistent input across all training images. Finally, identifiers for the images designated to the training and validation datasets are saved, setting the stage for subsequent phases of model training and validation.

Python Codes

Importing Libraries

import pydicom
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
import pandas as pd

Defining PrepareData Class for data Preprocessing

class PrepareData:
def __init__(self, bbox_labels_path, training_labels_path, root_path, save_path):
# Load CSV files containing bounding box labels and training labels
self.bbox_labels = pd.read_csv(bbox_labels_path)
self.training_labels = pd.read_csv(training_labels_path)
self.root_path = Path(root_path)
self.save_path = Path(save_path)
self.sum_pixel_values, self.sum_squared_pixel_values = 0, 0
self.train_ids = []
self.val_ids = []
self.image_area = 512 * 512 # Assuming all images are resized to 512x512

@staticmethod
def apply_windowing(image, window_center, window_width):
# Apply windowing function to enhance the contrast in the image
lower_limit = window_center - (window_width / 2)
upper_limit = window_center + (window_width / 2)
windowed_image = np.clip(image, lower_limit, upper_limit)
# Normalize the image to the 16-bit range
windowed_image = (windowed_image - lower_limit) * (65535 / window_width)
return windowed_image.astype(np.uint16)

def prepare_data(self):
for counter, image_id in enumerate(tqdm(self.bbox_labels.Image)):
# Extract metadata for locating the DICOM file
row = self.training_labels[self.training_labels['SOPInstanceUID'] == image_id]
study_id = row['StudyInstanceUID'].values[0]
series_id = row['SeriesInstanceUID'].values[0]

# Construct path to the DICOM file and load it
dcm_path = self.root_path / study_id / series_id / image_id
dcm_path = dcm_path.with_suffix(".dcm")
dcm = pydicom.read_file(dcm_path)
dcm_array = dcm.pixel_array

# Parameters for windowing function
window_center = -600
window_width = 1500

# Process image using windowing and normalization
windowed_image = self.apply_windowing(dcm_array, window_center, window_width)
normalized_image = (windowed_image / np.max(windowed_image)) * 65535
resized_image = cv2.resize(normalized_image.astype(np.uint16), (512, 512)).astype(np.float16)

# Determine if the image is part of the training or validation set
dataset_type = "train" if counter < 20000 else "val"
if dataset_type == "train":
self.train_ids.append(image_id)
else:
self.val_ids.append(image_id)

# Save the processed image
dataset_path = self.save_path / dataset_type
dataset_path.mkdir(parents=True, exist_ok=True)
np.save(dataset_path / image_id, resized_image)

# Calculate image statistics if part of the training set
if dataset_type == "train":
self.sum_pixel_values += np.sum(resized_image) / self.image_area
self.sum_squared_pixel_values += np.sum(np.square(resized_image)) / self.image_area

# Save the list of training and validation IDs
np.save('../lungDetection/train_subjects.npy', self.train_ids)
np.save('../lungDetection/val_subjects.npy', self.val_ids)

return self.sum_pixel_values, self.sum_squared_pixel_values

This methodical approach ensures that the dataset is well-prepared for training deep learning models aimed at detecting lung regions in CT scans, effectively leveraging the detailed annotations provided in the bounding box CSV file.

After Data preparation, we need to implement a lung localization model using the prepared dataset, which involves several key steps, from setting up your model architecture to defining the training loop, and finally evaluating the model’s performance. However, Using PyTorch Lightning can streamline your model training process by abstracting a lot of boilerplate code involved in running a PyTorch model. This allows you to focus more on the experiment itself rather than the mechanics of the training loop.

Lung Localization Model Implementation

Below, I’ll outline a general approach using a convolutional neural network (CNN), which is commonly used for image-based tasks in deep learning due to its effectiveness in extracting hierarchical features from images.

Import Libraries

from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision import transforms, models
from torchvision.utils import make_grid
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from imgaug.augmentables.bbs import BoundingBox
from imgaug.augmenters import Sequential, GammaContrast, Affine
from torch.utils.data import Dataset, DataLoader
import cv2

Declare constants

# Constants for image normalization
MEAN = 143.57 / 255
STD = 123.49 / 255

LungLocalizeDataset Class

The LungLocalizeDataset class in the code is designed to handle the specific data loading and preprocessing needs for a lung localization task using CT scan images. This custom dataset class is crucial for several reasons:

Data Loading

Preprocessing

Data Augmentation

Conversion to Pytorch Tensors

Integration with Pytorch DataLoader

class LungLocalizeDataset(Dataset):
def __init__(self, labels_csv_path, patient_data_path, root_dir, augmentation=None):
self.labels = pd.read_csv(labels_csv_path)
self.patients = np.load(patient_data_path)
self.root_dir = Path(root_dir)
self.augmentation = augmentation

def __len__(self):
return len(self.patients)

def __getitem__(self, idx):
patient_id = self.patients[idx]
data = self.labels[self.labels["Image"] == patient_id]
bbox = [data[f"{dim}"].item() * 512 for dim in ["Xmin", "Ymin", "Xmax", "Ymax"]]
image_path = self.root_dir / f"{patient_id}.npy"
image = np.load(image_path).astype(np.float32)

if self.augmentation:
image, bbox = self.apply_augmentation(image, BoundingBox(*bbox))

image = (image - MEAN) / STD
return torch.tensor(image).unsqueeze(0), torch.tensor(bbox)

def apply_augmentation(self, image, bbox):
random_seed = torch.randint(0, 100000, (1,)).item()
image_aug, bbox_aug = self.augmentation(image=image, bounding_boxes=bbox)
return image_aug, bbox_aug.to_xyxy_array()

LungDetectionModel Class

The LungDetectionModel class in the code leverages PyTorch Lightning to define and train a deep learning model specifically for lung localization tasks using CT scan images.

Model Architecture

This class modifies a pre-trained ResNet50 model, a popular deep convolutional neural network known for its effectiveness in image recognition tasks. The modifications include:

Adapting the Input Layer: The first convolutional layer is adjusted to accept single-channel (grayscale) images typical of CT scans, rather than the three-channel (RGB) images for which ResNet50 is originally designed.

Customizing the Output Layer: The final fully connected layer is replaced with a new sequence of layers that output four values. These four values represent the coordinates of a bounding box around the lungs in the CT image.

Forward Pass

The outputs of the network, which are the predicted bounding box coordinates, are clamped to the valid range of the image dimensions (0 to 512 pixels). This ensures that the model predictions are feasible bounding box coordinates within the image.

Loss Function

The model uses the Smooth L1 loss, which is less sensitive to outliers than the Mean Squared Error loss. This characteristic makes it suitable for regression tasks like predicting bounding box coordinates.

Optimizer and Learning

The model uses the Adam optimizer, which is known for its efficiency in converging faster than many other types of optimizers.

Rate Scheduler

This learning rate scheduler reduces the learning rate when a metric has stopped improving, which helps in fine-tuning the model during the later stages of training to achieve better performance.

Integration with PyTorch Lightning

PyTorch Lightning enhances the basic functionality of PyTorch by providing a structured way to organize the training process, manage device placement, and scale the computation. It also simplifies repetitive tasks like checkpointing and logging, making the model easier to develop, maintain, and reproduce.

class LungDetectionModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.resnet50(pretrained=True)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model.fc = torch.nn.Sequential(
torch.nn.Linear(2048, 1024),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(1024, 4) # Output 4 coordinates for the bounding box
)
self.loss_fn = torch.nn.SmoothL1Loss()

def forward(self, x):
return self.model(x).clamp(min=0, max=512) # Ensure bounding box coordinates are valid

def training_step(self, batch, batch_idx):
images, targets = batch
predictions = self(images)
loss = self.loss_fn(predictions, targets.float())
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
images, targets = batch
predictions = self(images)
loss = self.loss_fn(predictions, targets.float())
self.log("val_loss", loss)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}

Overall, the LungDetectionModel the class encapsulates a sophisticated deep-learning workflow tailored for lung detection in medical images, optimizing it for both performance and ease of use. This setup allows researchers and developers to focus more on model experimentation and less on boilerplate code.

Training the Model

The training process for the LungDetectionModel is a structured and automated procedure facilitated by PyTorch Lightning, which significantly simplifies many aspects of running a deep learning training loop. Here’s a detailed breakdown of the components involved in this training setup:

Model Initialization

The LungDetectionModel is instantiated, which prepares a neural network based on the modified ResNet50 architecture. This model is specifically tailored to output bounding box coordinates for lung detection in CT images.

Checkpointing

  • ModelCheckpoint Callback: This is a PyTorch Lightning feature that automatically saves the model at specific intervals. Here, it’s configured to monitor the validation loss ("Val Loss"), and it saves the top 5 models with the lowest validation loss. This ensures that even if the model's performance degrades over time due to overfitting or other issues, you will retain access to the best performing versions of the model.

Trainer Configuration

  • Accelerator and Devices: The model is set to train using GPU acceleration (accelerator='gpu'), with the training assigned to 1 GPU device. This allows for faster computation of model training than CPU-based training, especially beneficial for processing large datasets typical in medical imaging.
  • Logger: The TensorBoardLogger is used to log training and validation metrics along with other statistical data. This logger outputs the data to the directory "./logs_lung_ResNet50", which can then be visualized using TensorBoard. This tool helps in monitoring the model's performance and understanding its learning dynamics over time.
  • Logging Frequency: The log_every_n_steps=1 setting ensures that metrics are logged after every training step, providing granular insight into the model's training process. This frequent logging can be very useful for detailed analysis and troubleshooting during the early stages of model development.
  • Callbacks: Besides checkpointing, additional callbacks can be added to perform tasks like early stopping, learning rate scheduling, etc. In this case, the checkpoint callback is primarily used.

Training Execution

  • Max Epochs: The trainer is set to run for a maximum of 100 epochs. An epoch represents one complete pass through the entire training dataset. This number is a balance between adequate learning time and preventing overfitting.
  • Data Loaders: The train_loader and val_loader provide batches of training and validation data to the model. These loaders handle the necessary operations to feed data into the model during the training and validation phases efficiently.

Starting the Training

The trainer.fit(model, train_loader, val_loader) command starts the training process using the defined model and data loaders. The training runs iteratively, updating the model weights to minimize the loss function defined within the model class. Validation steps occur at the end of each epoch, providing performance metrics that are not influenced by the training process, thus giving an unbiased indication of the model's ability to generalize.


model = LungDetectionModel()
checkpoint_callback = ModelCheckpoint(
monitor = "Val Loss",
save_top_k=5,
mode="min")
trainer = pl.Trainer(accelerator='gpu', devices=1, logger=TensorBoardLogger("./logs_lung_ResNet50"), log_every_n_steps=1, callbacks=checkpoint_callback, max_epochs=100)
trainer.fit(model, train_loader, val_loader)

This training process is highly automated and monitored, allowing the researcher or developer to focus on higher-level tasks such as model tuning and evaluation, while PyTorch Lightning handles the underlying training dynamics. This setup is particularly beneficial in fields like medical imaging, where model performance and reliability are of utmost importance.

References

1. Aggarwal, R., Sounderajah, V., Martin, G. et al. Diagnostic accuracy of deep learning in medical imaging: a systematic review and meta-analysis. npj Digit. Med. 4, 65 (2021). https://doi.org/10.1038/s41746-021-00438-z

2. Mohsen Soori, Behrooz Arezoo, Roza Dastres,Artificial intelligence, machine learning and deep learning in advanced robotics, a review,Cognitive Robotics,Volume 3,2023,Pages 54–70,ISSN 2667–2413,https://doi.org/10.1016/j.cogr.2023.04.001.

3. H. Xu, J. Yuan and J. Ma, “MURF: Mutually Reinforcing Multi-Modal Image Registration and Fusion,” in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 45, no. 10, pp. 12148–12166, Oct. 2023, doi: 10.1109/TPAMI.2023.3283682.

4. Visser, M., Petr, J., Müller, D.M.J., Eijgelaar, R.S., Hendriks, E.J., Witte, M., Barkhof, F., van Herk, M., Mutsaerts, H.J.M.M., Vrenken, H., de Munck, J.C., & De Witt Hamer, P.C. (2020). Accurate MR Image Registration to Anatomical Reference Space for Diffuse Glioma. Frontiers in Neuroscience, 14, 585. doi: 10.3389/fnins.2020.00585

5. Rosenman JG, Miller EP, Tracton G, Cullip TJ. Image registration: an essential part of radiation therapy treatment planning. Int J Radiat Oncol Biol Phys. 1998 Jan 1;40(1):197–205. doi: 10.1016/s0360–3016(97)00546–4. PMID: 9422577.

6.Mayr, Andreas & Klambauer, Günter & Unterthiner, Thomas & Hochreiter, Sepp. (2015). DeepTox: Toxicity Prediction Using Deep Learning. Frontiers in Environmental Science. 3. 10.3389/fenvs.2015.00080.

--

--

Sandaruwan Herath
Data Science and Machine Learning

IT Consultant/Lecturer | Data Analyst/BI Consultant/Machine Learning