Segmentation of abnormality areas on medical images. Practical uses of U-Net and experiments.

Olga Mindlina
8 min readFeb 23, 2024

--

Semantic segmentation is one of important problems in Computer Vision. Semantic segmentation is used to identify pixel groups associated with a particular feature or an object. Such a task is often carried out on medical images.

In this article, I discuss segmentation of abnormality areas on MR images. I use brain MRI dataset from Kaggle. This dataset contains data of 110 patients, with a set of MRI with brain slices and a set of corresponding images with masks of abnormality areas for every patient. Picture below shows the examples of “brain slice image + image with mask” pairs:

In the dataset, the number of pairs “brain slice image + mask image” for each person varies from 20 pairs to 88 pairs. The whole set contains 3935 pairs: 2556 pairs with zero-mask and 1379 pairs with non-zero masks for areas of abnormality.

Solving semantic segmentation task for MRI of brain. 2D U-Net

I need to build a system which is able to create a mask image for abnormality area (including zero-mask cases) based on input MRI of brain slice. I use the dataset described above to train and test the model using MRI of brain slices as inputs and known mask images as labels. The dataset is relatively small, and the numbers of images varies among the patients. I use 2D U-Net implementation from this link, with some minor modifications: I use parameters instead of constant values for convolutions and add sigmoid activation to the output tensor since the output tensor represents a mask image for one class only — for the abnormality area. Below is a code which I use for my model:

import torch
import torch.nn as nn

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))

def forward(self, x):
x = self.conv(x)
return x


class InConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(InConv, self).__init__()
self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x):
x = self.conv(x)
return x


class Down(nn.Module):
def __init__(self, in_ch, out_ch):
super(Down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_ch, out_ch)
)

def forward(self, x):
x = self.mpconv(x)
return x


class Up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(Up, self).__init__()

if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

self.conv = DoubleConv(in_ch, out_ch)

def forward(self, x1, x2):
x1 = self.up(x1)

diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x


class OutConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)

def forward(self, x):
x = self.conv(x)
return x


class Unet(nn.Module):

def __init__(self, in_channels, classes, first_out_channels=64):
super(Unet, self).__init__()
self.n_channels = in_channels
self.n_classes = classes
self.nc = first_out_channels
self.nc2 = 2 * self.nc
self.nc4 = 4 * self.nc
self.nc8 = 8 * self.nc

self.inc = InConv(in_channels, self.nc)
self.down1 = Down(self.nc, self.nc2)
self.down2 = Down(self.nc2, self.nc4)
self.down3 = Down(self.nc4, self.nc8)
self.down4 = Down(self.nc8, self.nc8)
self.up1 = Up(2*self.nc8, self.nc4)
self.up2 = Up(self.nc8, self.nc2)
self.up3 = Up(self.nc4, self.nc)
self.up4 = Up(self.nc2, self.nc)
self.outc = OutConv(self.nc, classes)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return nn.Sigmoid()(x)


# UNet model creation:
net = Unet(3, 1).to(device)

In the last line of the code, I create U-Net model object which has RGB-image with a brain slice (3 channels) in the input and produce an image with a mask of the area of abnormality (1 channel) in the output.

Let’s look inside the U-Net architecture. U-Net consists of two parts: Encoder and Decoder. Encoder contains series of convolution-blocks following by max pooling with image size reduction. The architecture of Encoder is similar to that of classifiers. It performs local and global feature extraction. Decoder appears symmetrically to Encoder: it contains series of image up-sampling blocks following by convolution-blocks. Finally, decoder restores the size of its input data to the size of the input image. An important point that is present in Decoder is concatenation of the outputs of Decoder blocks with symmetrical outputs of Encoder blocks. This algorithm allows us to associate the class-features extracted by Encoder with their spatial locations and form a segmentation mask. Picture below shows a schema of U-Net architecture:

Model training and results

I split the dataset into a training set and a test set. The training set contains 3541 MR images with different brain slices views + mask-images with the abnormality or zero. 2300 images have zero-masks, and 1241 images have masks with abnormality segments. Test data contains 394 MR images with different brain slices views + mask-images with the abnormality or zero-mask. 256 images have zero-masks and 138 images have masks with abnormality segments.

To prepare an input image with brain slice the following torchvision transforms are applied to the input data img which initially is in the numpy array form:

img = transforms.ToTensor()(img)
img = transforms.Resize((256, 256))(img)
s, m = torch.std_mean(img, dim=(0, 1, 2))
img = transforms.Normalize(m, 2*s)(img)

Transform to a torch tensor, resizing and normalization are applied to MR images. Only transform to a torch tensor and resizing are applied to mask images.

All images are resized to the resolution 256x256.

Normalization of brain slice images is important. My experiments with training the model for different cases — without input images normalization and with input images normalization — show that the quality of the model trained on normalized images is improved by ~3%. Picture below shows the same input image with normalization and without normalization:

I train U-Net model described in previous section. The summary of this model:

from torchinfo import summary
summary(model=net, input_size=(1, 3, 256, 256), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

The model contains more than 13 million of trainable parameters.

I use Binary Cross Entropy loss function to train the system to build masks closer to labels (masks) images. I use Adam optimizer and learning rate 0.0001. I use both IoU (Intersection over Union) and Dice metrics as a quality measure: IoU = 1 and Dice = 1 mean an ideal quality. During the training, I save model checkpoints each 2 epochs to choose one showing the best performance on the test set. I’ve got the best result after training during 30 epochs. Note, that for all results below including pictures, I use segmentation masks generated by trained model with threshold application: mask pixel values are set to 0 if they < 0.5 else mask pixel values are set to 1. On the test set the following quality is achieved (results below show mean values of IoU and Dice on 394 test entries):

test set: IoU = 0.8434411330978429, Dice = 0.9150724891122154

On the training set, the same model shows the following quality (the results below are mean values of IoU and Dice on 550 randomly selected training entries):

training set: IoU = 0.9289778242365415, Dice = 0.9631814451824671

Pictures below demonstrate work of the trained model on test images:

Note: I’ve tried modifications of U-Net with reduced parameters — only 2 ”Downsampling-Convolution” blocks from Encoder + 2 ”Upsampling-Convolution” blocks from Decoder, or 3 ”Downsampling-Convolution” blocks from Encoder + 3 ”Upsampling-Convolution” blocks from Decoder, or increasing the number of blocks in Encoder and Decoder. It doesn’t improve the quality of the trained model on the training and test sets. Classic U-Net shows the best results.

Data augmentation and performance improvement

The dataset which I use to train the model is relatively small — less than 4000 MRI + masks. I expect a better performance from the model trained on more representative dataset with more images. Doctors say that horizontal flip (changing left and right sides of the image) is an absolutely correct way for medical data augmentation. I’ve created a dataset with a number of images doubled compared with the source dataset, by applying flip functionality from OpenCV to MRI and to masks:

img_f = cv2.flip(img, 1)

Picture below shows the example of flip functionality application to an image and a mask:

I’ve trained the model on augmented dataset. During the training I save model checkpoints each 2 epochs to choose one showing the best performance on the test set. I’ve got the best result after training during 30 epochs. On the test set the following quality is achieved (results below show mean values of IoU and Dice on 787 test entries):

test set: IoU = 0.853091030096872, Dice = 0.9207222054841806

On the augmented training set, the same model shows the following quality (the results below are mean values of IoU and Dice on 550 randomly selected training entries):

training set: IoU = 0.9220136233462951, Dice = 0.9594246493851964

As we can see, augmentation of the data set slightly improves the model performance.

Conclusion

· U-Net allows to make segmentation on medical images with a high quality.

· Input images pre-processing with normalization helps to improve the performance of the segmentation model.

· Data augmentation helps to improve the performance of the segmentation model.

This article represents the first part of my research on solving the problem of medical images segmentation. In the next part I’ll try to solve this problem using Transformer for image segmentation.

--

--