Build and Train U-Net from scratch using Tensorflow2.0

Prateek Bhatt
Analytics Vidhya
Published in
8 min readJul 2, 2020

I, as a Machine Learning Engineer, have always tried to understand the published research papers and it has always been hard to understand those papers and even harder to replicate the results.

One such paper is U-Net: Convolutional Networks for Biomedical Image Segmentation. I have seen quite a few implementations of U-Net, but they all do not implement the exact architecture explained in the paper.

During the implementation of the code and the article, I did refer to the tutorial provided by TensorFlow, which also implements a modified version of the model.

Later, I came across a youtube video from Abhishek Thakur that shows step by step implementation of the U-Net model. He uses PyTorch for it, I myself have not used PyTorch a lot, so I thought of creating the U-Net using TensorFlow.

U-Net model

The model looks in the shape of a U and so the name has been derived from it.
There are two main parts of the U-Net mode, one part is contracting path which is the left side of the model and the other is expansive pathwhich is the right side of the model.

The contracting path consists of repeatedly applying two convolutional layers (3x3 — unpadded). The convolutional layer is followed byRelu and maxpooling (2x2) with stride 2. [2]

The expansive path consists of up-convolutions(2x2) also known as convolutionTranspose for upsampling. The up-convolutions reduces the feature map by half. Later the result is concatenated with its counterpart from the contracting path after cropping. Similar to the contracting path the layer Relu is followed after two consecutive convolutional layers. [Reference]

Do not bother if the model is not clear from the above explanation. I will go through each step along with the code.

Let us concentrate on the contracting and the expansive path separately.

The Contracting Path

The contracting path looks like below. The image is from the official U-Net paper. I have just isolated the contracting path for a better explanation.

The different arrows represent different layers and operations.

Contracting Path — U-Net [2]
Different Arrows [2]

For the contracting part we need convolutional layer (3x3) and maxpool (2x2) layer. You can also see from the image that two convolutional layers are together. The number written above the line is the number that represents the number of filters . The number written at the side of the line represents the size of the image at that particular moment.

So consider the below example form the U-Net architecture, where 64 represents the filter size to be used in the convolutional layer and (572x572) (570x570)(568,568) represents the image size. Here the paper explains the output image size after each layer. This makes it very easy to implement.

First part [2]

So for the above image, the implemented layers will look something like this.

# In the original paper the network consisted of only one channel.
inputs = layers.Input(shape=(572, 572, 1))
# first part of the U - contracting part
c0 = layers.Conv2D(64, activation='relu', kernel_size=3)(inputs)
c1 = layers.Conv2D(64, activation='relu', kernel_size=3)(c0)

I hope the above example code gives you an idea, let's take another example and try to implement the code for it.

Second part[2]
c2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c1)

c3 = layers.Conv2D(128, activation='relu', kernel_size=3)(c2)
c4 = layers.Conv2D(128, activation='relu', kernel_size=3)(c3) # This layer for concatenating in the expansive part
c5 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c4)

Here we have two maxpool (2x2); one before the two convolutional layers and one after. The complete Contracting part will look like this below:

# declaring the input layer
# In the original paper the network consisted of only one channel.
inputs = layers.Input(shape=(572, 572, 3))
# first part of the U - contracting part
c0 = layers.Conv2D(64, activation='relu', kernel_size=3)(inputs)
c1 = layers.Conv2D(64, activation='relu', kernel_size=3)(c0) # This layer for concatenating in the expansive part
c2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c1)

c3 = layers.Conv2D(128, activation='relu', kernel_size=3)(c2)
c4 = layers.Conv2D(128, activation='relu', kernel_size=3)(c3) # This layer for concatenating in the expansive part
c5 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c4)

c6 = layers.Conv2D(256, activation='relu', kernel_size=3)(c5)
c7 = layers.Conv2D(256, activation='relu', kernel_size=3)(c6) # This layer for concatenating in the expansive part
c8 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c7)

c9 = layers.Conv2D(512, activation='relu', kernel_size=3)(c8)
c10 = layers.Conv2D(512, activation='relu', kernel_size=3)(c9) # This layer for concatenating in the expansive part
c11 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c10)

c12 = layers.Conv2D(1024, activation='relu', kernel_size=3)(c11)
c13 = layers.Conv2D(1024, activation='relu', kernel_size=3, padding='valid')(c12)

Expansive Part

Now we will implement the expansive part of the model. The expansive part also needs the contracting part. The outputs from the convolutional layers from the contracting part are cropped and then concatenated with the expansive part. We will look at the expansive part step by step.

First part [2]

You can see that here we have to use up-convolution. The up-convolution has 512 filters, the other 512 filters come from the contracting part.

# We will now start the second part of the U - expansive part
t01 = layers.Conv2DTranspose(512, kernel_size=2, strides=(2, 2), activation='relu')(c13)
Concatenation of contracting part [2]

We will have to get the output of the layer c10 and we have to crop the image from (64x64) to (56x56) so that it can be concatenated with the expansive part.

crop01 = layers.Cropping2D(cropping=(4, 4))(c10)
concat01 = layers.concatenate([t01, crop01], axis=-1)

Once concatenated we have to add two convolutional layers each with 512 filters.

c14 = layers.Conv2D(512, activation='relu', kernel_size=3)(concat01)
c15 = layers.Conv2D(512, activation='relu', kernel_size=3)(c14)

We will be repeating until we achieve the following architecture.

U-Net Architecture [2]

The complete model is as below along with the visualization of the model.

# declaring the input layer
# In the original paper the network consisted of only one channel.
inputs = layers.Input(shape=(572, 572, 1))
# first part of the U - contracting part
c0 = layers.Conv2D(64, activation='relu', kernel_size=3)(inputs)
c1 = layers.Conv2D(64, activation='relu', kernel_size=3)(c0) # This layer for concatenating in the expansive part
c2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c1)

c3 = layers.Conv2D(128, activation='relu', kernel_size=3)(c2)
c4 = layers.Conv2D(128, activation='relu', kernel_size=3)(c3) # This layer for concatenating in the expansive part
c5 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c4)

c6 = layers.Conv2D(256, activation='relu', kernel_size=3)(c5)
c7 = layers.Conv2D(256, activation='relu', kernel_size=3)(c6) # This layer for concatenating in the expansive part
c8 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c7)

c9 = layers.Conv2D(512, activation='relu', kernel_size=3)(c8)
c10 = layers.Conv2D(512, activation='relu', kernel_size=3)(c9) # This layer for concatenating in the expansive part
c11 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid')(c10)

c12 = layers.Conv2D(1024, activation='relu', kernel_size=3)(c11)
c13 = layers.Conv2D(1024, activation='relu', kernel_size=3, padding='valid')(c12)

# We will now start the second part of the U - expansive part
t01 = layers.Conv2DTranspose(512, kernel_size=2, strides=(2, 2), activation='relu')(c13)
crop01 = layers.Cropping2D(cropping=(4, 4))(c10)

concat01 = layers.concatenate([t01, crop01], axis=-1)

c14 = layers.Conv2D(512, activation='relu', kernel_size=3)(concat01)
c15 = layers.Conv2D(512, activation='relu', kernel_size=3)(c14)

t02 = layers.Conv2DTranspose(256, kernel_size=2, strides=(2, 2), activation='relu')(c15)
crop02 = layers.Cropping2D(cropping=(16, 16))(c7)

concat02 = layers.concatenate([t02, crop02], axis=-1)

c16 = layers.Conv2D(256, activation='relu', kernel_size=3)(concat02)
c17 = layers.Conv2D(256, activation='relu', kernel_size=3)(c16)

t03 = layers.Conv2DTranspose(128, kernel_size=2, strides=(2, 2), activation='relu')(c17)
crop03 = layers.Cropping2D(cropping=(40, 40))(c4)

concat03 = layers.concatenate([t03, crop03], axis=-1)

c18 = layers.Conv2D(128, activation='relu', kernel_size=3)(concat03)
c19 = layers.Conv2D(128, activation='relu', kernel_size=3)(c18)

t04 = layers.Conv2DTranspose(64, kernel_size=2, strides=(2, 2), activation='relu')(c19)
crop04 = layers.Cropping2D(cropping=(88, 88))(c1)

concat04 = layers.concatenate([t04, crop04], axis=-1)

c20 = layers.Conv2D(64, activation='relu', kernel_size=3)(concat04)
c21 = layers.Conv2D(64, activation='relu', kernel_size=3)(c20)

outputs = layers.Conv2D(2, kernel_size=1)(c21)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="u-netmodel")
Visualize Model Architecture

Another important aspect of the model is to understand the input filters and output filters. Consider you have a grayscale image, then the input filter will have the channel 1 and 3 in terms of RGB image. The output channel depends on the image_mask. If you are classifying each image pixel in three different classes then the output layer will consist of 3 channels.

It is now time for a practical example. For this purpose, I have used the Tensorflow tutorial on Image segmentation and used the parts that I needed for the demonstration, why to invent another wheel?

Practical Example

We will be looking at the oxford_iiit_pet dataset. The dataset is RGB with (128x128) image size. The first thing that might come to your mind is that the U-Net expects the input image to be (572,572)as well as it expects to have grayscale images. Yes, you are right and we can handle it in different ways, the way I have handled it is as follows.

First I changed my input layer to handle RGB images.

inputs = layers.Input(shape=(572, 572, 3))

Secondly, I will have to resize the images to (572x572) from (128x128). This can be achieved using the API resize_with_pad . The input mask needs to be resized as well.

input_image = tf.image.resize_with_pad(datapoint['image'], 572, 572)
input_mask = tf.image.resize_with_pad(datapoint['segmentation_mask'], 388, 388)

Lastly, we will have to change our output layer to handle the three-channel output that the input_mask has. There are three possible output pixels for this particular dataset (1,2,3), but because we padded the input_mask, we will have 4 channels.

# This is based on our dataset. The output channels are 3, think of it as each pixel will be classified
# into three classes, but I have written 4 here, as I do padding with 0, so we end up have four classes.
outputs = layers.Conv2D(4, kernel_size=1)(c21)

You can find the complete code in the GitHub. The preparation of the dataset and the training part is available on the above-mentioned repository.

The code guides you step by step on how to train a U-Net model from scratch on the oxford_iiit_pet dataset. I have tried to make the tutorial as scalable and as replaceable to other datasets as possible. As I know you would like to apply the U-Net to some other problem.

Kindly, give me feedback on how did you like this tutorial and if you will like to know more details on the training part of the tutorial.

Happy reading!

References:

[1] https://www.youtube.com/watch?v=u1loyDCoGbE&t=2111s

[2] https://arxiv.org/pdf/1505.04597.pdf

[3] https://www.tensorflow.org/tutorials/images/segmentation

--

--