Sitemap

Vision Transformer for semantic segmentation on medical images. Practical uses and experiments.

10 min readMar 9, 2024

The focus of this article is Vision Transformer (ViT) and its practical applications for semantic segmentation problem. I discuss again the task of segmentation of areas of abnormality on MR images. I’ve already solved this task using U-Net and discuss it here. Also, I’ve described my solution of using ViT for image classification task on a custom dataset containing medical images here.

Essential things about ViT are the following:

· ViT architecture is based on image representation as a set of patches. Image patches are non-overlapping blocks of image. Each block has a vector of embeddings initially formed from image pixels containing in that block.

· Transformer Encoder is a main part of ViT which trains similarity between patches according to their class affiliation. It contains a sequence of linear, normalization, and activation layers.

· ViT model pre-trained on a large dataset (e. g. ImageNet21K) can be used to transfer learning on a custom dataset, and the fine-tuned model shows a good performance.

Essential things about U-Net are the following:

· U-Net consists of two parts: Encoder and Decoder. Encoder contains series of blocks with features extraction and image reduction. Decoder appears symmetrically to Encoder and reconstructs image resolution.

· U-Net is one of the best architectures for semantic segmentation on medical images among CNNs.

An overview of ViTs for segmentation can be found in this article. I use Swin Transformer V2 from Hugging Face as Encoder in my system for segmentation. Swin Transformer — Hierarchical Vision Transformer using Shifted Windows — contains 4 stages of encoder processing embedding patches. Initially patch size is 4x4 pixels. On each encoder stage, patches resolution increases twice by merging embeddings from smaller patches from previous stage. It means that the spatial resolution of the image, expressed in patches, is reduced twice at each subsequent stage. Picture below (from Hugging Face documentation) shows a high-level architecture of Swin Transformer:

Press enter or click to view image in full size

Note, that the sequence of encoder’s blocks with down-sampling is similar to high-level architecture of U-Net Encoder, discussed in my previous article. Also, note that the encoder of ViT for classification operates with patches 16x16 on any stage (see the picture above).

Several models for segmentation were trained for Swin Transformer including a large model trained on ImageNet21K dataset (~14 million images). Full segmentation pipeline consists of Encoder and Decoder. Swin Transformer Encoder from Hugging Face is used for following fine tuning on custom datasets. In other words, I use pre-trained Swin Transformer large model as Encoder and implement and train my own Decoder to build a full system for semantic segmentation on my dataset.

Swin Transformer V2 from Hugging Face: look deep inside

Let’s look at Swin Transformer V2 model from Hugging Face using the following code blocks:

Installation:

!pip install torchvision
!pip install torchinfo
!pip install -q git+https://github.com/huggingface/transformers.git

Imports:

from PIL import Image
from torchinfo import summary
import torch

Google drive mounting (for google colab):

from google.colab import drive
drive.mount('/content/gdrive')

Cuda device setting:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Large pre-trained model loading (trained on ImageNet21K):

from transformers import AutoImageProcessor, Swinv2Model

image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k")
model = Swinv2Model.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k").to(device)

image_processor defines a set of transforms applied to an input image which initially is in the form of PIL Image:

ViTImageProcessor {
"_valid_processor_keys": [
"images",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"return_tensors",
"data_format",
"input_data_format"
],
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.485,
0.456,
0.406
],
"image_processor_type": "ViTImageProcessor",
"image_std": [
0.229,
0.224,
0.225
],
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 192,
"width": 192
}
}

Input PIL image is transformed to the torch tensor, resize to image resolution 192x192 and normalized.

Model summary:

summary(model=model, input_size=(1, 3, 192, 192), col_names=['input_size', 'output_size', 'num_params', 'trainable'])
Press enter or click to view image in full size

This is a large model containing more than 195 million parameters.

Call

model.eval()

to see all layers which the model contains.

The parts of Swin Transformer V2 model are following:

· Patch-embeddings layer forming 2034=48*48 patches with a size 4x4 for an input image with a resolution 192x192. For each patch a vector of linear projection with a length 192 is formed.

· 4 encoder stages. On each stage Multi Head Self Attention is trained. Patches sizes in pixels increases twice in each stage (with patch-embeddings merging). Picture resolution in patches is reduced twice in each stage.

· Normalization of the encoder output to produce last_hidden_state from the encoder.

· Average pooling of last_hidden_state tensor to produce pooler_output vector which consists of class embeddings.

Let’s see how it works step by step.

Load any image and pre-process it by image_processor:

import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)


inputs = image_processor(image, return_tensors="pt")

Inputs is sent to Swin Transformer V2 model. Picture below illustrates that calling the parts of Swin Transformer V2 sequentially (as shown on the left) is equivalent to calling the model as a whole (as shown on the right):

Press enter or click to view image in full size
Fig. 1

Look at the outputs of each encoder stage:

print(im0.shape)
print(im1.shape)
print(im2.shape)
print(im3.shape)
print(im4.shape)

We see the following shapes:

im0 -> torch.Size([1, 2304, 192]) -> 2304=48*48 — number of patches, 192 — patch-embeddings length

im1 -> torch.Size([1, 576, 384]) -> 576=24*24 — number of patches, 384 — patch-embeddings length

im2 -> torch.Size([1, 144, 768]) -> 144=12*12 — number of patches, 768 — patch-embeddings length

im3 -> torch.Size([1, 36, 1536]) -> 36=6*6 — number of patches, 1536 — patch-embeddings length

im4 -> torch.Size([1, 36, 1536]) -> 36=6*6 — number of patches, 1536 — patch-embeddings length

These outputs from Swin Transformer V2 pre-trained large model became inputs of my Decoder model. I train my Decoder to obtain segmentation masks for areas of abnormality on MRI of brain. Flowchart below shows a high-level architecture of this Decoder:

Fig. 2

Note, that “image resizing to 256x256” in the last block of the flowchart above is a custom element in the common Decoder’s flowchart: I use image resolution 256x256 for MRI of brain and segmentation mask images.

My implementation for semantic segmentation system on MRI of brain with using of Swin Transformer V2. Transformer vs U-Net

Let’s return to the task of semantic segmentation on MRI of brain. I use brain MRI dataset from Kaggle. This dataset contains data of 110 patients: a set of MRI with brain slices and a set of corresponding images with masks of abnormality areas for every patient. Picture below shows the examples of “brain slice image + image with mask” pairs:

Press enter or click to view image in full size

In the dataset, the number of pairs “brain slice image + mask image” for each person varies from 20 pairs to 88 pairs. The whole set contains 3929 pairs: 2556 pairs with zero-mask and 1373 pairs with non-zero masks for areas of abnormality.

I implement and train my model of Decoder which uses outputs from Swin Transformer V2 stages (im0, im1, im2, im3, im4 — see Fig. 1) as Decoder’s inputs. I use PyTorch for the implementation of the model and implement it according to the flowchart in Fig. 2.

The code below shows pre-processing for one image which initially is in the form of PIL Image into the im0, im1, im2, im3, im4 tensors from the pre-trained Swin Transformer V2 model stage. Variable model in the code below is the loaded pre-trained on ImageNet21K Swin Transformer V2 model (see code blocks in the previous section):

img = <load PIL Image>
img = image_processor(images=img, return_tensors="pt")

with torch.no_grad():
x = model.embeddings(**img.to(device))
input_dimensions=x[1]
im0 = x[0].squeeze(0)

x = model.encoder.layers[0](x[0], input_dimensions=input_dimensions)
im1 = x[0].squeeze(0)

x = model.encoder.layers[1](x[0], input_dimensions=(input_dimensions[0]//2, input_dimensions[1]//2))
im2 = x[0].squeeze(0)

x = model.encoder.layers[2](x[0], input_dimensions=(input_dimensions[0]//4, input_dimensions[1]//4))
im3 = x[0].squeeze(0)

x = model.encoder.layers[3](x[0], input_dimensions=(input_dimensions[0]//8, input_dimensions[1]//8))
x = model.layernorm(x[0])
im4 = x.squeeze(0)

Note: I use squeeze() to remove batch dimension for a single image, because I suppose that it will be sent to torch-DataLoader, which adds batch dimension to image-batches.

Only transform to a torch tensor and resizing are applied to mask images. Batches of 5 input tensors from DataLoader are sent to the following model:

class Up_Linear(nn.Module):
def __init__(self, in_ch, size, coef=1):
super(Up_Linear, self).__init__()
self.shuffle = nn.PixelShuffle(upscale_factor=2)

n_ch = int(coef * in_ch)

self.ln = nn.Sequential(
nn.Linear(in_ch * 2, n_ch),
nn.ReLU(inplace=True),
nn.Linear(n_ch, in_ch * 2),
nn.ReLU(inplace=True),
)

self.size = size

def forward(self, x1, x2):
x = torch.cat((x1, x2), 2)
x = self.ln(x)
x = x.permute(0, 2, 1)
x = torch.reshape(x, (x.shape[0], x.shape[1], self.size, self.size))
x = self.shuffle(x)
x = torch.reshape(x, (x.shape[0], x.shape[1], self.size*self.size*4))
x = x.permute(0, 2, 1)
return x

class MRI_Seg(nn.Module):
def __init__(self):
super(MRI_Seg, self).__init__()

self.ups3 = Up_Linear(1536, 6, 1)
self.ups2 = Up_Linear(768, 12, 1)
self.ups1 = Up_Linear(384, 24, 2)
self.ups0 = Up_Linear(192, 48, 3)

self.shuffle = nn.PixelShuffle(upscale_factor=2)

self.out = nn.Sequential(
nn.Conv2d(24, 1, kernel_size=1, stride=1),
nn.Sigmoid()
)

def forward(self, x0, x1, x2, x3, x4):
x = self.ups3(x4, x3)
x = self.ups2(x, x2)
x = self.ups1(x, x1)
x = self.ups0(x, x0)

x = x.permute(0, 2, 1)
x = torch.reshape(x, (x.shape[0], x.shape[1], 96, 96))
x = self.shuffle(x)
x = transforms.Resize((256, 256))(x)

x = self.out(x)
return x


net = MRI_Seg().to(device)

The summary of this model:

summary(model=net, input_size=[(1, 2304, 192), (1, 576, 384), (1, 144, 768), (1, 36, 1536), (1, 36, 1536)], col_names=['input_size', 'output_size', 'num_params', 'trainable'])
Press enter or click to view image in full size

The model contains more than 13 million of trainable parameters (like U-Net model).

I use Binary Cross Entropy loss function to train my model to build masks closer to labels (masks) images. I use Adam optimizer and learning rate 0.0001. I use both IoU (Intersection over Union) and Dice metrics as a quality measure: IoU = 1 and Dice = 1 mean an ideal quality. Note, that for all results below including pictures, I use segmentation masks generated by trained model with threshold application: mask pixel values are set to 0 if they < 0.5 else mask pixel values are set to 1.

The picture below shows the comparison of results for U-Net architecture (they were obtained and presented here) and for MRI_Seg model (above). For both models I’ve chosen the checkpoints showing the best results on the test set. For the performance evaluation of both U-Net and Transformer-based models I use 550 randomly selected entries from the training set and all 394 test entries containing in the test set (the same training and test sets for the both models are used):

As we can see the performance of Transformer-based segmentation model is significantly lower than one of U-Net model. I’ve tried to change my Decoder architecture and the number of trainable parameters but it hasn’t improved the performance.

In my previous article, I experimented with augmented dataset for the segmentation model. I applied horizontal flip (changing left and right sides of the image) to MRI and mask-images and created a dataset with a number of images doubled compared with the source dataset. Now I’m using this augmented dataset for training of my model based on Transformer. The picture below shows the comparison of results for U-Net architecture (they were obtained and presented here) and for MRI_Seg model (above) on augmented dataset. For both models I’ve chosen the checkpoints showing the best results on the test set. For the performance evaluation of both U-Net and Transformer-based models I use 550 randomly selected entries from the training set and all 787 test entries containing in the test set (the same training and test sets for the both models are used):

We can see that the performance of Transformer-based segmentation model trained on augmented dataset has increased significantly and became close to the performance of U-Net model.

Pictures below show results of the work of the trained model based on Transformer on test images.

Press enter or click to view image in full size
Press enter or click to view image in full size
Press enter or click to view image in full size

Conclusion

· Swin Transformer V2 high-level architecture allows to implement a custom Decoder for a semantic segmentation system using concepts of U-Net model.

· A size of a custom dataset should be big enough (~10K). In this case the performance of the semantic segmentation system based on Transformer is good.

· The semantic segmentation system based on Transformer is able to get a performance close to the best segmentation models like U-Net.

Links:

The full code in Google Colab, along with additional text data, is available on GitHub.

--

--

Responses (3)