Segment Anything Model (SAM) for Medical Image Segmentation

Reza Kalantar
5 min readMay 30, 2023

--

In the evolving landscape of artificial intelligence (AI), medical imaging stands as a field witnessing profound transformation. Riding this wave of change, Facebook’s (now Meta) research group has developed a groundbreaking model architecture known as SegmentAnything (SAM). SAM’s brilliance lies in its ability to generate segmentation masks for diverse objects in an image. This adaptive quality allows it to perform a myriad of tasks within medical imaging, from segmenting everyday objects to illuminating specific structures in medical images.

The code snippet is available on my GitHub page.

Figure designed by Reza Kalantar

A Deep Dive into SAM

SAM’s process of fine-tuning for specific medical imaging tasks is a multi-step journey. Here’s a breakdown:

  1. Data Loading and Preprocessing: The initial step includes handling medical imaging data, which is usually stored in formats like DICOM or NIfTI. Libraries such as pydicom or nibabel are indispensable for this step. The images are then preprocessed, including tasks like reorienting, normalizing pixel intensities, and converting the images and masks into model-friendly formats.
  2. Bounding Box Prompt Creation: Bounding box prompts are the guiding light for SAM’s segmentation. The bounding boxes must be designed to loosely encapsulate the structure you wish to segment. Interestingly, SAM can entertain multiple bounding boxes, permitting multi-object segmentation in a single go.
  3. Model and Processor Preparation: This involves loading the pre-trained SAM model and the associated processor. The latter is responsible for preparing your inputs and prompts for the model.
  4. Model Fine-Tuning: This crucial step entails running a training loop, computing the loss function (a comparison of the model’s output to the actual mask), backpropagating the gradients, and updating the model’s weights.
  5. Model Evaluation: After the model is trained, it’s time to evaluate its performance on a validation set to gauge how it will perform on unseen data. Metrics such as the Dice coefficient or Intersection over Union (IoU) come in handy here.
  6. Inference: The final step involves segmenting new medical images using your trained model. This process includes preparing the image and bounding box prompt, feeding them into the model, and post-processing the output to yield your final segmentation mask.

In this article, I’ll guide you through finetuning SAM to segment lungs from CT scans using Goolge Colab. We will also cover the necessary steps to preprocess medical images and convert them into 2D slices.

Kick-starting with the Kaggle Dataset

To start with, you’ll need to install the Kaggle library:

!pip install -q kaggle

Next, create a directory named ".kaggle" in the root directory:

!mkdir -p ~/.kaggle

Then, upload your Kaggle API token, obtainable from the Kaggle website:

from google.colab import files
files.upload() # upload your Kaggle.json API token

After uploading the token, place it in the “.kaggle” directory:

!cp kaggle.json ~/.kaggle/

Now, you’re all set to download the dataset. In this tutorial, we’ll be using the “finding-lungs-in-ct-data” dataset:

!kaggle datasets download -d kmader/finding-lungs-in-ct-data

Lastly, unzip the downloaded dataset:

!unzip -q /content/finding-lungs-in-ct-data.zip

Preprocess Data

Before we begin processing our data, we’ll need to install and import some essential libraries. These include Monai and SimpleITK for medical image processing and training in PyTorch, as well as the HuggingFace library for transformers:

!pip install -q monai
!pip install -q SimpleITK
!pip install -q git+https://github.com/huggingface/transformers.git

This dataset includes 3D .nii.gz volumes and contours from 4 patients. We first split this data to 2 for training, 1 for validation and 1 for testing, and save 2D axial slices to the relevant directories:

No. of images: 4  labels: 4
processing patient 0 (325, 512, 512) (325, 512, 512)
processing patient 1 (465, 512, 512) (465, 512, 512)
processing patient 2 (301, 512, 512) (301, 512, 512)
processing patient 3 (117, 512, 512) (117, 512, 512)

In the given code snippet, we’re initializing a dictionary, data_paths, to store the paths of our image and label files.

This involves traversing the directories of each dataset category (training, validation, and testing), and for each data type (images and masks), we construct the directory path. We then collect all the file paths in the directory that end with “.nii.gz” into a list.

Each list is stored in the data_paths dictionary with a key that combines the dataset type and the data type.

Number of training images 655
Number of validation images 265
Number of test images 49

The given code snippet is creating an instance of SamProcessor for image preprocessing. SamProcessor is part of the Hugging Face transformers library and is utilized for processing images to be used with the Sequence-to-Sequence with Attention Mechanism (SAM) model.

We’re initializing it using a pretrained model from Facebook, specifically the “sam-vit-base” model. This processor will be used to format our images appropriately for input into the SAM model:

SamProcessor:
- image_processor: SamImageProcessor {
"do_convert_rgb": true,
"do_normalize": true,
"do_pad": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.485,
0.456,
0.406
],
"image_processor_type": "SamImageProcessor",
"image_std": [
0.229,
0.224,
0.225
],
"pad_size": {
"height": 1024,
"width": 1024
},
"processor_class": "SamProcessor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"size": {
"longest_edge": 1024
}
}

The get_bounding_box function is designed to create bounding box coordinates for a given segmentation map. These coordinates are derived from the identified contours within the map and are adjusted with a randomly selected padding for variability. If no contours exist, the bounding box is set to the image size.

The SAMDataset class creates a custom dataset suitable for our application. It applies multiple transformations to our data, such as loading the images, ensuring correct orientation, normalizing the intensities, and cropping them to a specific size. The class also prepares images and prompts for the model by converting the images to the expected input format, generating bounding boxes, and arranging the segmentation masks:

This code creates data loaders for the training and validation datasets. The SAMDataset objects are created using image and mask paths and the SamProcessor for preprocessing. Then, PyTorch's DataLoader function is used to batch and shuffle the data, facilitating efficient and randomized feeding of data into the model during training.

Now, we can visualize the processed data:

pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask torch.Size([256, 256])
Sample frontal chest CT and corresponding lung mask

Train

Now that our data loaders are prepared, we can begin configuring our model for fine-tuning. We preserve the encoder weights from the pre-trained SAM model by freezing them:

Finally, we can start training our model:

Segmentation progress displayed during training

Inference

After the training is completed, we use the best weights to predict segmentation masks from the test data:

Test predictions from the finetuned SAM model

Congratulations! You’ve successfully fine-tuned the SAM model for lung segmentation in CT scans using bounding box prompts. Happy coding!

Try the code in Google Colab:

Thank you for reading! If you find my blogs interesting or would like to get in touch, reach out on here, Github or LinkedIn.

--

--

Reza Kalantar

Medical AI Researcher by Profession • Scientist/Engineer by Trade • Investor by Instinct • Explorer by Nature • Procrastinator by Choice