A Gentle Introduction to U-Net

Rishav Banerjee
SRM MIC
Published in
4 min readJun 11, 2020
I’d be surprised if this is your first time seeing this, literally all over the place

Right, this picture right here, is a U-Net architecture, visualised. But before we get into the technical aspect, why even use a U-Net? You probably came across this while searching up semantic segmentation, to which I will be writing a blog soon, till then check this one out. The original intent behind making this architecture was for biomedical image segmentation. Nowadays, it is the most popular network for any form of semantic segmentation, and the technique is also widely adopted across various Kaggle competitions as an incredibly reliable method for a wide host of computer vision problems.

So what is a U-Net?

To understand U-Nets, we need to understand the intuition behind a skip connection. This is just to get a better idea of what is happening, it is not explicitly used in the actual paper/model. A skip connection, as the name suggests, is a connection made from a earlier part of the network to a later part, and information is transferred over. The intuition behind this is pretty simple: we wish to bring back lost information over some layers back, to give better context to the network. Over the course of multiple convolutional process, as the network dives into the lower level features, it loses context of the higher level features. Sending over the lost information via a skip connection, i.e. a connection made between two layers that is not in the actual order of the information sent.

Here we can see how the input is added to the network’s output for a new output

So what does this have to do with U-Nets?

Simple: The idea or intuition behind a convolution process is twofold: to extract features, and reduce computation time on a smaller image. However, it loses information over this process. A U-Net, makes the best of both worlds, such that when we upsample or upscale our image, we concatenate the feature maps to the corresponding upscaled set of maps.

And now this image should make a lot more sense!

Observe how the skip connections work. For the first inverse convolution process, you have the feature maps of the last convolution process sent over. For the second inverse convolution, the second last, and so on. To quote the authors (Olaf Ronneberger, Philipp Fischer, and Thomas Brox) from the paper:

To predict the pixels in the border region of the image, the missing context is extrapolated by mirroring the input image.

A great feature of this network, as boasted by the authors, is that this network has the capability to train on very little data with excessive augmentations, and train with great accuracy at that.

What does the code look like?

I’ll be doing this using pytorch, if you would like to see a tensorflow/keras approach to this, try this article (all the way at the bottom). Let us start off by defining some helper classes.

Firstly, a double convolution layer, since every single block uses one of these, irrespective of upsampling or downsampling.

Then a Downsample block using the DoubleConv that we made earler.

Now a Upsample block, again by using the DoubleConv. Note how we have the choice to choose between bilinear upsampling and a transpose convolution process here. We choose bilinear simply because research has shown it gives better results and also the paper said so feel free to experiment with transpose.

Finally, a class for the final layer, with the required number of out channels for the completed segmentation color map.

Now, let’s apply it all in the Model itself. I have used some hyperparameters here, which were from a previous project, feel free to tweak for your needs.

Should be pretty easy to understand what is going on here (all hail pytorch and it’s pythonic style of code). Pay close attention to how we are passing x4, x3 and so on with their corresponding upsampling block, to emulate the U-Net design, and it’s skip connections. We are sending over that x3 to two places: to the next block for downsampling, and for data preservation for upsampling at the second upsampling block.

Applications

As we discussed above, the biggest use for this is in semantic segmentation, and has performed excellently in this field thus far.

A very popular kaggle competition for creating masks of these cars for easy extraction

Conclusion

Hopefully you have a good idea of what unets are and how they work, or at least the intuition behind how they work. Feel free to check out this article here for an incredibly in-depth look at the code for this amazing architecture. Have a good day!

--

--