Pytorch Semantic Image Segmentation
Semantic image segmentation is a powerful computer vision technique that involves the understanding and analysis of images at a pixel level. It aims to assign a meaningful label to each pixel in an image, effectively dividing it into different regions or objects. Unlike traditional image classification, which assigns a single label to an entire image, semantic segmentation provides a more detailed and granular understanding of the visual content within an image.
In this article, we will walk through the process of how to train our custom neural networks for semantic image segmentation.
Input Data
As a first step we need to generate our own segmentation datasets. These datasets consist of image label pairs with an image (HxWxC) and a target (HxW) with H as height, W as width and C as number of channels. For standard RGB images we have 3 channels. However, if we use multi- or hyperspectral images as model input we can have more channels. The two dimensionel target array contains categorical values that refer to a certain class at every pixel location.
To get these image target pairs we have to manually annotate images, which means that we have to assign a class to every pixel of an image. This image annotation can be done using some special software. One such tool that I could recommend is the online annotation editor Labelbox.
There are a lot of tutorials on how to use Labelbox and they provide a very good documentation, so a tutorial on how to annotate the images is not part of this article. It is important to consider to not save the target images as .jpg-files, since the jpeg compression could add new integer numbers (new classes) to your files. Formats like .png are fine.
Once we have our own custom segmentation dataset we can start writing our Python scripts to train and evaluate our models. Let’s make a file utils.py, where we specify some helper functions for the training and to import our custom datasets and our pretrained UNet model.
import os
import pickle
import json
from torch.utils.data import DataLoader
import pathlib
from customdatasets import CustomDataSet
from transformations import Compose, DenseTarget, RandomFlip, Resize_Sample
from transformations import MoveAxis, Normalize01, RandomCrop
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split
from os import walk
import torch as t
import numpy as np
import torch.nn as nn
def get_files(path):
files = []
for (dirpath, dirnames, filenames) in walk(path):
for names in filenames:
files.append(dirpath + '/' + names)
return files
def makedirs(dirname):
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_model(device, cl):
unet = smp.Unet('resnet152', classes=cl, activation=None, encoder_weights='imagenet')
if t.cuda.is_available():
unet.cuda()
unet = unet.to(device)
return unet
def import_data(args, batch_sz, set = 'project_3', crop_size):
root = pathlib.Path('./')
if set == 'project_1':
inputs = get_files('./input_data/project_1/image/')
targets = get_files('./input_data/project_1/target/')
if set == 'project_2':
inputs = get_files('./input_data/project_2/image/')
targets = get_files('./input_data/project_2/target/')
if set == 'project_3':
inputs = get_files('./input_data/project_3/image/')
targets = get_files('./input_data/project_3/target/')
split = 0.8
inputs_train, inputs_valid = train_test_split(
inputs,
random_state=42,
train_size=split,
shuffle=True)
targets_train, targets_valid = train_test_split(
targets,
random_state=42,
train_size=split,
shuffle=True)
# Add your desired transformations here
transforms = Compose([
MoveAxis(),
Normalize01(),
RandomCrop(crop_size),
RandomFlip()
])
# train dataset
dataset_train = CustomDataSet(inputs=inputs_train,
targets=targets_train,
transform=transforms)
# validation dataset
dataset_valid = CustomDataSet(inputs=inputs_valid,
targets=targets_valid,
transform=transforms)
# train dataloader
dataloader_training = DataLoader(dataset=dataset_train,
batch_size=batch_sz,
shuffle=True
)
# validation dataloader
dataloader_validation = DataLoader(dataset=dataset_valid,
batch_size=batch_sz,
shuffle=True)
return dataloader_training, dataloader_validation
def checkpoint(f, tag, args, device, dataloader_training, dataloader_validation):
f.cpu()
ckpt_dict = {
"model_state_dict": f.state_dict(),
"train": dataloader_training,
"valid": dataloader_validation
}
t.save(ckpt_dict, os.path.join(args.save_dir, tag))
f.to(device)
For the data import, we need to make some import classes. A detailed description of how to import and augment data for semantic segmentation can be found in my previous article.
Like shown in this article, we have to make our own dataset and transformation classes. Make a new file customdatasets.py with the following code:
import torch
from skimage.io import imread
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
class CustomDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None,
use_cache=False,
pre_transform=None,
):
self.inputs = inputs
self.targets = targets
self.transform = transform
self.inputs_dtype = torch.float32
self.targets_dtype = torch.long
self.use_cache = use_cache
self.pre_transform = pre_transform
if self.use_cache:
from multiprocessing import Pool
from itertools import repeat
with Pool() as pool:
self.cached_data = pool.starmap(self.read_images, zip(inputs, targets, repeat(self.pre_transform)))
def __len__(self):
return len(self.inputs)
def __getitem__(self,
index: int):
if self.use_cache:
x, y = self.cached_data[index]
else:
# Select the sample
input_id = self.inputs[index]
target_id = self.targets[index]
# Load input and target
x, y = imread(input_id), imread(target_id)
# Preprocessing
if self.transform is not None:
x, y = self.transform(x, y)
x, y = torch.from_numpy(x.copy()).type(self.inputs_dtype), torch.from_numpy(y.copy()).type(self.targets_dtype)
return x, y
@staticmethod
def read_images(inp, tar, pre_transform):
inp, tar = imread(inp), imread(tar)
if pre_transform:
inp, tar = pre_transform(inp, tar)
return inp, tar
Next, we have to make a file named transformations.py. In this file we specify our custom data transformations for preprocessing and augmentation.
#!/usr/bin/python
import numpy as np
from sklearn.externals._pilutil import bytescale
import random
import matplotlib.pyplot as plt
import cv2
import random
def normalize_01(inp: np.ndarray):
inp_out = (inp - np.min(inp)) / np.ptp(inp)
return inp_out
def normalize(inp: np.ndarray, mean: float, std: float):
inp_out = (inp - mean) / std
return inp_out
def re_normalize(inp: np.ndarray,
low: int = 0,
high: int = 255
):
"""Normalize the data to a certain range. Default: [0-255]"""
out = bytescale(inp, low=low, high=high)
return out
class Compose:
"""
Composes several transforms together.
"""
def __init__(self, transforms: list):
self.transforms = transforms
def __call__(self, inp, target):
for t in self.transforms:
inp, target = t(inp, target)
return inp, target
def __repr__(self): return str([transform for transform in self.transforms])
class MoveAxis:
"""From [H, W, C] to [C, H, W]"""
def __init__(self, transform_input: bool = True, transform_target: bool = False):
self.transform_input = transform_input
self.transform_target = transform_target
def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp = np.moveaxis(inp, -1, 0)
#tar = np.moveaxis(tar, -1, 0)
return inp, tar
def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
class RandomFlip:
def __init__(self):
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
rand = random.choice([0, 1])
if rand == 1:
#inp = np.ndarray.copy(np.fliplr(inp))
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 1)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.fliplr(tar))
rand = random.choice([0, 1])
if rand == 1:
#inp = np.ndarray.copy(np.flipud(inp, axis=(1,2)))
inp = np.moveaxis(inp, 0, -1)
inp = cv2.flip(inp, 0)
inp = np.moveaxis(inp, -1, 0)
tar = np.ndarray.copy(np.flipud(tar))
rand = random.choice([0, 1])
if rand == 1:
inp = np.ndarray.copy(np.rot90(inp, k=1, axes=(1, 2)))
tar = np.ndarray.copy(np.rot90(tar, k=1, axes=(0, 1)))
return inp, tar
def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
class RandomCrop:
def __init__(self, crop_size):
self.crop_size = crop_size
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
max_x = inp.shape[1] - self.crop_size
max_y = inp.shape[2] - self.crop_size
x = random.randint(0, max_x)
y = random.randint(0, max_y)
# Crop
inp = np.moveaxis(inp, 0, -1)
inp = inp[x: x + self.crop_size, y: y + self.crop_size,:]
inp = np.moveaxis(inp, -1, 0)
tar = tar[x: x + self.crop_size, y: y + self.crop_size]
return inp, tar
class Resize:
def __init__(self, img_size):
self.img_size = img_size
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp = np.moveaxis(inp, 0, -1)
inp = cv2.resize(inp, (self.img_size,self.img_size), interpolation = cv2.INTER_NEAREST)
inp = np.moveaxis(inp, -1, 0)
tar = cv2.resize(tar, (self.img_size,self.img_size), interpolation = cv2.INTER_NEAREST)
return inp, tar
class Normalize01:
"""Squash image input to the value range [0, 1] (no clipping)"""
def __init__(self):
pass
def __call__(self, inp, tar):
inp = normalize_01(inp)
return inp, tar
class Normalize:
"""Normalize based on mean and standard deviation."""
def __init__(self,
mean: float,
std: float,
transform_input=True,
transform_target=False
):
self.transform_input = transform_input
self.transform_target = transform_target
self.mean = mean
self.std = std
def __call__(self, inp, tar):
inp = normalize(inp)
return inp, tar
class ColorTransformations:
def __init__(self):
pass
def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp_tensor = torch.from_numpy(inp)
tar_tensor = torch.from_numpy(tar)
color_transform = transforms.Compose([
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
])
inp_tensor = color_transform(inp_tensor)
inp = inp_tensor.numpy()
tar = tar_tensor.numpy()
return inp, tar
class ColorNoise:
def __init__(self, noise_std=0.05):
self.noise_std = noise_std
def __call__(self, inp: np.ndarray, tar: np.ndarray):
inp_tensor = torch.from_numpy(inp)
tar_tensor = torch.from_numpy(tar)
noise = torch.randn_like(inp_tensor) * self.noise_std
inp_tensor += noise
inp_tensor = torch.clamp(inp_tensor, 0, 1)
inp = inp_tensor.numpy()
tar = tar_tensor.numpy()
return inp, tar
Now we are ready to start our training. Our main file (train.py) loads the data and our pretrained model, trains the model on our custom data and saves the models etc. as checkpoints. For your own project you can specify your hyperparameters like learn rate, optimizer, number of epochs, batchsize etc. to achieve the best possible model performance. It should look like this:
#!/usr/bin/env python3
import os
import pickle
import json
from utils import *
import torch as t
import torch.nn as nn
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
t.backends.cudnn.benchmark = True
t.backends.cudnn.enabled = True
seed = 42
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
t.manual_seed(seed)
def main(args):
args.save_dir = './checkpoints/'
makedirs(args.save_dir)
with open(f'{args.save_dir}/params.txt', 'w') as f:
json.dump(args.__dict__, f)
t.manual_seed(seed)
if t.cuda.is_available():
t.cuda.manual_seed_all(seed)
dload_train, dload_valid = import_data(args, args.batch_size, args.project, args.resize, args.random_crop_size)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
f = get_model(device, args.num_classes)
params = f.parameters()
if args.optimizer == "adam":
optim = t.optim.Adam(params, lr=args.learnrate, betas=[.9, .999], weight_decay=0.0)
else:
optim = t.optim.SGD(params, lr=args.learnrate, momentum=.9, weight_decay=0.0)
best_valid_acc = 0.0
iteration = 0
train_losses = []
val_losses = []
val_corr = []
for epoch in range(args.epochs):
iter_losses = []
for i, (x_train, y_train) in tqdm(enumerate(dload_train)):
x_train, y_train = next(iter(dload_train))
x_train, y_train = x_train.to(device), y_train.to(device)
Loss = 0.
logits = f(x_train)
l_dis = nn.CrossEntropyLoss()(logits, y_train)
Loss += l_dis
iter_losses.append(Loss.item())
optim.zero_grad()
Loss.backward()
optim.step()
if iteration % args.print_every == 0:
acc = (logits.max(1)[1] == y_train).float().mean()
print('P(y|x) {}:{:>d} loss={:>14.9f}, acc={:>14.9f}'.format(epoch,
iteration,
l_dis.item(),
acc.item()))
iteration += 1
train_losses.append(np.mean(iter_losses))
if epoch % args.eval_every == 0:
f.eval()
with t.no_grad():
correct, loss = eval_classification(f, dload_valid, device)
val_losses.append(loss)
val_corr.append(correct)
print("Epoch {}: Valid Loss {}, Valid Acc {}".format(epoch, loss, correct))
if correct > best_valid_acc:
best_valid_acc = correct
print("Best Valid!: {}".format(correct))
checkpoint(f, "best_validation_ckpt.pt", args, device, dload_train, dload_valid)
f.train()
if epoch % args.ckpt_every == 0:
checkpoint(f, f'checkpoint_{epoch}.pt', args, device, dload_train, dload_valid)
# Losses are saved and can be loaded for further analysis
# You can also plot them here using matplotlib
with open("./records/trainlosses.txt" , "wb") as fp:
pickle.dump(train_losses, fp)
with open("./records/vallosses.txt" , "wb") as fp:
pickle.dump(val_losses, fp)
with open("./records/correct.txt" , "wb") as fp:
pickle.dump(val_corr, fp)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Pytorch Semantic Segmentation")
parser.add_argument("--learnrate", type=int, default=0.0001, help='learn rate of optimizer')
parser.add_argument("--optimizer", choices=['sgd', 'adam'], default='adam')
parser.add_argument("--epochs", type=int, default=7000)
parser.add_argument("--eval_every", type=int, default=1, help="Epochs between evaluation")
parser.add_argument("--print_every", type=int, default=5, help="Epochs between print")
parser.add_argument("--ckpt_every", type=int, default=2, help="Epochs between checkpoint save")
parser.add_argument("--project", choices=['project_1', 'project_2', 'project_3'], default='project_1')
parser.add_argument("--batch_size", type=int, default=2, help="Batch Size")
parser.add_argument("--num_classes", type=int, default=8, help="Number of classes of your training dataset")
parser.add_argument("--resize", type=int, default=512, help="Size of images for resizing")
parser.add_argument("--random_crop_size", type=int, default=256, help="Size of random crops. Must be smaller than resized images.")
args = parser.parse_args()
if args.random_crop_size > args.resize:
raise Exception("Crop size (--random_crop_size) must be smaller than resized image (--resize)!")
main(args)
Congrats! Now we are ready to train our own custom semantic segmentation model in Pytorch.
The whole code of this article can be found on my GitHub repository.
I hope you liked this article!
All the best,
Stefan