The Impact of Attention Mechanisms on U-Net Performance in Skin Cancer Segmentation: A Case Study

Joshua Sirusstara
Python’s Gurus
Published in
8 min readJun 7, 2024
A focus gaze art by David Sirusstara

Introduction

In today’s landscape of artificial intelligence, attention mechanisms have emerged as pivotal tools, transforming models’ capabilities across diverse domains. While initially prominent in natural language processing, attention mechanisms are increasingly proving their worth in computer vision tasks, particularly in image segmentation. In this article we will be exploring their impact in a classical application: skin cancer segmentation. Through this examination, we aim to revisit the potential of attention mechanisms in computer vision segmentation, with skin cancer segmentation serving as a compelling case study.

Dataset

The dataset we will be using is Skin Cancer Segmentation dataset which contains RGB images of skin cancer shot consist of 900 training data and 379 testing data along with the ground truth segmentation. The training data will then be split with a 4:1 ratio into a 720 training data and 180 validation data

Dataset sample column 1 feature image, column 2 segmentation label, column 3 baseline using SLIC

Modelling

For modelling we will be using U-Net. The reasons for choosing U-Net is interpretability, the model is easy to implement and understand with a decent performance. U-Net was also initially build for the biomedical field, so it fits well with our data use case.

a. U-Net

U-Net mainly consist of two part, the encoding phase and the decoding phase. The model works by using the feature mapping learned from the previous encoding phase to build up segmentation map in the decoding phase.

The encoding phase extract the image’s features via convolution operations, pooling, and flattening into embedding vectors. The embedding vectors then used in the decoding phase, by being reversed back into image. Each decoding block contain three components which consist of an upsampling layer, such as convolution transpose or interpolation algorithm, a concantenating layer, and a double convolutional block. The decoding phase went through the same number of block/layer such as in the encoding phase untill the final segmentation image is created.

A closer look into the convolution blocks of U-Net that we will be using consist of two convolution layer along with a batch normalization, and a Relu activation function. The first convolution layer purpose is to reduce the number of channels (aggregating the features), usually in half to maintain symmetry, while the second layer continue to extract more in depth features.

class ConvBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out,
kernel_size=3, stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out,
kernel_size=3, stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)

def forward(self, x):
x = self.conv(x)
return x

Next we define the up convolutional block which will be responsible for upsampling, similar to reverse pooling. In this case study we will be using interpolation pytorch default nearest neighbour algorithm.

class UpConvBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out,
kernel_size=3,stride=1,
padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
)

def forward(self, x):
x = x = self.up(x)
return x

After went through an up convolutional block we want to crop the image to remove unwanted loss border pixel produced by the convolution operation of the block and keep the wanted size of the image. This part is optional.

def crop_image(tensor, tensor_target):
target_size = tensor_target.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size-target_size
delta = delta//2
return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]

Putting it all together, the final model will have an input of an image with three channel and an output of the segmentation prediction with one channel.

class UNet(nn.Module):
def __init__(self, n_classes=1, in_channel=3, out_channel=1):
super().__init__()

self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
self.conv2 = ConvBlock(ch_in=64, ch_out=128)
self.conv3 = ConvBlock(ch_in=128, ch_out=256)
self.conv4 = ConvBlock(ch_in=256, ch_out=512)
self.conv5 = ConvBlock(ch_in=512, ch_out=1024)

self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)

self.up4 = UpConvBlock(ch_in=512, ch_out=256)
self.upconv4 = ConvBlock(ch_in=512, ch_out=256)

self.up3 = UpConvBlock(ch_in=256, ch_out=128)
self.upconv3 = ConvBlock(ch_in=256, ch_out=128)

self.up2 = UpConvBlock(ch_in=128, ch_out=64)
self.upconv2 = ConvBlock(ch_in=128, ch_out=64)

self.conv_1x1 = nn.Conv2d(64, out_channel,
kernel_size=1, stride=1, padding=0)

def forward(self, x):
# encoder
x1 = self.conv1(x)

x2 = self.maxpool(x1)
x2 = self.conv2(x2)

x3 = self.maxpool(x2)
x3 = self.conv3(x3)

x4 = self.maxpool(x3)
x4 = self.conv4(x4)

x5 = self.maxpool(x4)
x5 = self.conv5(x5)

# decoder + concat
d5 = self.up5(x5)
x4 = crop_image(x4,d5)
d5 = torch.concat((x4, d5), dim=1)
d5 = self.upconv5(d5)

d4 = self.up4(d5)
x3 = crop_image(x3,d4)
d4 = torch.concat((x3, d4), dim=1)
d4 = self.upconv4(d4)

d3 = self.up3(d4)
x2 = crop_image(x2,d3)
d3 = torch.concat((x2, d3), dim=1)
d3 = self.upconv3(d3)

d2 = self.up2(d3)
x1 = crop_image(x1,d2)
d2 = torch.concat((x1, d2), dim=1)
d2 = self.upconv2(d2)

d1 = self.conv_1x1(d2)

return d1

# don't forget to initialize
unet = UNet(n_classes=1).to(device)

b. Soft Attention U-Net

Attention is the method of giving more weights to certain points of interest by scaling via dot product. U-Net utilize this method by using the vector embeddings produced by the encoding phase that then went through the decoding blocks and using the vector produced directly from feature map encoding.

Both the vector embedding and the feature map embedding undergo convolutional operations before being added together. This addition process enhances the importance of aligned features while diminishing the influence of unaligned ones, effectively guiding the model to focus more on relevant features during the segmentation process. The result then processed turned into attention coeffecient which we can use to multiply the original vector, scaling the points of interest.

In this case study we will be using soft attention because we want it to be trainable with backpropagation based on the query image.

class AttentionBlock(nn.Module):
def __init__(self, f_g, f_l, f_int):
super().__init__()

self.w_g = nn.Sequential(
nn.Conv2d(f_g, f_int,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(f_int)
)

self.w_x = nn.Sequential(
nn.Conv2d(f_l, f_int,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(f_int)
)

self.psi = nn.Sequential(
nn.Conv2d(f_int, 1,
kernel_size=1, stride=1,
padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid(),
)

self.relu = nn.ReLU(inplace=True)

def forward(self, g, x):
g1 = self.w_g(g)
x1 = self.w_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)

return psi*x

This time we don’t need to use cropping, as with attention the image is already weighted on the point of interest, ignoring the irrelevant part such as border.

class AttentionUNet(nn.Module):
def __init__(self, n_classes=1, in_channel=3, out_channel=1):
super().__init__()

self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

self.conv1 = ConvBlock(ch_in=in_channel, ch_out=64)
self.conv2 = ConvBlock(ch_in=64, ch_out=128)
self.conv3 = ConvBlock(ch_in=128, ch_out=256)
self.conv4 = ConvBlock(ch_in=256, ch_out=512)
self.conv5 = ConvBlock(ch_in=512, ch_out=1024)

self.up5 = UpConvBlock(ch_in=1024, ch_out=512)
self.att5 = AttentionBlock(f_g=512, f_l=512, f_int=256)
self.upconv5 = ConvBlock(ch_in=1024, ch_out=512)

self.up4 = UpConvBlock(ch_in=512, ch_out=256)
self.att4 = AttentionBlock(f_g=256, f_l=256, f_int=128)
self.upconv4 = ConvBlock(ch_in=512, ch_out=256)

self.up3 = UpConvBlock(ch_in=256, ch_out=128)
self.att3 = AttentionBlock(f_g=128, f_l=128, f_int=64)
self.upconv3 = ConvBlock(ch_in=256, ch_out=128)

self.up2 = UpConvBlock(ch_in=128, ch_out=64)
self.att2 = AttentionBlock(f_g=64, f_l=64, f_int=32)
self.upconv2 = ConvBlock(ch_in=128, ch_out=64)

self.conv_1x1 = nn.Conv2d(64, out_channel,
kernel_size=1, stride=1, padding=0)

def forward(self, x):
# encoder
x1 = self.conv1(x)

x2 = self.maxpool(x1)
x2 = self.conv2(x2)

x3 = self.maxpool(x2)
x3 = self.conv3(x3)

x4 = self.maxpool(x3)
x4 = self.conv4(x4)

x5 = self.maxpool(x4)
x5 = self.conv5(x5)

# decoder + concat
d5 = self.up5(x5)
x4 = self.att5(g=d5, x=x4)
d5 = torch.concat((x4, d5), dim=1)
d5 = self.upconv5(d5)

d4 = self.up4(d5)
x3 = self.att4(g=d4, x=x3)
d4 = torch.concat((x3, d4), dim=1)
d4 = self.upconv4(d4)

d3 = self.up3(d4)
x2 = self.att3(g=d3, x=x2)
d3 = torch.concat((x2, d3), dim=1)
d3 = self.upconv3(d3)

d2 = self.up2(d3)
x1 = self.att2(g=d2, x=x1)
d2 = torch.concat((x1, d2), dim=1)
d2 = self.upconv2(d2)

d1 = self.conv_1x1(d2)

return d1

# don't forget to initialize the model
attention_unet = AttentionUNet(n_classes=1).to(device)

Segmentation Metric

In segmentation we treat image similar to image classification where in image classification when an image is classified as a class all of the bits of the image expected to belong to that class. In image segmentation however, the bits are partially belong to that class so to evaluate the performance of the segmentation we look for the intersection between the truth bit and the predicted bit (True Positive) divided by the total of area where the prediction predict to be true but false (False Positive), the intersection, and the area where the prediction predict to be false but actually true (False Negative) to get a measurable truth percentage. Based on the explanation, we want to maximize the intersection and minimize the union of the prediction and the ground truth.

The metric used for segmentation is usually IoU and Dice Coeffecient. IoU is the ratio between the intersection of successfully segmented ground truth and the union of both ground truth and predicted segmentation, the concept is similar to accuracy in traditional classification.

IoU Formula

If IoU is the percentage of success intersection then Dice is the F1 version of IoU, which is a form of harmonic mean between precision and recall of ground truth and predicted segmentation.

Dice Coeffecient formula

Similar to the difference between F1 and accuracy, Dice offers a more nuanced evaluation of segmentation overlap, whereas IoU is more interpretable.

def iou_metric(inputs, target):

intersection = (target*inputs).sum()
union = target.sum() + inputs.sum() - intersection

if target.sum() == 0 and inputs.sum() == 0:
return 1.0
return intersection/union

def dice_coef_metric(inputs, target):

intersection = 2.0 * (target*inputs).sum()
union = target.sum() + inputs.sum()

if target.sum() == 0 and inputs.sum() == 0:
return 1.0
return intersection/union

class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

#comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)

#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

return 1 - dice

Training Result

left Image U-Net Dice Coef training score history, right image U-Net with Soft Attention Dice Coef training score history

From the graph training score history, we can see that the U-Net with attention experience more frequent correction then the regular U-Net but both model have generalize well on both the training set and the validation set without overfitting.

Test Result

Mean IoU of the test images 
U-Net : 73.77556284268697%
Attention U-Net : 74.68607862790425%
Mean DICE coef of the test images
U-Net : 84.98269978194702%
Attention U-Net : 85.18728190991779%
Comparison between segmentation result of regular U-Net vs U-Net with Soft Attention mechanism

The test results show that attention improves both the IoU and Dice coefficient metrics. This means that the attention mechanism helps the model more accurately cover the segmented intersection regions compared to the regular U-Net. This proves soft attention ability to supress irrelevant regions, reduce redundant features by giving more weight to important areas, and overall improve the model’s ability to segment the image in finer detail.

Acknowledgement

Full notebook source code can be viewed here

https://colab.research.google.com/drive/1kI-ykYkhZroacO0IKVgH6cnfG1rq_7Rd?usp=sharing

The main code reference is taken from

Concepts and explanations are covered from

Python’s Gurus🚀

Thank you for being a part of the Python’s Gurus community!

Before you go:

  • Be sure to clap x50 time and follow the writer ️👏️️
  • Follow us: Newsletter
  • Do you aspire to become a Guru too? Submit your best article or draft to reach our audience.

--

--