U-net Unleashed: A step-by-step guide on implementing and training your own segmentation model in Tensorflow: Part 2

Vipul Sarode
5 min readJan 20, 2024

--

In the last part, we saw how to build the U-Net from scratch. In this part, we will use the U-Net and perform segmentations on real-world data. You can find the dataset used in this part here.

If you want the full version of the code, you can get it here. To reproduce the results, you may need multiple GPUs. I used the GPUs provided for free by Kaggle for this task.

Data

I have created a custom function to upload the images and their respective masks simultaneously. In this function, I have convert all the images and their masks into grayscale to lessen the need for computational resources. Also, I have resized all of them to (256,256) size to ensure consistency. Here is the code for the function.

def load_images(imgsort, masksort, image_dir, mask_dir):
'''
Takes the directories and image/mask names and reads them using cv2
'''

images, masks = [], []

for img, msk in tqdm(zip(imgsort, masksort), total = len(imgsort), desc = 'Loading Images and Masks'):
image = cv2.imread(image_dir + img, cv2.IMREAD_GRAYSCALE)
mask = cv2.imread(mask_dir + msk, cv2.IMREAD_GRAYSCALE)

image = cv2.resize(image, (256,256))
mask = cv2.resize(mask, (256,256))


images.append(image)
masks.append(mask)

del image, mask

return images, masks

images, masks = load_images(imgsort, masksort, image_dir, mask_dir)

Once we import the images and the masks, we need to perform the most significant step in Machine Learning, which is:

Visualize, visualize, visualize! Before training any Machine Learning model, we need to become one with the data as better data is probably the most crucial ingredient to cook a better model. Visualizing the images will also ensure that the images and masks are aligned.

#Plotting images for sanity check
def plot_image_with_mask(image_list, mask_list, num_pairs = 4):
'''
This functions takes the image and mask lists and prints 4 random pairs of images and masks
'''
plt.figure(figsize = (18,9))
for i in range(num_pairs):
idx = random.randint(0, len(image_list))
img = image_list[idx]
mask = mask_list[idx]
plt.subplot(2, 4, i + 1)
plt.imshow(img)
plt.title(f'Real Image, index = {idx}')
plt.axis('off')
plt.subplot(2, 4, i + num_pairs + 1)
plt.imshow(mask)
plt.title(f'Segmented Image, index = {idx}')
plt.axis('off')
del img, mask

plot_image_with_mask(images, masks, num_pairs = 4)

Now that the data is ready, I converted all the images and masks to tensors and performed a train-test-split using the 60–20–20 distribution for the training, validation, and testing set respectively. Again, you can refer to the code here.

Setting up Evaluation Metrics

Segmentation tasks cannot be evaluated using the traditional accuracy method. Let’s think for a second. Segmentation is just a process of classifying the pixels. As you can see in the above masks, majority of the pixels in the mask are black (no cancer tissue is present there in our case). Hence, even if it classifies 25% of the segmented area correctly, we will still be getting “high accuracy” for our model. But practically, it might be useless. Hence, we evaluate Segmentations with a different metric — the “Dice Coefficient

Defining the Dice Coefficient

Visual representation of the Dice Coefficient (Source: omicsonline)

Dice Coefficient is the measure of overlap of two sets (in our case, images). For our use case, it measures the similarity between two images, the actual mask and the predicted mask. If the dice coefficient is 1, then the predicted mask is exactly equal to the actual mask (100% overlap). If it is zero, then the predicted and actual masks would not have any similarity.

Mathematically, the dice coefficient is defined as the ratio of intersection of two sets to their union.

Dice coefficient = 2 |A ∩ B| / (|A| + |B|) where A and B are two sets whose intensity of overlap is being measured, whereas |A| and |B| are the total number of elements in A and B respectively.

Let’s see how to code the dice coefficient in Python.

# Setting dice coefficient to evaluate our model
def dice_coeff(y_true, y_pred, smooth = 1):
intersection = K.sum(y_true*y_pred, axis = -1)
union = K.sum(y_true, axis = -1) + K.sum(y_pred, axis = -1)
dice_coeff = (2*intersection+smooth) / (union + smooth)
return dice_coeff

Training the U-Net and creating Segmentations

Now that we are all set with the data and the evaluation metric, let’s train the U-Net model and perform Segmentations on the test data. I saved the U-Net we built from scratch in Part 1 in a .py file and imported it into this notebook for training.

with strategy.scope(): #this line allocates multiple GPUs for training in Kaggle
unet = unet()
unet.compile(loss = 'binary_crossentropy',
optimizer = 'adam',
metrics = ['accuracy', dice_coeff])

#Defining early stopping to regularize the model and prevent overfitting
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 3, restore_best_weights = True)

#Training the model with 50 epochs (it will stop training in between because of early stopping)
unet_history = unet.fit(train_data, validation_data = [val_data],
epochs = 50, callbacks = [early_stopping])

This is the final epoch before the model stopped training due to the callback we set earlier. As you can see, the accuracy is very high (94%), but the Dice Coefficient is almost (67%), hence justifying the need for the dice coefficient for evaluation. Now, let’s predict some masks for the images in the test set and plot them beside the actual masks to see how well our model is doing.

#Function to plot the predictions with orginal image, original mask and predicted mask
def plot_preds(idx):
'''
This function plots a test image, it's actual mask and the predicted
mask side by side.
'''

plt.figure(figsize = (15, 15))
test_img = images_test[idx]
test_img = tf.expand_dims(test_img, axis = 0)
test_img = tf.expand_dims(test_img, axis = -1)
pred = unet.predict(test_img)
pred = pred.squeeze()
thresh = pred > 0.5
plt.subplot(1,3,1)
plt.imshow(images_test[idx])
plt.title(f'Original Image {idx}')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(masks_test[idx])
plt.title('Actual Mask')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(thresh)
plt.title('Predicted Mask')
plt.axis('off')

#plotting 10 random images with their true and predicted masks
for i in [random.randint(0, 2000) for i in range(10)]:
plot_preds(i)
10 randomly picked images with their actual and predicted mask

Looks good! I’ll say that the model is performing well with the segmentations. I hope you learnt something new from this article and thanks for sticking around till the end of it. If you have any further questions or concerns, please let me know in the comments.

References

  1. oced.ai

--

--