Maru Tech 🐞
Data And Beyond
Published in
13 min readMay 30, 2023

--

(A deep dive) into U-NET paper : Convolutional Networks for Biomedical Image Segmentation paper

Helloo!╰(*°▽°*)╯

🎇🎇🎇🎇🎇🎇🎇

it’s been a while

Well i’ve spent some time studying the concept of segmentation specifically with UNET . And I want to share with you my humble explanation of this topic, primarily for those who are new to the field or looking for a quick refresher.

While my understanding of this topic might not be as extensive as that of seasoned professionals, I believe that sharing what I’ve learned can still be beneficial. It may serve as a time-saving resource for those who are just starting to explore the subject or need a concise explanation , I welcome any feedback or criticism from senior professionals who may come across this post. Your insights and guidance will be greatly appreciated and will contribute to my growth as a student.

Thank you for taking the time to read this post 💜💜💜 . If you have any questions please feel free to ask

1. Introduction

Image segmentation is a computer vision task that involves dividing an image into different regions or segments, typically based on visual patterns or object boundaries. It plays a crucial role in various applications such as medical imaging, autonomous driving, and object recognition.

The goal of image segmentation is to assign a label or category to each pixel in an image, allowing for detailed analysis and understanding of its content , and one of its most commun architectutres , is as called UNET

This architecture has become widely used in various medical image analysis tasks, particularly in the field of biomedical image segmentation.

One of the key strengths of the U-Net architecture is its ability to handle small datasets, which are common in the biomedical field. This is achieved through data augmentation techniques of the input images.

Additionally, the U-Net has demonstrated its speed in both training and inference time where the segmentation of 512 x 512 images takes less than a second on the GPUs which can further ease the segmentation process .

Well in this blogpost we are going to explain The paper “ U-Net : Convolutional Networks for Biomedical Image Segmentation” which was published in the Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI) in 2015.

2. The network architecture

The U-Net architecture is a fully convolutional neural network (FCN) that is specifically designed for biomedical image segmentation tasks.

The architecture U shaped and consists of an encoder , decoder and skip connections .

The encoder network is composed of a series of repeated
application of two unpadded 3x3 conv , ReLU followed by a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels . the aim here is to reduce the spatial resolution of the input image while gaining a high contextual details that describes the ‘what’ information in the picture .

For the expansive path every level consists of a repeated 3 steps :

1- an upsampling of the feature map with a 2x2 convolution (“up-convolution”) that halves the number of feature channels .

2- a concatenation with the correspondingly cropped feature map from the contracting path .

3- a two 3x3 convolutions, each followed by a ReLU .

at the end of the expansive path we have a conv 1x1 for resizing the channels axis to the desired number of channels aka features .

the aim here is to gradually increase the spatial resolution of the encoded features to recover the ‘WHERE’ information . The skip connections between the encoder and decoder networks help to preserve high-resolution fine grained information and improve segmentation accuracy.

________________________________________________________________

3. Training

For this part it may look a little bit complicated but just stay focused , you ll get it easily

For training they have used :

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

  • 3.1. overlap tile strategy :

the original input image is divided into smaller overlapping patches or tiles (overlap is to ensure that the network captures contextual information near the image boundaries) Each tile is fed into the network individually for processing then the predicted outputs are stitched and blended together to form the final seamless prediction for the entire input image

another thing to mention is that since convolutional neural networks often struggle with accurately predicting values near the borders of an image because you know . the receptive field of the network’s convolutional layers becomes smaller as it moves towards the borders or it may suffer from information loss due to the amount of inconsistant pixels values added trough padding , therefore the authors have used mirroring technique to provide additional information for making predictions in the border regions of the image but how it has been applied ?

they take the original image and create a mirrored copy of it then they effectively extend the image beyond its original boundaries , when processing the image, it is divided into smaller overlapping tiles these tiles include both the original pixels and the mirrored pixels . This way, when the network predicts the values for the border pixels, it can take into account the reflected context from neighboring tiles . After the predictions are made, the border regions that include the mirrored pixels are cropped out leaving only the original non-mirrored pixels as the final predicted output .

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 3.2. the optimizer :

Stochastic gradient descent (SGD) with a learning rate of 0.01 and a momentum of 0.99.

SGD : useful when dealing with large datasets because it updates the model parameters based on a single example, rather than the entire dataset , and sometimes they used sgd also when talking about updating after a small subset of training examples (minibatch)

LR : the learning rate is a hyperparameter that determines the step size at which the model parameters are updated during training , a learning rate of 0.01 is a typical starting point that strikes a balance between convergence speed and stability

MOMENTIUM : accumulates the past gradients and applies a weighted average to determine the direction and magnitude of the next update. A momentum of 0.99 indicates that a large fraction of the past gradients is considered, and it helps to smooth out the updates and accelerate convergence.

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

3.3. weights initialization :

The initial weights were sampled from a Gaussian distribution with a standard deviation of np.sqrt(2/N), where N denotes the number of incoming nodes of one neuron .

For example, in a scenario where we have a 3x3 convolutional layer with 64 feature channels in the previous layer, the number of incoming nodes (N) would be 9 multiplied by 64, resulting in N = 576. By applying the formula, the standard deviation would be np.sqrt(2/576) .(Weight initialization is crucial since it enables faster convergence during training , reduces the likelihood of getting stuck in suboptimal solutions and avoiding vanishing or exploding gradients .)

but why they have used this std specifically ????!!!

well the standard deviation np.sqrt(2/N) is derived from the concept of Xavier initialization ‘glorot’ , which aims to set the initial weights of the network in a way that balances the signal propagation and gradient flow during training.

The value of N represents the number of incoming connections that are contributing to the activation of a neuron so by using np.sqrt(2/N) as the standard deviation, the weights are initialized in such a way that the variance of the inputs to each neuron remains approximately constant across different layers of the network

np.sqrt(2/10) = 0.44 — — ‘a reasonable variance’— — - np.sqrt(2/100) =0.14)

thus avoiding the issues of vanishing (the weights are too small, they may cause the signal to diminish as it propagates through the network) or exploding gradients (the opposite case).

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

  • 3.4. The loss function :

Since the biological images are microscopical to a large extent , there is difficulty in capturing the spaces separations between the adjacent cells , so to encourage the network to learn fine separation borders between these cells and address the variations in pixel frequencies across different classes in the training dataset , the authors have employed a precomputed weight map for each ground truth segmentation this weight map serves two purposes , class frequency balancing and enhancement of separation borders . (A weight map is a 2D image with the same size as the input image, where each pixel is assigned a weight based on its importance for the segmentation task) .

therefore false predictions on the separation regions will cause a higher penalization on the model than any other region in the input image

the formula is as follow :

where :

  • wc(x) represents the weight map used to balance the frequencies of different classes( wc may be manually assigned by domain experts who have prior knowledge about the importance of different regions in the images. They can assign higher weights to underrepresented classes or regions that require more attention during training , or automatically and this can involve statistical analysis of the dataset to determine the relative occurrence of different classes and assign weights accordingly. For example, the weight map can be inversely proportional to the class frequencies, assigning higher weights to less frequent classes.)
  • d1(x) corresponds to the distance from the pixel x to the border of the nearest cell .
  • d2(x) represents the distance to the border of the second nearest cell .
  • w0 = 10 and σ to approximately 5 pixels .

however considering this weighted map as the ground-truth label , the loos function was calculated through a combination of a pixel-wise soft-max operation and the cross entropy loss function , By combining these two , we achieve two important objectives. First, the soft max function converts the model’s activations into meaningful probabilities. Second, the cross entropy loss function provides a way to measure the dissimilarity between these predicted probabilities and the ground truth labels .

  • 3.3.1. The soft-max function :
soft-max function

where :

  • ak(x) represents the activation in the feature channel k at the pixel position x .
  • K is the total number of classes .
  • pk(x) is approximately 1 for the class k that has the highest activation ak(x), and it is approximately 0 for all other classes.

This function works by dividing the exponential activation of class k (exp(ak(x))) by the sum of exponentials across all classes , thus we obtain the probability pk(x) for that class at that pixel position.

ps : (The exponential transformation helps emphasize larger activation values and suppress smaller ones , this means that classes with higher activations will have higher probabilities assigned to them, indicating a stronger likelihood of being the most probable class at that pixel position).

  • 3.3.2. A binary cross-entropy loss function

after applying soft-max the cross entropy function then is used to penalize the deviation of p l (x)(x) (the prediction) from 1 at each position .

general case of binary cross entropy function :

L(y, ŷ) = -[y * log(p(y)) + (1 — y) * log(1 — p(y))]

where:

L represents the binary cross entropy loss.

y is the true label (either 0 or 1).

ŷ is the predicted probability that the input belongs to class 1.

When y = 1, the first term y * log(p(y)) becomes the negative log of the predicted probability p(y). As p(y) approaches 1, the loss approaches 0, indicating a correct prediction for class 1. However, as ŷ approaches 0, the loss approaches infinity, indicating a severe penalty for predicting class 0 when the true label is 1.
when y = 0, the second term (1 — y) * log(1 — p(y)) becomes the negative log of (1 — p(y)) , which is equivalent to the negative log of the predicted probability for class .

in the paper a small modification is applied to the cross entropy function and the modified formula is given by :

where :

  • w(x) represents the computed weight map
  • p l (x)(x) is the predicted probability.

we can write the formula as :

E = Σx∈Ω [w(x)* log(p (x)) + (1 — w(x)) * log(1 — p (x)))]

and the calculation will be similar to the regular function except replacing the ground-truth map values with a weighted ones

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

  • 3.4. Data augmentation :

They’ve applied elastic deformations which simulate realistic tissue variations by introducing smooth deformations that mimic the natural variations observed in biological tissues

“ We generate smooth deformations using random displacement vectors on a coarse 3 by 3 grid. The displacements are sampled from a Gaussian distribution with 10 pixels standard deviation. Per-pixel displacements are then computed using bicubic interpolation. Drop-out layers at the end of the contracting path perform further implicit
data augmentation.” They said

let me explain this with an example

let’s suppose we have a microscopical image of a cell we start by creating a coarse 3 by 3 grid over this image , each grid point (grid cell center) represents a region of interest in the image , then we generate random displacement vectors (from a Gaussian distribution with a standard deviation of 10 pixels) for each grid point which enable us to determine the amount and direction of shift for the corresponding region .

Once we have the random displacement vectors assigned to the grid points, we proceed to calculate the per-pixel displacements using Bicubic interpolation which estimates the pixel values based on the surrounding pixels, providing a continuous and smooth transformation in order to make the deformations appear realistic and seamless across the entire image.

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

The authors trained their model for 100 epochs, with a batch size of 1. They used early stopping based on the validation loss to prevent overfitting.

4. Experiments

they have experimented on 3 tasks :

the ISBI cell tracking challenge 2012 the segmentation of neuronal structures in electron microscopic recordings task

  • 30 images (512x512 pixels) from serial section transmission electron microscopy of the Drosophila first instar larva ventral nerve cord (VNC).
  • fully annotated black and white groundtruth

The evaluation is done by thresholding the map at 10 different levels and computation of the “warping error”, the “Rand error” and the “pixel error”

the table of comparaison shows that the UNET gave the best warping error rate

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

the ISBI cell tracking challenge 2014 cell segmentation task in light microscopic images on 2 datasets :

  • “PhC-U373” : 35 partially annotated training images of Glioblastoma-astrocytoma
  • “DIC-HeLa” : 20 partially annotated training images of HeLa cells on a flat glass recorded
    by differential interference contrast (DIC) microscopy

here is the table of comparaison we can see that they have achieved significantly better intersection over union results with a large gap comparing with the other architectures

5. Conclusion

In conclusion, the U-Net architecture is a highly effective and efficient method for biomedical image segmentation tasks. Its unique design, with skip connections between the encoder and decoder networks, helps to preserve high-resolution information and improve segmentation accuracy. The U-Net architecture has become widely adopted in the biomedical field, and its impact on the field of medical image segmentation cannot be overstated .

Thank you for reading !

REFERENCES

--

--

Maru Tech 🐞
Data And Beyond

Deep learning & computer vision engineer | Algeria | Data And Beyond Author