Deep Learning for Road Detection in Satellite Imagery
Introduction
Satellite image segmentation is a computer vision task that involves partitioning an image into multiple segments or regions to simplify its representation. In this project, I focus on segmenting satellite images to identify road areas crucial for various applications such as urban planning, transportation management, and infrastructure development.
This project contains code for a satellite image segmentation project (https://www.kaggle.com/code/balraj98/road-extraction-from-satellite-images-deeplabv3) aimed at identifying road areas using deep learning techniques. The project involves training a segmentation model on satellite imagery data and making predictions on new satellite images.
GitHub Repo: https://github.com/Milad84/DeepLab_Road_Detection
How is this different from the Kaggle post?
Although the code in many parts is inspired (copied) from the Kaggle post by BALRAJ ASHWATH, it differs in scope and scalability, and new manipulation is introduced in the second and third parts. I am trying to train the model to model other formats of satellite imagery and, specifically, TIF. The other goal is to detect the width of the street and not only a line segment. The side project consists of customizing TIF files to get the best out of them without necessarily using PNGs.
I’ll walk through the entire process, from data preprocessing to model training and inference. By the end, you’ll have a comprehensive understanding of how to apply state-of-the-art deep learning models to tackle road detection tasks.
Road detection in satellite imagery plays a crucial role in various applications, such as urban planning, transportation management, and infrastructure development. Traditional methods often rely on manual feature engineering and are limited in scalability. Deep learning techniques offer a promising alternative by automatically learning discriminative features from data.
Dataset and Libraries Used
Before diving into the implementation details, let’s briefly discuss the dataset and libraries used in this project:
- Dataset: We’ll train our model with the DeepGlobe Road Extraction Dataset (LINK), which consists of high-resolution satellite images labeled with road segments.
- Libraries: Our implementation relies on several Python libraries, including OpenCV, NumPy, PyTorch, and Segmentation Models PyTorch. These libraries provide essential functionalities for image processing, data manipulation, and deep learning model training.
Part 1: Data Preprocessing and Model Training
Overview
The first part of our implementation involves data preprocessing and model training. We’ll prepare the satellite imagery data, preprocess it, and train a deep-learning model for road detection.
Code Explanation
Let’s break down the key components of the code:
import os
import cv2
import numpy as np
import pandas as pd
import random
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
# Load metadata and preprocess data
DATA_DIR = r'C:\Users\MohammadalizadehkorM\Downloads\DeepGlobe Road Extraction Dataset'
metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)
# Define class names and RGB values
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
class_names = class_dict['name'].tolist()
class_rgb_values = class_dict[['r','g','b']].values.tolist()
# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ['background', 'road']
# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values = np.array(class_rgb_values)[select_class_indices]
# Define helper functions
def visualize(**images):
"""
Plot images in one row
"""
n_images = len(images)
plt.figure(figsize=(20,8))
for idx, (name, image) in enumerate(images.items()):
plt.subplot(1, n_images, idx + 1)
plt.xticks([]);
plt.yticks([])
# get title from the parameter names
plt.title(name.replace('_',' ').title(), fontsize=20)
plt.imshow(image)
plt.show()
def one_hot_encode(label, label_values):
"""
Convert a segmentation image label array to one-hot format
by replacing each pixel value with a vector of length num_classes
"""
semantic_map = []
for colour in label_values:
equality = np.equal(label, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
def reverse_one_hot(image):
"""
Transform a 2D array in one-hot format (depth is num_classes),
to a 2D array with only 1 channel, where each pixel value is
the classified class key.
"""
x = np.argmax(image, axis=-1)
return x
def colour_code_segmentation(image, label_values):
"""
Given a 1-channel array of class keys, colour code the segmentation results.
"""
colour_codes = np.array(label_values)
x = colour_codes[image.astype(int)]
return x
# Define dataset class
class RoadsDataset(torch.utils.data.Dataset):
def __init__(
self,
df,
class_rgb_values=None,
augmentation=None,
preprocessing=None,
):
self.image_paths = df['sat_image_path'].tolist()
self.mask_paths = df['mask_path'].tolist()
self.class_rgb_values = class_rgb_values
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, i):
image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
return image, mask
def __len__(self):
return len(self.image_paths)
# Define preprocessing and augmentation functions
def get_training_augmentation():
train_transform = [
album.HorizontalFlip(p=0.5),
album.VerticalFlip(p=0.5),
]
return album.Compose(train_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn=None):
_transform = []
if preprocessing_fn:
_transform.append(album.Lambda(image=preprocessing_fn))
_transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
return album.Compose(_transform)
# Create dataset instances
dataset = RoadsDataset(train_df, class_rgb_values=select_class_rgb_values)
random_idx = random.randint(0, len(dataset)-1)
image, mask = dataset[2]
visualize(
original_image=image,
ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
one_hot_encoded_mask=reverse_one_hot(mask)
)
# Create augmented dataset
augmented_dataset = RoadsDataset(
train_df,
augmentation=get_training_augmentation(),
class_rgb_values=select_class_rgb_values,
)
random_idx = random.randint(0, len(augmented_dataset)-1)
# Different augmentations on image/mask pairs
for idx in range(3):
image, mask = augmented_dataset[idx]
visualize(
original_image=image,
ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
one_hot_encoded_mask=reverse_one_hot(mask)
)
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = select_classes
ACTIVATION = 'sigmoid'
# Create segmentation model
model = smp.DeepLabV3Plus(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
# Get train and val dataset instances
train_dataset = RoadsDataset(
train_df,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=select_class_rgb_values,
)
valid_dataset = RoadsDataset(
valid_df,
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=select_class_rgb_values,
)
# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)
if __name__ == '__main__':
# Set flag to train the model or not
TRAINING = True
EPOCHS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.00008),])
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
if TRAINING:
best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []
for i in range(0, EPOCHS):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
train_logs_list.append(train_logs)
valid_logs_list.append(valid_logs)
if best_iou_score < valid_logs['iou_score']:
best_iou_score = valid_logs['iou_score']
torch.save(model, './best_model.pth')
print('Model saved!')
Part 2: Inference on New Images
Overview
Once we have a trained model, we can use it to perform inference on new satellite images to detect roads. This part focuses on loading and applying the trained model to new images.
How is the TIF file prepared to test the model?
The TIF samples are coming from EarthExplorer.gov
Major steps to prepare the target TIF and apply the model to it:
→Download satellite imagery from EarthExplorer.usgs.gov
→ Split the Raster
→ Use the model saved in the directory and inject your TIF
→ Assign the coordinate system
→ Save the output
→ Convert the output to binary raster (0 and other value) where one value stands for the road (0), and the other value must be deleted
→ Extract the needed value (Extract by Attribute)
→ Vectorize the output
→ Aggregate polygons (vectorized polygons).
Remember that I will not provide code for the last four steps (→ Convert to binary raster…. → Aggregate polygons (vectorized polygons) because I have done them in ArcPro.
Code Explanation
Let’s dissect the code for performing inference on new images:
import os
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'road']
ACTIVATION = 'sigmoid'
# Define preprocessing function
def preprocess_image(image):
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
image = preprocessing_fn(image)
return image.transpose(2, 0, 1).astype('float32')
# Load the model
model = torch.load('best_model.pth', map_location=torch.device('cpu')) # Load the model on CPU
# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
# Pad the input image to make dimensions divisible by 16
h, w, _ = input_image.shape
new_h = int(np.ceil(h / 16) * 16)
new_w = int(np.ceil(w / 16) * 16)
pad_top = (new_h - h) // 2
pad_bottom = new_h - h - pad_top
pad_left = (new_w - w) // 2
pad_right = new_w - w - pad_left
input_image = cv2.copyMakeBorder(input_image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
input_image = preprocess_image(input_image)
# Perform inference
with torch.no_grad():
input_tensor = torch.from_numpy(input_image).unsqueeze(0)
model.eval()
output = model(input_tensor)
# Process the output as needed
output_mask = output.squeeze().cpu().numpy() # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0) # Get the index of the class with the highest probability
# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255
# Save the output mask
output_path = r'path_to_output_mask.png'
cv2.imwrite(output_path, road_mask) # Save the road mask as an image
Part 3: Saving Output with Geospatial Information
Overview
In some cases, it’s essential to retain geospatial information when saving the output of road detection. This part demonstrates how to save the predicted road mask with geospatial information.
Code Explanation
Here’s a breakdown of the code for saving the output with geospatial information:
import os
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
from osgeo import gdal
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'road']
ACTIVATION = 'sigmoid'
# Define preprocessing function
def preprocess_image(image):
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
# Ensure input image is in RGB format
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Pad the image to make its dimensions divisible by 16
h, w, _ = image.shape
new_h = int(np.ceil(h / 16) * 16)
new_w = int(np.ceil(w / 16) * 16)
pad_top = (new_h - h) // 2
pad_bottom = new_h - h - pad_top
pad_left = (new_w - w) // 2
pad_right = new_w - w - pad_left
image = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
# Apply preprocessing function
image = preprocessing_fn(image)
return image.transpose(2, 0, 1).astype('float32')
# Specify the path to the model checkpoint file
model_checkpoint_path = 'best_model.pth'
# Check if the model checkpoint file exists
if not os.path.exists(model_checkpoint_path):
raise FileNotFoundError(f"Model checkpoint file '{model_checkpoint_path}' not found.")
# Load the model
model = torch.load(model_checkpoint_path, map_location=torch.device('cpu')) # Load the model on CPU
# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'
# Open the image using GDAL to retain geospatial information
ds = gdal.Open(input_image_path)
input_image = np.transpose(ds.ReadAsArray(), (1, 2, 0))
# Preprocess the input image
input_image = preprocess_image(input_image)
# Perform inference
with torch.no_grad():
input_tensor = torch.from_numpy(input_image).unsqueeze(0)
model.eval()
output = model(input_tensor)
# Process the output as needed
output_mask = output.squeeze().cpu().numpy() # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0) # Get the index of the class with the highest probability
# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255
# Get the geotransform and projection from the input image
geotransform = ds.GetGeoTransform()
projection = ds.GetProjection()
# Save the output mask with geospatial information
output_path = r'path_to_output_mask_with_GCS.tif'
driver = gdal.GetDriverByName('GTiff')
output_ds = driver.Create(output_path, road_mask.shape[1], road_mask.shape[0], 1, gdal.GDT_Byte)
output_ds.SetGeoTransform(geotransform)
output_ds.SetProjection(projection)
output_ds.GetRasterBand(1).WriteArray(road_mask)
output_ds = None
The above process will give you a starting point for editing the width of the roads. The picture below shows the aggregation result. Notice how the noise yields unwanted polygons:
Conclusion
In this post, I covered the entire pipeline for road detection in satellite imagery using deep learning. Each step is crucial in achieving accurate results, from data preprocessing and model training to inference and output saving. By following these steps and leveraging the provided code snippets, you can build your road detection system and apply it to real-world scenarios.
You can use this base, as shown in the last picture, to edit the width of the streets. The pipeline is far from perfect, but I will get there. Meanwhile, if you want to help with this, you are welcome to contribute to the Github repo, assuming you know how it is done.