Satellite image segmentation— part 3

Romain Guion
VorTECHsa
Published in
7 min readJun 12, 2020

Ship surveillance & tracking — part 3: identifying where ships are on an image, using a U-net built on top of the encoder trained in part 2.

Inputs and outputs of the U-Net model we build in this blog post

This is the 3rd part of a 3 blog posts series:

  • Ship detection — Part 1: binary prediction of whether there is at least 1 ship, or not. Part 1 is a simple solution showing great results in a few lines of code
  • Ship detection & early localisation — Part 2: towards object segmentation with (a) training fully convolutional NN through transfer learning, to build an encoder for our U-Net, and (b) first emerging localisation properties with CAM attention interpretability
  • Ship localisation — Part 3: identify where ship are within the image, building a U-Net from the encoder developed in Part 2

Code available on github

Introduction to image segmentation, and U-net

Image segmentation has a long history of traditional algorithms, from thresholding to clustering and Markov random fields. With the progress in deep learning, a number of algorithms to draw bounding boxes then became state of the art, like Single Shot multibox Detector (SSD) or You Only Look Once (YOLO), popular with autonomous vehicles. In this article, we will build on the implicit localisation properties of fully convolutional networks we demonstrated with CAM in Part 2 of this blog post series. We will build a U-Net, an architecture created in 2015 in the paper “U-Net: Convolutional Networks for Biomedical Image Segmentation” , itself an improvement of “Fully convolutional networks for semantic segmentation” . Both of these networks take a different strategy to YOLO and the like, and perform pixel by pixel segmentation instead of bounding boxes.

U-Net architecture from the original U-Net paper, with encoder / downsampling part on the left of the U, and decoder / upsampling part on the right of the U. In this post, we trained the encoder on a classification task in part 2.
Illustration of downsampling in encoder until pool5, and then upsampling in decoder, from the original FCN paper

A segmentation task is different from a classification task because it requires predicting a class for each pixel of the input image, instead of only 1 class for the whole input. Classification needs to understand what is in the input (namely, the context).

To do this, in this post we expand on the Fully Convolutional Networks (FCNs) we developed in part 2, built only from locally connected layers, such as convolution, pooling and upsampling. In part 2, we discovered that the classifier had implicit localisation properties, stemming from this local topology, and we visualised the results through a very rough static upsampling step. The idea in this post is to train a network at this upsampling step, this time with labeled masks of how we want to image to be reconstituted.

  • Downsampling path : capture semantic/contextual information (what). In this post, we will reuse the encoder Xception trained in part 2 on a classification task (ship / no ship), and keep its weights frozen
  • Upsampling path : recover spatial information (where). In this part 3, this is the part of the network we will train.

Labeled data: images and masks with ships

Data input: images and masks highlighting which pixels correspond to a ship

U-Net implementation in TensorFlow, for this ship localisation task

U-Net implementation in tensorflow, from pretrained encoder (here Xception pretrained in part 2). Here the choices of where to establish the skip connections was semi-arbitrary. This is a hyperparameter we could optimise.
U-Net implementation in tensorflow. The variable “model” is the encoder (here Xception pretrained in part 2)

Managing class imbalance

In the dataset we’re using, 78% of images have no ship. For images with ≥ 1 ship, the ratio of pixels with a ship to pixels with no ship is about 1/1000. Together, this makes about 1/5000 class imbalance, which is tough to handle. To manage this class imbalance, we’re taking 2 actions:

  • separate the ship detection task from the image segmentation task
  • for the segmentation task, pixel accuracy probably won’t cut it (network would likely learn to always say “no ship”, so we’ll mix the loss function with a different metric called the DICE similarity coefficient (2 * |X| inter |Y|) / (|X| + |Y|). We’ll also track a similar metric called intersection over union (|X| inter |Y|) / (|X| U |Y|), and the true positives rate of pixels

Managing overfitting

In this first implementation, I have managed overfitting as upstream as I could, by:

  1. performing a lot of data augmentation on our already pretty large training set (given the short training I gave for this blog post, training set isn’t even seen once by the neural network, let alone with data augmentation, but this means we could train for much longer and see benefits)
  2. freezing encoder weights, so the capacity of the model is reduced: out of 28 million weights, “only” 7 millions are trainable. Fine tuning at a later stage may be valuable, but to start it reduces the number of weights.

What could be done next is more traditional regularisation, dropouts and early stopping.

While managing the combined masks and images pipeline was a bit of engineering Kung Fu for this dataset, the data augmentation can very nicely use keras/tensorflow api. The key when using this is to fix the randomness seed, and apply the same transformation to image and label. More details in the repo.
Training data for U-Net: data augmentation multiplies the effective training set, helping the model to generalise / reduce overfitting

Training process

After only 10 min of training (iterating through 8000 images), our U-Net already already made pretty decent predictions.

Prediction after 2 pseudo-epochs (2 x 200 steps x 20 images). The U-net localises vessels after very little training, and that’s probably the reflection of the implicit localisation properties of our fully convolutional encoder, as highlighted by the CAM model in part 2. That may also contribute to the model being robust to mislabeling: on this image there are clearly 3 vessels, but the label highlights only 2.

Although the performance started relatively high, presumably stemming from our pretrained fully convolutional encoder, the initial training attempts staled at a validation Dice coefficient of about 0.35, stagnating for about 20 pseudo-epochs (of 200 steps of 20 images). Using Adam, and changing learning rates didn’t seem to help. Under the assumption of a premature convergence to a local optimum, a strategy of learning rate cycling was adopted (discussed in Part 2; code). This brought the val Dice to 0.75, highlighting once more the importance of tuning convergence.

Learning rate cycling schedule used to escape local optima, ranging from 5e-6 and 1e-3.
Performance during training of U-Net for a couple of hours.
Picks up all shapes, although some vessels moored alongside are merged into one.

Through a pretty basic training we reached a decent performance for the model: val dice 0.75, val IoU 0.37, val true positive pixel rate of 0.67 (and val binary pixel accuracy of 0.997). For some applications, this performance may already be enough. And notice that what we did was overall pretty simple, and the results presented here should really be seen as a starting point from which many simple improvements can be made:

  • train for longer: the model has not actually seen all the training data yet, let alone the augmented data!
  • model seems to overfit the training data: this is odd given the model hasn’t seen all the training set yet, and could be due to a bug (opportunity to improve!). Beyond the bug, regularisation, batch normalisation and dropouts are classic next steps.
  • refine loss function to better represent the learning objective. For example recent Kaggle competitions seem to use “Focal Loss”, defined by a recent Facebook paper (FAIR, 2018). It adds a factor (1 − pt)^γ to the standard cross entropy criterion. Setting γ > 0 reduces the relative loss for well-classified examples (pt > .5), putting more focus on hard, misclassified examples
  • hyperparameter optimisation: U-Net structure was pretty random, and where the skips are placed probably matters
  • better train the encoder (in part 2 we also stopped as soon as we got decent performance)

I hope you enjoyed the ride, and that you now feel that building an image detection and segmentation algorithm isn’t that big a deal!

To conclude, we will show more predictions where the model sometimes does well, and other times not so well:

Model can learn complex shapes despite labels being rectangular
Model handles various contrasts quite well
Some ships are tiny
partial view of a vessel is handled ok
Model may have learned to ignore wash, which makes is ignore some vessels in the wash of others here?
Struggling with the shadow, or just too small?
Robust to partial / corrupted images

--

--

Romain Guion
VorTECHsa

VP data production at Vortexa - the data science & engineering team. Previously consultant in the medical devices industry. Cambridge and Centrale Paris alumni.