Pytorch Semantic Image Segmentation

Stefan Herdy
8 min readJul 13, 2023

--

Semantic image segmentation is a powerful computer vision technique that involves the understanding and analysis of images at a pixel level. It aims to assign a meaningful label to each pixel in an image, effectively dividing it into different regions or objects. Unlike traditional image classification, which assigns a single label to an entire image, semantic segmentation provides a more detailed and granular understanding of the visual content within an image.

In this article, we will walk through the process of how to train our custom neural networks for semantic image segmentation.

Input Data

As a first step we need to generate our own segmentation datasets. These datasets consist of image label pairs with an image (HxWxC) and a target (HxW) with H as height, W as width and C as number of channels. For standard RGB images we have 3 channels. However, if we use multi- or hyperspectral images as model input we can have more channels. The two dimensionel target array contains categorical values that refer to a certain class at every pixel location.

Semantik Segmentation
Semantic Segmantation
Examples of an image target pair as imput data for the neural network training

To get these image target pairs we have to manually annotate images, which means that we have to assign a class to every pixel of an image. This image annotation can be done using some special software. One such tool that I could recommend is the online annotation editor Labelbox.

There are a lot of tutorials on how to use Labelbox and they provide a very good documentation, so a tutorial on how to annotate the images is not part of this article. It is important to consider to not save the target images as .jpg-files, since the jpeg compression could add new integer numbers (new classes) to your files. Formats like .png are fine.

Once we have our own custom segmentation dataset we can start writing our Python scripts to train and evaluate our models. Let’s make a file utils.py, where we specify some helper functions for the training and to import our custom datasets and our pretrained UNet model.

import os
import pickle
import json
from torch.utils.data import DataLoader
import pathlib
from customdatasets import CustomDataSet
from transformations import Compose, DenseTarget, RandomFlip, Resize_Sample
from transformations import MoveAxis, Normalize01, RandomCrop
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from os import walk
import torch as t
import numpy as np
import torch.nn as nn


def get_files(path):
files = []
for (dirpath, dirnames, filenames) in walk(path):
for names in filenames:
files.append(dirpath + '/' + names)
return files

def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)

def get_model(device, cl):

unet = smp.Unet('resnet152', classes=cl, activation=None, encoder_weights='imagenet')

if t.cuda.is_available():
unet.cuda()

unet = unet.to(device)
return unet

def import_data(args, batch_sz, set = 'project_3', crop_size):

root = pathlib.Path('./')
if set == 'project_1':
inputs = get_files('./input_data/project_1/image/')
targets = get_files('./input_data/project_1/target/')

if set == 'project_2':
inputs = get_files('./input_data/project_2/image/')
targets = get_files('./input_data/project_2/target/')

if set == 'project_3':
inputs = get_files('./input_data/project_3/image/')
targets = get_files('./input_data/project_3/target/')

split = 0.8

inputs_train, inputs_valid = train_test_split(
inputs,
random_state=42,
train_size=split,
shuffle=True)

targets_train, targets_valid = train_test_split(
targets,
random_state=42,
train_size=split,
shuffle=True)

# Add your desired transformations here
transforms = Compose([
MoveAxis(),
Normalize01(),
RandomCrop(crop_size),
RandomFlip()
])

# train dataset
dataset_train = CustomDataSet(inputs=inputs_train,
targets=targets_train,
transform=transforms)


# validation dataset
dataset_valid = CustomDataSet(inputs=inputs_valid,
targets=targets_valid,
transform=transforms)


# train dataloader
dataloader_training = DataLoader(dataset=dataset_train,
batch_size=batch_sz,
shuffle=True
)

# validation dataloader
dataloader_validation = DataLoader(dataset=dataset_valid,
batch_size=batch_sz,
shuffle=True)

return dataloader_training, dataloader_validation



def checkpoint(f, tag, args, device, dataloader_training, dataloader_validation):
f.cpu()
ckpt_dict = {
"model_state_dict": f.state_dict(),
"train": dataloader_training,
"valid": dataloader_validation

}
t.save(ckpt_dict, os.path.join(args.save_dir, tag))
f.to(device)

For the data import, we need to make some import classes. A detailed description of how to import and augment data for semantic segmentation can be found in my previous article.

Like shown in this article, we have to make our own dataset and transformation classes. Make a new file customdatasets.py with the following code:

import torch
from skimage.io import imread
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

class CustomDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None,
use_cache=False,
pre_transform=None,
):
self.inputs = inputs
self.targets = targets
self.transform = transform
self.inputs_dtype = torch.float32
self.targets_dtype = torch.long
self.use_cache = use_cache
self.pre_transform = pre_transform

if self.use_cache:
from multiprocessing import Pool
from itertools import repeat

with Pool() as pool:
self.cached_data = pool.starmap(self.read_images, zip(inputs, targets, repeat(self.pre_transform)))

def __len__(self):
return len(self.inputs)

def __getitem__(self,
index: int):
if self.use_cache:
x, y = self.cached_data[index]
else:
# Select the sample
input_id = self.inputs[index]
target_id = self.targets[index]

# Load input and target
x, y = imread(input_id), imread(target_id)

# Preprocessing
if self.transform is not None:
x, y = self.transform(x, y)

x, y = torch.from_numpy(x.copy()).type(self.inputs_dtype), torch.from_numpy(y.copy()).type(self.targets_dtype)

return x, y

@staticmethod
def read_images(inp, tar, pre_transform):
inp, tar = imread(inp), imread(tar)
if pre_transform:
inp, tar = pre_transform(inp, tar)
return inp, tar

Next, we have to make a file named transformations.py. In this file we specify our custom data transformations for preprocessing and augmentation.

#!/usr/bin/python

import numpy as np
from sklearn.externals._pilutil import bytescale
import random
import matplotlib.pyplot as plt
import cv2
import random


def normalize_01(inp: np.ndarray):
inp_out = (inp - np.min(inp)) / np.ptp(inp)
return inp_out

def normalize(inp: np.ndarray, mean: float, std: float):
inp_out = (inp - mean) / std
return inp_out

def re_normalize(inp: np.ndarray,
low: int = 0,
high: int = 255
):
"""Normalize the data to a certain range. Default: [0-255]"""
out = bytescale(inp, low=low, high=high)
return out

class Compose:
"""
Composes several transforms together.
"""

def __init__(self, transforms: list):
self.transforms = transforms

def __call__(self, inp, target):
for t in self.transforms:
inp, target = t(inp, target)
return inp, target

def __repr__(self): return str([transform for transform in self.transforms])


class MoveAxis:
"""From [H, W, C] to [C, H, W]"""

def __init__(self, transform_input: bool = True, transform_target: bool = False):
self.transform_input = transform_input
self.transform_target = transform_target

def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp = np.moveaxis(inp, -1, 0)
#tar = np.moveaxis(tar, -1, 0)

return inp, tar

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})


class RandomFlip:
def __init__(self):
pass

def __call__(self, inp: np.ndarray, tar: np.ndarray):
rand = random.choice([0, 1])
if rand == 1:
#inp = np.ndarray.copy(np.fliplr(inp))
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 1)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.fliplr(tar))

rand = random.choice([0, 1])
if rand == 1:
#inp = np.ndarray.copy(np.flipud(inp, axis=(1,2)))
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 0)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.flipud(tar))

rand = random.choice([0, 1])
if rand == 1:
inp = np.ndarray.copy(np.rot90(inp, k=1, axes=(1, 2)))
tar = np.ndarray.copy(np.rot90(tar, k=1, axes=(0, 1)))
return inp, tar

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})


class RandomCrop:
def __init__(self, crop_size):
self.crop_size = crop_size
pass

def __call__(self, inp: np.ndarray, tar: np.ndarray):
max_x = inp.shape[1] - self.crop_size
max_y = inp.shape[2] - self.crop_size
x = random.randint(0, max_x)
y = random.randint(0, max_y)

# Crop
inp = np.moveaxis(inp, 0, -1)
inp = inp[x: x + self.crop_size, y: y + self.crop_size,:]
inp = np.moveaxis(inp, -1, 0)
tar = tar[x: x + self.crop_size, y: y + self.crop_size]

return inp, tar


class Resize:
def __init__(self, img_size):
self.img_size = img_size
pass

def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp = np.moveaxis(inp, 0, -1)
inp = cv2.resize(inp, (self.img_size,self.img_size), interpolation = cv2.INTER_NEAREST)
inp = np.moveaxis(inp, -1, 0)
tar = cv2.resize(tar, (self.img_size,self.img_size), interpolation = cv2.INTER_NEAREST)

return inp, tar

class Normalize01:
"""Squash image input to the value range [0, 1] (no clipping)"""

def __init__(self):
pass

def __call__(self, inp, tar):
inp = normalize_01(inp)

return inp, tar


class Normalize:
"""Normalize based on mean and standard deviation."""
def __init__(self,
mean: float,
std: float,
transform_input=True,
transform_target=False
):

self.transform_input = transform_input
self.transform_target = transform_target
self.mean = mean
self.std = std

def __call__(self, inp, tar):
inp = normalize(inp)

return inp, tar

class ColorTransformations:
def __init__(self):
pass

def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp_tensor = torch.from_numpy(inp)
tar_tensor = torch.from_numpy(tar)

color_transform = transforms.Compose([
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
])

inp_tensor = color_transform(inp_tensor)

inp = inp_tensor.numpy()
tar = tar_tensor.numpy()

return inp, tar

class ColorNoise:
def __init__(self, noise_std=0.05):
self.noise_std = noise_std

def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp_tensor = torch.from_numpy(inp)
tar_tensor = torch.from_numpy(tar)

noise = torch.randn_like(inp_tensor) * self.noise_std
inp_tensor += noise

inp_tensor = torch.clamp(inp_tensor, 0, 1)

inp = inp_tensor.numpy()
tar = tar_tensor.numpy()

return inp, tar

Now we are ready to start our training. Our main file (train.py) loads the data and our pretrained model, trains the model on our custom data and saves the models etc. as checkpoints. For your own project you can specify your hyperparameters like learn rate, optimizer, number of epochs, batchsize etc. to achieve the best possible model performance. It should look like this:

#!/usr/bin/env python3

import os
import pickle
import json
from utils import *
import torch as t
import torch.nn as nn
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
t.backends.cudnn.benchmark = True
t.backends.cudnn.enabled = True

seed = 42
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
t.manual_seed(seed)

def main(args):
args.save_dir = './checkpoints/'

makedirs(args.save_dir)
with open(f'{args.save_dir}/params.txt', 'w') as f:
json.dump(args.__dict__, f)

t.manual_seed(seed)
if t.cuda.is_available():
t.cuda.manual_seed_all(seed)

dload_train, dload_valid = import_data(args, args.batch_size, args.project, args.resize, args.random_crop_size)

device = t.device('cuda' if t.cuda.is_available() else 'cpu')
f = get_model(device, args.num_classes)

params = f.parameters()
if args.optimizer == "adam":
optim = t.optim.Adam(params, lr=args.learnrate, betas=[.9, .999], weight_decay=0.0)
else:
optim = t.optim.SGD(params, lr=args.learnrate, momentum=.9, weight_decay=0.0)

best_valid_acc = 0.0
iteration = 0

train_losses = []
val_losses = []
val_corr = []

for epoch in range(args.epochs):
iter_losses = []
for i, (x_train, y_train) in tqdm(enumerate(dload_train)):
x_train, y_train = next(iter(dload_train))
x_train, y_train = x_train.to(device), y_train.to(device)

Loss = 0.

logits = f(x_train)
l_dis = nn.CrossEntropyLoss()(logits, y_train)
Loss += l_dis
iter_losses.append(Loss.item())

optim.zero_grad()
Loss.backward()
optim.step()

if iteration % args.print_every == 0:
acc = (logits.max(1)[1] == y_train).float().mean()
print('P(y|x) {}:{:>d} loss={:>14.9f}, acc={:>14.9f}'.format(epoch,
iteration,
l_dis.item(),
acc.item()))

iteration += 1

train_losses.append(np.mean(iter_losses))

if epoch % args.eval_every == 0:
f.eval()
with t.no_grad():
correct, loss = eval_classification(f, dload_valid, device)
val_losses.append(loss)
val_corr.append(correct)
print("Epoch {}: Valid Loss {}, Valid Acc {}".format(epoch, loss, correct))
if correct > best_valid_acc:
best_valid_acc = correct
print("Best Valid!: {}".format(correct))
checkpoint(f, "best_validation_ckpt.pt", args, device, dload_train, dload_valid)
f.train()
if epoch % args.ckpt_every == 0:
checkpoint(f, f'checkpoint_{epoch}.pt', args, device, dload_train, dload_valid)

# Losses are saved and can be loaded for further analysis
# You can also plot them here using matplotlib
with open("./records/trainlosses.txt" , "wb") as fp:
pickle.dump(train_losses, fp)

with open("./records/vallosses.txt" , "wb") as fp:
pickle.dump(val_losses, fp)

with open("./records/correct.txt" , "wb") as fp:
pickle.dump(val_corr, fp)

if __name__ == "__main__":
parser = argparse.ArgumentParser("Pytorch Semantic Segmentation")
parser.add_argument("--learnrate", type=int, default=0.0001, help='learn rate of optimizer')
parser.add_argument("--optimizer", choices=['sgd', 'adam'], default='adam')
parser.add_argument("--epochs", type=int, default=7000)
parser.add_argument("--eval_every", type=int, default=1, help="Epochs between evaluation")
parser.add_argument("--print_every", type=int, default=5, help="Epochs between print")
parser.add_argument("--ckpt_every", type=int, default=2, help="Epochs between checkpoint save")
parser.add_argument("--project", choices=['project_1', 'project_2', 'project_3'], default='project_1')
parser.add_argument("--batch_size", type=int, default=2, help="Batch Size")
parser.add_argument("--num_classes", type=int, default=8, help="Number of classes of your training dataset")
parser.add_argument("--resize", type=int, default=512, help="Size of images for resizing")
parser.add_argument("--random_crop_size", type=int, default=256, help="Size of random crops. Must be smaller than resized images.")
args = parser.parse_args()
if args.random_crop_size > args.resize:
raise Exception("Crop size (--random_crop_size) must be smaller than resized image (--resize)!")

main(args)

Congrats! Now we are ready to train our own custom semantic segmentation model in Pytorch.

The whole code of this article can be found on my GitHub repository.

I hope you liked this article!

All the best,
Stefan

--

--

Stefan Herdy

Passionate Programmer, Data Scientist and Machine Learning Engineer with special interest in developing software for Biology, Remote Sensing and Forestry