Brain MRI Segmentation with Segment Anything Model (SAM) Part 1: Background — Data Preparation
The implementation code can be found in Github or Kaggle.
1. Problem Statement and Background
Medical image segmentation is a crucial process in healthcare, specifically in the identification and delineation of anatomical structures within medical images. These structures include organs, lesions, and tissues, which are vital for accurate diagnosis and treatment planning. This technique is indispensable in clinical applications such as computer-aided diagnosis, treatment planning, and the monitoring of disease progression.
Brain MRI segmentation is particularly important in the detection and diagnosis of brain cancer. By accurately segmenting MRI images, clinicians can identify cancerous tissues, which is essential for determining the appropriate course of treatment. This process not only aids in early detection but also significantly improves the chances of successful treatment outcomes.
Existing brain MRI segmentation methods range from manual delineation by medical experts to more sophisticated automated techniques. Among the notable automated methods is the U-Net architecture, a convolutional neural network designed specifically for medical image analysis. Other methods include various algorithm-based approaches that strive to accurately identify and segment relevant brain structures, facilitating effective diagnosis and treatment planning.
The introduction of the Segment Anything Model (SAM) by META AI’s FAIR Lab in April 2023 represented a major advancement in image segmentation technology. Detailed in their groundbreaking paper “Segment Anything”, SAM’s promptable-based architecture has brought a new level of versatility and accuracy to the field. Its ability to generate precise object masks automatically has set a new standard in the realm of image segmentation-related tasks. SAM, a model trained on the extensive SA-1B dataset, which includes 11 billion images and over 1 billion masks, primarily focuses on natural image data known for its pronounced edge details. This aspect markedly contrasts with typical medical imaging data, where the boundaries between objects are rarely as distinct.
This project aims to explore the effectiveness of SAM in the context of brain MRI imaging. Our objective is to utilize SAM’s advanced segmentation capabilities to create accurate and detailed masks for brain MRI scans. We anticipate that this exploration will significantly enhance the accuracy and efficiency of medical diagnoses and treatment planning, potentially ushering in a new era of technological advancement and clinical application in medical imaging.
2. Segment Anything Model
Segment Anything Model, a foundational (pretrained) model developed for image segmentation tasks, was launched by Meta AI in April 2023. Trained on an extensive dataset comprising 11 billion images and over 1 billion masks, known as the SA-1B dataset, SAM demonstrates remarkable zero-shot generalization abilities. A key feature of this model is its prompt-based interface, which allows users to input various types of prompts, such as points, bounding boxes, or text. The model then utilizes these prompts to generate precise segmentation masks on images, showcasing its adaptability and advanced understanding of image segmentation.
The features of SAM are supported by 3 main components: image encoder, prompt encoder, and mask decoder.
Image Encoder
The image encoder in SAM transforms the input image into a format that the model can understand. It takes an image and produces an image embedding (C x H x W). Specifically, it uses the ViT-H/16 encoder, which reduces the image size by 16 times. For instance, an input image of 1024 x 1024 becomes a 64 x 64 output. To further simplify the image, the encoder employs a 1 x 1 convolution layer with 256 channels, followed by a 3 x 3 convolution layer also with 256 channels. Each convolution step includes a normalization layer to maintain image quality. The image encoder has 632M parameters.
Prompt Encoder
SAM, as a prompt-able model, uses prompts to generate masks. This part of SAM contains 4M parameters. SAM itself supports two types of prompts: sparse (points, boxes, and text) and dense (masks). The prompt encoder then converts these prompts into prompt embeddings:
- For sparse prompts, it transforms the prompt into a 256-dimensional embedding vector.
- For dense prompts, the encoder downscales the prompt using two stride-2 convolutions and then applies a final 1×1 convolution to adjust the channel dimension to 256.
Mask Decoder
The mask decoder is a crucial component that integrates the representations of both the image and the prompts to accurately predict the appropriate mask for the image.
3. Dataset
In this project, the LGG segmentation dataset is utilized. It comprises brain MRI scans paired with manually created segmentation masks highlighting FLAIR abnormalities. These images are sourced from The Cancer Imaging Archive (TCIA). They represent a subset of 110 patients from The Cancer Genome Atlas (TCGA) that focuses on lower-grade glioma, each having available data on fluid-attenuated inversion recovery (FLAIR) sequences and genomic clusters.
4. Library and Configuration
Moving forward, we will develop Python code to investigate the features of SAM. The first step in this process is to import the essential libraries. If executing this code leads to an error, ensure that all necessary packages are installed by using the pip install
command.
# for general purposes
import random
import time
import warnings
# for mathematical computation
import numpy as np
from statistics import mean
# for visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
from tqdm.notebook import tqdm
from IPython.display import clear_output
# for data management and manipulation
import os
import glob
import io
import cv2
import pandas as pd
import datasets as dts
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from monai.transforms import Compose, NormalizeIntensityd
# for modeling and model evaluation
import torch
import monai
import torchvision
from sklearn.model_selection import train_test_split
from transformers import SamProcessor
from transformers import SamModel
from torch import nn
from torch.optim import Adam
from monai.metrics import compute_iou
We will define a CFG
class to encapsulate key variables that are essential throughout the execution of the code.
class CFG:
# define paths
DATASET_PATH = "/kaggle/input/lgg-mri-segmentation/"
TRAIN_PATH = "/kaggle/input/lgg-mri-segmentation/kaggle_3m/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # if GPU available, use it to catalize the training process.
TRAIN_BATCH_SIZE = 2 # the number of data to pass in each iteration throughout the network in training phase.
TEST_BATCH_SIZE = 1 # the number of data to pass in each iteration throughout the network in testing phase.
LEARNING_RATE = 1e-3 # how fast model will learn
WEIGHT_DECAY = 0 # for regularization
EPOCH = 10 # 1 epoch = all data have passed through network.
5. Data Preparation and Exploration
Before inputting data into the model, it is essential to undergo several preprocessing steps. As observed on this Kaggle site, the data is organized into numerous folders. Within each folder, there are images along with their corresponding masks. The initial step involves retrieving the file paths for all these items.
# fetch the paths of all files
dataset_images = glob.glob(f"{CFG.TRAIN_PATH}**/*.tif")
Constructing The Dataframe
Subsequently, we will create a dataframe to organize the data details, including the file paths. This will also involve assigning an identifier to distinguish whether each file is an image or a mask. We will accomplish this using the following functions.
# get the patient details
def get_sample_patient_id(image_paths):
return [(_.split('/')[-2:][0]) for _ in image_paths]
# get the sample number
def get_sample_number(image_paths):
sample_numbers = []
is_mask = []
for path in image_paths:
path_list = path.split('/')[-2:][1].split('_')
if 'mask.tif' in path_list:
sample_numbers.append(int(path_list[-2]))
is_mask.append(1)
else:
sample_numbers.append(int(path_list[-1].replace('.tif', '')))
is_mask.append(0)
return sample_numbers, is_mask
# construct the dataframe
def build_df(image_paths):
sample_numbers, mask_label = get_sample_number(image_paths)
df = pd.DataFrame({
'id' : sample_numbers,
'patient' : get_sample_patient_id(image_paths),
'image_path': image_paths,
'is_mask' : mask_label
})
return df
Here is what the dataframe appears like:
Next, we will perform a groupby
operation to segregate the dataframes, creating separate ones for images and masks. Following this, we will merge them so that the paths for both images and masks are aligned side-by-side.
# images_df: for images
# mask_df: for masks
grouped_df = dataset_df.groupby(by='is_mask')
images_df, mask_df = (
grouped_df.get_group(0).drop('is_mask', axis=1).reset_index(drop=True),
grouped_df.get_group(1).drop('is_mask', axis=1).reset_index(drop=True)
)
mask_df = mask_df.rename({'image_path': 'mask_path'}, axis=1)
mask_df.head()
# merge images dataframe and masks dataframe
ds = images_df.merge(
mask_df,
on=['id', 'patient'],
how='left'
)
Following this, we will generate labels for each data row by examining for abnormalities, indicated by the presence of pixels with a value of 1 in the mask. If such pixels are found, the diagnosis is labeled as 1, signifying cancerous tissue; if not, it is labeled as 0.
def _load(image_path, as_tensor=True):
image = Image.open(image_path)
return np.array(image).astype(np.float32) / 255.
def generate_label(mask_path, load_fn):
mask = load_fn(mask_path)
if mask.max() > 0:
return 1 # brain tumor presents
return 0 # normal
# generate MRI Label
ds['diagnosis'] = [generate_label(_, _load) for _ in tqdm(ds['mask_path'])]
ds.head()
For this project, our focus will be solely on the diagnosis of cancerous tissue. Therefore, we will limit our dataset to cases where the diagnosis equals 1, using a total of 1360 rows of data. This decision is guided by the need for simplicity and the constraints of our computational resources.
# filter valid masks and choose only 1360
ds = ds[ds['diagnosis']==1]
ds = ds.head(1360)
Train-Test Splitting
A common practice in machine learning involves dividing our data into training and testing sets. We train our model using the training set and evaluate its performance with the testing set. To achieve this, we utilize the train_test_split()
function from scikit-learn
, which returns four components: training features, testing features, training labels, and testing labels. In our context, the “features” refer to the images, while the “labels” denote the masks.
# train-test splitting
image_train, image_test, mask_train, mask_test = train_test_split(
ds['image_path'], ds['mask_path'], test_size = 0.10)
Next, we will merge the paths of the images and their corresponding masks into a single dataframe.
# train_df: contains the path to image and mask of training set
train_df = pd.concat([image_train, mask_train], axis=1).reset_index(drop=True)
# test_df: contains the path to image and mask of testing set
test_df = pd.concat([image_test, mask_test], axis=1).reset_index(drop=True)
For both the train_df
and test_df
, we will utilize the Dataset.from_pandas
function from the dataset
package, which we have aliased as dts
. This function will encapsulate our data, facilitating ease in the modeling process.
train_dataset = dts.Dataset.from_pandas(train_df)
test_dataset = dts.Dataset.from_pandas(test_df)
For both train_dataset
and test_dataset
, we will apply transformations to ensure that each dataset stores both the image and its corresponding mask. The images will be stored in RGB format, while the masks will be stored as grayscale.
def transform(data):
# Load the image
with open(data['image_path'], 'rb') as f:
image = Image.open(io.BytesIO(f.read())).convert('RGB')
data['image'] = image
with open(data['mask_path'], 'rb') as f:
mask = Image.open(io.BytesIO(f.read())).convert('L') # to grayscale
data['mask'] = mask
return data
# input both image_path and mask_path to `transform()` and return the image and the mask
train_dataset = train_dataset.map(transform, remove_columns=['image_path','mask_path'])
test_dataset = test_dataset.map(transform, remove_columns=['image_path','mask_path'])
Now, train_dataset
and test_dataset
store the files instead of the paths. For example:
# take random sample from train_dataset
example = train_dataset[0]
img = example['image']
msk = example['mask']
If you call either img
or msk
, the respective image or mask will be rendered.
Visualizing The Dataset
Next, we will proceed with visualizing our data, a crucial step for a deeper understanding. By overlapping the image and the mask in one axes, we can effectively identify cancerous regions in the brain. This can be achieved using the following code snippet.
# code for showing the mask
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
fig, axes = plt.subplots()
axes.imshow(np.array(img))
ground_truth_seg = np.array(example["mask"])
show_mask(ground_truth_seg, axes)
axes.title.set_text(f"Ground truth mask")
axes.axis("off")
The above code is resulting this image:
Defining The Dataset Class
To prepare our data for input into our model, several processing steps are necessary to align with the model’s requirements. Our system, SAM, is built using the PyTorch framework, which facilitates dataset processing through the use of Dataset
and DataLoader
classes. The following code snippets illustrate how these are implemented in our context.
class SAMDataset(torch.utils.data.Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"]
ground_truth_mask = np.array(item["mask"])
# get bounding box prompt
prompt = get_bounding_box(ground_truth_mask)
# prepare image and prompt for the model
inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
# add ground truth segmentation
inputs["ground_truth_mask"] = ground_truth_mask/255
return inputs
def get_bounding_box(ground_truth_map):
'''
This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
The padding is random int values between 5 and 20 pixels
'''
if len(np.unique(ground_truth_map)) > 1:
# get bounding box from mask
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(5, 20))
x_max = min(W, x_max + np.random.randint(5, 20))
y_min = max(0, y_min - np.random.randint(5, 20))
y_max = min(H, y_max + np.random.randint(5, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
else:
return [0, 0, 256, 256]
processor = SamProcessor.from_pretrained("facebook/sam-vit-base", do_normalize=False)
train_sam_ds = SAMDataset(dataset=train_dataset, processor=processor)
First, we will define SAMDataset
for our custom dataset, a class that inherits from torch.utils.data.Dataset
. When creating a class for a custom dataset in PyTorch, it is essential to include three mandatory methods: __init__()
, __len__()
, and __get_item()__
.
- The
__init__()
function in theSAMDataset
class is automatically executed upon instantiation of the class. Within this function, 2 variables are assigned:self.dataset
, which holds the set of image-mask pairs, andself.processor
, which refers to the SAM processor. The SAM processor is designed to transform the data, ensuring it is appropriately formatted for the SAM model. - The
__len__()
function is designed to return the number of image-mask pairs in the dataset. Essentially, it indicates the total amount of data available in the dataset. - The
__getitem__()
function is called when accessing data from the dataset. It begins by retrieving the data at the specified index, namely the image and its corresponding mask. The mask is converted into an array, a necessary step for extracting the bounding box location. Given that SAM is a prompt-able segmentation model, it requires a prompt to identify a valid mask for each image. In this project, the bounding box serves as this prompt. To determine the bounding box, we employ theget_bounding_box()
function, which identifies the location of cancerous tissue in the ground truth mask. This function searches for pixels with a value of 1 and creates a bounding box encompassing these areas. Both the MRI image and the generated prompt are then processed throughself.processor()
(the SAM processor). This processor outputs several values. Subsequently, we remove the batch dimension added by default during processing. Finally, we scale each pixel value of the ground truth mask by dividing by 255. This normalization ensures that pixel values range between 0 and 1, a crucial step for evaluating model performance later in the modeling phase and for preventing anomalous values during training.
After defining the SAMDataset
class, we proceed to utilize it. The first step is to define the processor using SamProcessor.from_pretrained()
, specifying the variant of the processor to be used. In this instance, we opt for the variant released by Facebook: facebook/sam-vit-base
. For exploring other variants, refer to the provided link.
SamProcessor.from_pretrained()
has parameter do_normalize
. This parameter is used if we want to normalize the pixel values by specific mean or standard deviation. This process of normalization may risk in resulting negative pixel values. By setting the parameter do_normalize
to False
, we explicitly indicate that we do not wish to normalize the pixel values in the MRI images. This process ensure no negative values are being fed to the network.
By default, the processor performs rescaling (do_rescale = True
), ensuring that each pixel value in the MRI images is within the range of 0 to 1. Scaling is nothing but multiplying the pixel values with a constant. In this case, 1/255. This rescaling is vital as it guarantees that the model processes values between 0 and 1 for both the image and the mask, maintaining consistency in data handling and model performance.
Finally, we utilize all of the class functionality and pass train_dataset
and processor
.
Defining The DataLoader
During model training, data is not loaded all at once but in mini-batches. Each iteration involves passing a mini-batch of data through the model. Completing one pass of all the data constitutes an epoch, meaning an epoch comprises several iterations.
To manage the loading of these mini-batches, we employ the DataLoader
class. This class accepts a batch_size
argument, which defines the quantity of data to be passed to the network in each iteration. For our purposes, we will set batch_size = 2
. This configuration allows for efficient data processing, balancing computational load and training speed.
# DataLoader to enable iteration during training process
# batch_size = 2
train_dataloader = DataLoader(train_sam_ds, batch_size=CFG.TRAIN_BATCH_SIZE, shuffle=False)
Click here to the part 2 of this article.