Image Segmentation — A Beginner’s Guide

The essentials of Image Segmentation + implementation in TensorFlow

Raj Pulapakura
6 min readFeb 4, 2024

Image segmentation is a computer vision technique that assigns a label to every pixel in an image such that pixels with the same label share certain characteristics.

For example, in a street scene, all pixels belonging to cars might be labeled with one color, while those belonging to the road might be labeled with another.

But to understand image segmentation and why it is useful, let’s go back to basics….

Boring Classifiers

Cute doggo. Source

Is there a cute little doggo in this picture? Of course there is.

This is a classification task. It tells us if there is a dog in the image.

But what if we want to know exactly where the dog is.

One approach is to draw a bounding box around the dog, which is called Object Detection.

Cute doggo + bounding box. Source + Author

If that’s all you want, then you’re done! But if you want to know exactly where the dog is, on the pixel level, then you’ll need something better. That’s where image segmentation comes into play.

Image Segmentation

Street segmentation. Source

The core task of image segmentation is to classify each pixel in an image. In the above street scene, there are 5 classes: road (pink), vehicles (red), buildings (yellow), nature (green), sky (blue). Each pixel is assigned one of these classes.

But sometimes you want to be able to differentiate between different cars, or different trees. To this end, there are 3 main types of image segmentation, each providing a different level of detail and information.

Semantic vs. Instance vs. Panoptic

Semantic vs. Instance vs. Panoptic segmentation. Source
  • Semantic segmentation classifies each pixel based on its semantic class. All the birds belong to the same class.
  • Instance segmentation assigns unique labels to different instances, even if they are of the same semantic class. Each bird belongs to a different class.
  • Panoptic segmentation combines the two, providing both class-level and instance-level labels. Each bird has its own class, but they are all identified as a “bird”.

Cool, but how do we actually implement image segmentation?

There are a couple of ways, such as thresholding and clustering, but deep learning (my fav) really takes the spotlight when it comes to image segmentation.

Real-time body part panoptic segmentation. GIF from TensorFlow Blog

U-Net

The U-Net architecture was initially designed for medical image segmentation, but it has since been adapted for many other use cases.

U-Net. Image by author.

The U-Net has an encoder-decoder structure.

The encoder is used to compress the input image into a latent space representation through convolutions and downsampling.

The decoder is used to extrapolate the latent representation into a segmented image, through convolutions and upsampling.

The long gray arrows running across the “U” are skip connections, and they serve two main purposes:

  1. During the forward pass, they enable the decoder to access information from the encoder.
  2. During the backward pass, they act as a “gradient superhighway” for gradients from the decoder to flow to the encoder.

The output of the model has the same width and height as the input, however the number of channels will be equal to the number of classes we are segmenting.

Code it up

If you’re keen to code, let’s implement the U-Net architecture for semantic segmentation in TensorFlow.

U-Net Architecture

Defining the model architecture is rather straightforward.

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, 
concatenate, Conv2DTranspose

def conv_block(x, n_filters):
"""two convolutions"""
x = Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
x = Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
return x

def encoder_block(x, n_filters):
"""conv block and max pooling"""
x = conv_block(x, n_filters)
p = MaxPooling2D((2, 2))(x)
return x, p # we will need x for the skip connections later

def decoder_block(x, p, n_filters):
"""upsample, skip connection, and conv block"""
x = Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(x)
x = concatenate([x, p]) # concatenate = skip connection
x = conv_block(x, n_filters)
return x

def unet_model(n_classes, img_height, img_width, img_channels):
inputs = Input((img_height, img_width, img_channels)) # 512x512x3

# Contraction path, encoder
c1, p1 = encoder_block(inputs, n_filters=64) # c1=512x512x64 p1=256x256x64
c2, p2 = encoder_block(p1, n_filters=128) # c2=256x256x128 p2=128x128x128
c3, p3 = encoder_block(p2, n_filters=256) # c3=128x128x256 p3=64x64x256
c4, p4 = encoder_block(p3, n_filters=512) # c4=64x64x512 p4=32x32x512

# Bottleneck
bridge = conv_block(p5, n_filters=1024) # bridge=32x32x1024

# Expansive path, decoder
u4 = decoder_block(bridge, p4, n_filters=512) # 64x64x512
u3 = decoder_block(u4, p3, n_filters=256) # 128x128x256
u2 = decoder_block(u3, p2, n_filters=128) # 256x256x128
u1 = decoder_block(u2, p1, n_filters=64) # 512x512x64

outputs = Conv2D(n_classes, (1, 1), activation='softmax')(u1) # 512x512xn_classes
# notice the softmax activation in the final layer

model = Model(inputs=[inputs], outputs=[outputs])

return model

# example classes: [road, vehicles, buildings, nature, background]
# instantiate model to predict 5 classes
unet_model = multi_unet_model(
n_classes=5,
img_height=IMG_HEIGHT,
img_width=IMG_WIDTH,
img_channels=3
)
# input: 512x512x3
# output: 512x512x5

Loss Function: Categorical Cross Entropy

How do we optimize this model? Well, since image segmentation is really just classification on the pixel level, we can use the standard classification loss function, which is Categorical Cross Entropy.

model.compile(
loss="categorical_crossentropy",
categorical_crossentropy
)

We can interpret each pixel of the resulting (512x512x5) volume as a vector of length 5. Since the last layer uses a softmax activation across the last dimension, each pixel vector contains the probabilities of that pixel belonging to each class.

Intuition for model output

Before we can train the model, we need a dataset. The dataset should contain (image, mask) pairs, where the image (x) is of shape (512x512x3) and the mask (y) is of shape (512x512x5).

Here is an example ground truth mask:

Image by Prince Canuma

Each pixel can only belong to one class, so it contains a “1” in one of the class channels, and a “0” in the other channels. You can think of each pixel as a one-hot vector (because that’s what it is).

Once you have your dataset prepared, you’re ready to train:

model.fit(
train_ds,
validation_data=val_ds,
epochs=10,
)

Of course, this code would not be enough to run a successful model. If you actually want to implement this, you need to consider preprocessing, rescaling, batching etc.

I’ve prepared a Kaggle notebook which tackles car segmentation (segmenting different parts of a car). It contains the complete code to run an image segmentation model, so check it out here.

Final Notes

  • Class Imbalance: Often in image segmentation, there is severe class imbalance. For example, in an average street view image, cars and buildings take up a lot of pixels, but stop signs take up very few pixels. The model has less data on stop signs, so it will perform poorly in segmenting stop signs. To solve this, you can use Focal Categorical Cross Entropy and class weights, which place emphasis on minority classes.
  • Other Architectures: U-Net is not the only image segmentation architecture, although it is conceptually the simplest. Others include SegNet, Mask R-CNN, and PSPNet.
  • Binary Segmentation: If there is only one class your segmenting (e.g. segmenting a brain tumor in an MRI scan), then the output of the model only needs to be (512x512). For the mask, each pixel will contain a “1” if that pixel belongs to a tumor, or “0” if that pixel does not belong to a tumor. Make sure to also change “softmax” to “sigmoid” in the final activation of the model, and use the (Focal) Binary Cross Entropy loss function.

Thanks for reading!

Follow me for more great content:

Articles you might like:

Have a fantastic day!

--

--

Raj Pulapakura

Machine Learning Engineer and Full Stack Developer. Passionate about advancing human intelligence and solving problems.