Pytorch implementation of Semantic Segmentation for Single class from scratch.

Shashank Shekhar
Analytics Vidhya
Published in
7 min readDec 14, 2019
Image by MIDHUN GEORGE via unsplash

INTRODUCTION

Semantic segmentation can be thought as a classification at a pixel level, more precisely it refers to the process of linking each pixel in an image to a class label. We are trying here to answer what & where is in the image, semantic segmentation is different from instance segmentation where objects of the same class will have different labels as in car1, car2 with different colours.

There are tones of repository available where semantic segmentation is available in very complex forms for multi classes. Through this blog I have tried to implement semantic segmentation from scratch for a single class, after some tweaks same should be applied for multi class.

Autonomous driving Image Source
Aerial Imaging Image Source

Some of its primary applications are in autonomous driving, Medical image diagnostics , aerial imaging, photo editing/creativity tools and many more.

Brain Tumour Prediction Image Source

Road-map of topics covered

  1. Overview
  2. Brief description about the data and framework we will be using.
  3. Code Implementation for a single class.
    a.
    Data Preprocessing Pipeline
    b.
    Dataloders pipeline
    c.
    Scores Pipeline
    d.
    Training Pipeline
  4. Inference

Overview of Semantic Image Segmentation

Given a grayscale(H,W,1) or RGB(H,W,3) image we want to generate a segmentation mask which is of the same dimension as the image and consist of categorical values from 1 to N (N is the number of classes).

Input Image to Segmentation Labels Credit
Categorical to One hot Encoded Labels Credit

Semantic labels consisting of categorical values can also be one hot encoded as shown in the left image.

Overlap of Original image and Semantic label

After taking the argmax across the class and overlap with the original image we will end up with a image on the left side.

Dataset & Framework

We will be using carvana dataset provided by kaggle, it contains a large number of car images. Each car has exactly 16 images, each one taken at different angles.

16 orientations for Single car Image

File descriptions

  • train — this folder contains the training set images (.jpg) [1280,1918]
  • test — this folder contains the test set images (jpg)[1280,1918]
  • train_masks — this folder contains the training set masks (.gif) [1280,1918]
  • train_masks.csv — this files gives a run-length encoded version of the training masks.
  • metadata.csv — contains basic information about all the cars.

As of framework we will majorly be using Pytorch and sklearn (for train/val split).

Implementation for Single Class

Implementation is subdivided into 4 pipelines:-

  1. Data Preprocessing Pipeline- Converting train_mask images from .gif to .png, then we will convert both train and train mask images(.png) from their original dimension to new dimension[128,128]. For all our training purposes we will using this 128,128 images.
  2. Dataloders pipeline- Here we will fetch images in batches apply transforms to them & then returns dataloders for train and validation phases.
  3. Scores Pipeline- Pipeline for calculating the required score (in our case dice score as mentioned by kaggle).
  4. Training Pipeline- Final pipeline where training begins, loss are calculated & parameters are updated.

Getting started!!!

For the entire code snippets in this blog i have tried to comment wherever it was required. We will start will importing all the required libraries.

We will be using the Unet Architecture for that we will use an high level API provided by segmentation_models.pytorch

For image and mask augmentation we will be using an API provided by albumentations.

  1. Data Preprocessing Pipeline

Firstly we will convert train mask from .gif to .png , then we will resize the train and mask images to [128,128]. Here we will be using ThreadPoolExecutor for parallel operations.

Data Preprocessing

Now we will load the train_masks.csv in dataframe just for getting the image names. I will not use run-length encoded version of the training masks provided in this csv instead i will directly use the mask images just now generated.

df=pd.read_csv('/home/arun/Shashank/carvana/train_masks.csv')# location of original and mask image
img_fol='/media/shashank/New Volume/carvana/train-128'
mask_fol='/media/shashank/New Volume/carvana/train_masks-128'
# imagenet mean/std will be used as the resnet backbone is trained on imagenet stats
mean, std=(0.485, 0.456, 0.406),(0.229, 0.224, 0.225)

as we will be using resnet back-end which is trained on imagnet we will set the mean and standard deviation of imagenet data for transformation purposes.

2. Dataloaders pipeline

In this section we will implement custom transforms , dataset and dataloader.
Starting with transforms depending on phase, if “train” then we will use horizontal flip along with Normalize and ToTensor. If “val” then we will only be useing Normalize and ToTensor.

applying transforms

After transform we will create a custom dataset class named CarDataset, here we fetch the original image and mask using the index id from dataloader and then apply transformation on top of that. Output from this class is image tensor of shape [3,128,128] and mask tensor [1,128,128]. For the mask tensor we have only one channel as we training only for a single class.

Mask Single channel representation [1,128,128].
Custom Dataset

Now using CarDataloader function we split the input dataframe into train dataframe and valid dataframe (only for the purpose of names). Using these dataframes we create dataloaders for training and validation.

DataLoader

3. Scores Pipeline

To tackle the problem of class imbalance we use Soft Dice Score instead of using pixel wise cross entropy loss. For calculating the SDS for every class we multiply the (pred score * target score) and divide by the sum of (pred²+target score²).

Image

Inside every epoch for all the batch we calculate the dice score & append in a empty list. At the end of epoch we calculate the mean of dice scores which represent dice score for that particular epoch.

At the end of epoch we log the dice values using epoch_log function.

4. Training Pipeline

In this last pipeline we create a trainer class by initializing most of the values.

Trainer class

In the start method for every epoch first we will call iterate method for training then iterate method for validation & then learning rate scheduling. If the current validation loss is less than previous one then we save the model parameters.

In the iterate method we call forward method which calculates the loss which is then divided by accumulate steps & added to running loss. Meanwhile we keep on storing the loss gradients upto the accumulation steps in loss.grad.

After that we do the optimization step and zero the gradients once accumulation steps are reached. Lastly we will have epoch loss, dice score & will clear the cuda cache memory.

Inside the forward method we take original image & target mask send it to GPU, create a forward pass to get the prediction mask. Using the loss function we calculate the loss.

Loading Architecture from smp

Now is the time to load the UNet architecture from smp, using resnet18 as backbone. For the number of classes we have used 1 as our mask dimension is [1,128,128].

model = smp.Unet("resnet18", encoder_weights="imagenet", classes=1, activation=None)

Let the magic begin!!!!!!!!!!!!!!!

model_trainer = Trainer(model)
model_trainer.start()
Starting epoch: 0 | phase:train | 🙊':02:02:11
Loss: 0.1084 |dice: 0.9460
Starting epoch: 0 | phase:val | 🙊':02:02:48
Loss: 0.0358 |dice: 0.9783
******** New optimal found, saving state ********

Starting epoch: 1 | phase:train | 🙊':02:02:55
Loss: 0.0288 |dice: 0.9800
Starting epoch: 1 | phase:val | 🙊':02:03:29
Loss: 0.0239 |dice: 0.9815
******** New optimal found, saving state ********

Starting epoch: 2 | phase:train | 🙊':02:03:36
Loss: 0.0205 |dice: 0.9836
Starting epoch: 2 | phase:val | 🙊':02:04:11
Loss: 0.0185 |dice: 0.9844
******** New optimal found, saving state ********

Starting epoch: 3 | phase:train | 🙊':02:04:18
Loss: 0.0172 |dice: 0.9854
Starting epoch: 3 | phase:val | 🙊':02:04:53
Loss: 0.0167 |dice: 0.9853
******** New optimal found, saving state ********

Starting epoch: 4 | phase:train | 🙊':02:04:59
Loss: 0.0155 |dice: 0.9863
Starting epoch: 4 | phase:val | 🙊':02:05:34
Loss: 0.0154 |dice: 0.9860
******** New optimal found, saving state ********

Starting epoch: 5 | phase:train | 🙊':02:05:40
Loss: 0.0149 |dice: 0.9864
Starting epoch: 5 | phase:val | 🙊':02:06:14
Loss: 0.0158 |dice: 0.9850

Starting epoch: 6 | phase:train | 🙊':02:06:20
Loss: 0.0142 |dice: 0.9869
Starting epoch: 6 | phase:val | 🙊':02:06:55
Loss: 0.0158 |dice: 0.9848

Starting epoch: 7 | phase:train | 🙊':02:07:00
Loss: 0.0132 |dice: 0.9877
Starting epoch: 7 | phase:val | 🙊':02:07:35
Loss: 0.0145 |dice: 0.9863
******** New optimal found, saving state ********

Starting epoch: 8 | phase:train | 🙊':02:07:41
Loss: 0.0127 |dice: 0.9881
Starting epoch: 8 | phase:val | 🙊':02:08:15
Loss: 0.0151 |dice: 0.9855

Starting epoch: 9 | phase:train | 🙊':02:08:21
Loss: 0.0125 |dice: 0.9882
Starting epoch: 9 | phase:val | 🙊':02:08:57
Loss: 0.0136 |dice: 0.9869
******** New optimal found, saving state ********

Inference

In around 6 minutes we reached dice score of 98.7 which is impressing, using the saved weights we will do inference on our validation data using the below snippets.

Left is the Predicted mask Right is the Target mask.

Conclusion

In may not be SOTA results but by using just 200 lines of code we get a clear idea of how semantic segmentation works. By tweaking few lines of code same can be done for multiclass labels. Please Share, Leave your comment if any.

--

--