The Impact of Attention Mechanisms on U-Net Performance in Skin Cancer Segmentation: A Case Study
Introduction
In today’s landscape of artificial intelligence, attention mechanisms have emerged as pivotal tools, transforming models’ capabilities across diverse domains. While initially prominent in natural language processing, attention mechanisms are increasingly proving their worth in computer vision tasks, particularly in image segmentation. In this article we will be exploring their impact in a classical application: skin cancer segmentation. Through this examination, we aim to revisit the potential of attention mechanisms in computer vision segmentation, with skin cancer segmentation serving as a compelling case study.
Dataset
The dataset we will be using is Skin Cancer Segmentation dataset which contains RGB images of skin cancer shot consist of 900 training data and 379 testing data along with the ground truth segmentation. The training data will then be split with a 4:1 ratio into a 720 training data and 180 validation data
Modelling
For modelling we will be using U-Net. The reasons for choosing U-Net is interpretability, the model is easy to implement and understand with a decent performance. U-Net was also initially build for the biomedical field, so it fits well with our data use case.
a. U-Net
U-Net mainly consist of two part, the encoding phase and the decoding phase. The model works by using the feature mapping learned from the previous encoding phase to build up segmentation map in the decoding phase.
The encoding phase extract the image’s features via convolution operations, pooling, and flattening into embedding vectors. The embedding vectors then used in the decoding phase, by being reversed back into image. Each decoding block contain three components which consist of an upsampling layer, such as convolution transpose or interpolation algorithm, a concantenating layer, and a double convolutional block. The decoding phase went through the same number of block/layer such as in the encoding phase untill the final segmentation image is created.
A closer look into the convolution blocks of U-Net that we will be using consist of two convolution layer along with a batch normalization, and a Relu activation function. The first convolution layer purpose is to reduce the number of channels (aggregating the features), usually in half to maintain symmetry, while the second layer continue to extract more in depth features.
class ConvBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out,
kernel_size=3, stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out,
kernel_size=3, stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
Next we define the up convolutional block which will be responsible for upsampling, similar to reverse pooling. In this case study we will be using interpolation pytorch default nearest neighbour algorithm.
class UpConvBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out,
kernel_size=3,stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = x = self.up(x)
return x
After went through an up convolutional block we want to crop the image to remove unwanted loss border pixel produced by the convolution operation of the block and keep the wanted size of the image. This part is optional.
def crop_image(tensor, tensor_target):
target_size = tensor_target.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size-target_size
delta = delta//2
return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]
Putting it all together, the final model will have an input of an image with three channel and an output of the segmentation prediction with one channel.
class UNet(nn.Module):
def __init__(self, n_classes=1, in_channel=3, out_channel=1):
super().__init__()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
self.conv2 = ConvBlock(ch_in=64, ch_out=128)
self.conv3 = ConvBlock(ch_in=128, ch_out=256)
self.conv4 = ConvBlock(ch_in=256, ch_out=512)
self.conv5 = ConvBlock(ch_in=512, ch_out=1024)
self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)
self.up4 = UpConvBlock(ch_in=512, ch_out=256)
self.upconv4 = ConvBlock(ch_in=512, ch_out=256)
self.up3 = UpConvBlock(ch_in=256, ch_out=128)
self.upconv3 = ConvBlock(ch_in=256, ch_out=128)
self.up2 = UpConvBlock(ch_in=128, ch_out=64)
self.upconv2 = ConvBlock(ch_in=128, ch_out=64)
self.conv_1x1 = nn.Conv2d(64, out_channel,
kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoder
x1 = self.conv1(x)
x2 = self.maxpool(x1)
x2 = self.conv2(x2)
x3 = self.maxpool(x2)
x3 = self.conv3(x3)
x4 = self.maxpool(x3)
x4 = self.conv4(x4)
x5 = self.maxpool(x4)
x5 = self.conv5(x5)
# decoder + concat
d5 = self.up5(x5)
x4 = crop_image(x4,d5)
d5 = torch.concat((x4, d5), dim=1)
d5 = self.upconv5(d5)
d4 = self.up4(d5)
x3 = crop_image(x3,d4)
d4 = torch.concat((x3, d4), dim=1)
d4 = self.upconv4(d4)
d3 = self.up3(d4)
x2 = crop_image(x2,d3)
d3 = torch.concat((x2, d3), dim=1)
d3 = self.upconv3(d3)
d2 = self.up2(d3)
x1 = crop_image(x1,d2)
d2 = torch.concat((x1, d2), dim=1)
d2 = self.upconv2(d2)
d1 = self.conv_1x1(d2)
return d1
# don't forget to initialize
unet = UNet(n_classes=1).to(device)
b. Soft Attention U-Net
Attention is the method of giving more weights to certain points of interest by scaling via dot product. U-Net utilize this method by using the vector embeddings produced by the encoding phase that then went through the decoding blocks and using the vector produced directly from feature map encoding.
Both the vector embedding and the feature map embedding undergo convolutional operations before being added together. This addition process enhances the importance of aligned features while diminishing the influence of unaligned ones, effectively guiding the model to focus more on relevant features during the segmentation process. The result then processed turned into attention coeffecient which we can use to multiply the original vector, scaling the points of interest.
In this case study we will be using soft attention because we want it to be trainable with backpropagation based on the query image.
class AttentionBlock(nn.Module):
def __init__(self, f_g, f_l, f_int):
super().__init__()
self.w_g = nn.Sequential(
nn.Conv2d(f_g, f_int,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(f_int)
)
self.w_x = nn.Sequential(
nn.Conv2d(f_l, f_int,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(f_int)
)
self.psi = nn.Sequential(
nn.Conv2d(f_int, 1,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid(),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.w_g(g)
x1 = self.w_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)
return psi*x
This time we don’t need to use cropping, as with attention the image is already weighted on the point of interest, ignoring the irrelevant part such as border.
class AttentionUNet(nn.Module):
def __init__(self, n_classes=1, in_channel=3, out_channel=1):
super().__init__()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
self.conv2 = ConvBlock(ch_in=64, ch_out=128)
self.conv3 = ConvBlock(ch_in=128, ch_out=256)
self.conv4 = ConvBlock(ch_in=256, ch_out=512)
self.conv5 = ConvBlock(ch_in=512, ch_out=1024)
self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
self.att5 = AttentionBlock(f_g=512, f_l=512, f_int=256)
self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)
self.up4 = UpConvBlock(ch_in=512, ch_out=256)
self.att4 = AttentionBlock(f_g=256, f_l=256, f_int=128)
self.upconv4 = ConvBlock(ch_in=512, ch_out=256)
self.up3 = UpConvBlock(ch_in=256, ch_out=128)
self.att3 = AttentionBlock(f_g=128, f_l=128, f_int=64)
self.upconv3 = ConvBlock(ch_in=256, ch_out=128)
self.up2 = UpConvBlock(ch_in=128, ch_out=64)
self.att2 = AttentionBlock(f_g=64, f_l=64, f_int=32)
self.upconv2 = ConvBlock(ch_in=128, ch_out=64)
self.conv_1x1 = nn.Conv2d(64, out_channel,
kernel_size=1, stride=1, padding=0)
def forward(self, x):
# encoder
x1 = self.conv1(x)
x2 = self.maxpool(x1)
x2 = self.conv2(x2)
x3 = self.maxpool(x2)
x3 = self.conv3(x3)
x4 = self.maxpool(x3)
x4 = self.conv4(x4)
x5 = self.maxpool(x4)
x5 = self.conv5(x5)
# decoder + concat
d5 = self.up5(x5)
x4 = self.att5(g=d5, x=x4)
d5 = torch.concat((x4, d5), dim=1)
d5 = self.upconv5(d5)
d4 = self.up4(d5)
x3 = self.att4(g=d4, x=x3)
d4 = torch.concat((x3, d4), dim=1)
d4 = self.upconv4(d4)
d3 = self.up3(d4)
x2 = self.att3(g=d3, x=x2)
d3 = torch.concat((x2, d3), dim=1)
d3 = self.upconv3(d3)
d2 = self.up2(d3)
x1 = self.att2(g=d2, x=x1)
d2 = torch.concat((x1, d2), dim=1)
d2 = self.upconv2(d2)
d1 = self.conv_1x1(d2)
return d1
# don't forget to initialize the model
attention_unet = AttentionUNet(n_classes=1).to(device)
Segmentation Metric
In segmentation we treat image similar to image classification where in image classification when an image is classified as a class all of the bits of the image expected to belong to that class. In image segmentation however, the bits are partially belong to that class so to evaluate the performance of the segmentation we look for the intersection between the truth bit and the predicted bit (True Positive) divided by the total of area where the prediction predict to be true but false (False Positive), the intersection, and the area where the prediction predict to be false but actually true (False Negative) to get a measurable truth percentage. Based on the explanation, we want to maximize the intersection and minimize the union of the prediction and the ground truth.
The metric used for segmentation is usually IoU and Dice Coeffecient. IoU is the ratio between the intersection of successfully segmented ground truth and the union of both ground truth and predicted segmentation, the concept is similar to accuracy in traditional classification.
If IoU is the percentage of success intersection then Dice is the F1 version of IoU, which is a form of harmonic mean between precision and recall of ground truth and predicted segmentation.
Similar to the difference between F1 and accuracy, Dice offers a more nuanced evaluation of segmentation overlap, whereas IoU is more interpretable.
def iou_metric(inputs, target):
intersection = (target*inputs).sum()
union = target.sum() + inputs.sum() - intersection
if target.sum() == 0 and inputs.sum() == 0:
return 1.0
return intersection/union
def dice_coef_metric(inputs, target):
intersection = 2.0 * (target*inputs).sum()
union = target.sum() + inputs.sum()
if target.sum() == 0 and inputs.sum() == 0:
return 1.0
return intersection/union
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
Training Result
From the graph training score history, we can see that the U-Net with attention experience more frequent correction then the regular U-Net but both model have generalize well on both the training set and the validation set without overfitting.
Test Result
Mean IoU of the test images
U-Net : 73.77556284268697%
Attention U-Net : 74.68607862790425%
Mean DICE coef of the test images
U-Net : 84.98269978194702%
Attention U-Net : 85.18728190991779%
The test results show that attention improves both the IoU and Dice coefficient metrics. This means that the attention mechanism helps the model more accurately cover the segmented intersection regions compared to the regular U-Net. This proves soft attention ability to supress irrelevant regions, reduce redundant features by giving more weight to important areas, and overall improve the model’s ability to segment the image in finer detail.
Acknowledgement
Full notebook source code can be viewed here
https://colab.research.google.com/drive/1kI-ykYkhZroacO0IKVgH6cnfG1rq_7Rd?usp=sharing
The main code reference is taken from
- Preprocessing and attention U-Net https://www.kaggle.com/code/truthisneverlinear/attention-u-net-pytorch#Attention-U-Net
- Image comparison https://www.kaggle.com/code/joshuasirusstara/bdcspn-tim/notebook
- SLIC from https://scikit-image.org/
Concepts and explanations are covered from
- https://towardsdatascience.com/a-detailed-explanation-of-the-attention-u-net-b371a5590831
- https://arxiv.org/pdf/1804.03999.pdf
Python’s Gurus🚀
Thank you for being a part of the Python’s Gurus community!
Before you go:
- Be sure to clap x50 time and follow the writer ️👏️️
- Follow us: Newsletter
- Do you aspire to become a Guru too? Submit your best article or draft to reach our audience.