Satellite image segmentation— part 3
Ship surveillance & tracking — part 3: identifying where ships are on an image, using a U-net built on top of the encoder trained in part 2.
This is the 3rd part of a 3 blog posts series:
- Ship detection — Part 1: binary prediction of whether there is at least 1 ship, or not. Part 1 is a simple solution showing great results in a few lines of code
- Ship detection & early localisation — Part 2: towards object segmentation with (a) training fully convolutional NN through transfer learning, to build an encoder for our U-Net, and (b) first emerging localisation properties with CAM attention interpretability
- Ship localisation — Part 3: identify where ship are within the image, building a U-Net from the encoder developed in Part 2
Introduction to image segmentation, and U-net
Image segmentation has a long history of traditional algorithms, from thresholding to clustering and Markov random fields. With the progress in deep learning, a number of algorithms to draw bounding boxes then became state of the art, like Single Shot multibox Detector (SSD) or You Only Look Once (YOLO), popular with autonomous vehicles. In this article, we will build on the implicit localisation properties of fully convolutional networks we demonstrated with CAM in Part 2 of this blog post series. We will build a U-Net, an architecture created in 2015 in the paper “U-Net: Convolutional Networks for Biomedical Image Segmentation” , itself an improvement of “Fully convolutional networks for semantic segmentation” . Both of these networks take a different strategy to YOLO and the like, and perform pixel by pixel segmentation instead of bounding boxes.
A segmentation task is different from a classification task because it requires predicting a class for each pixel of the input image, instead of only 1 class for the whole input. Classification needs to understand what is in the input (namely, the context).
To do this, in this post we expand on the Fully Convolutional Networks (FCNs) we developed in part 2, built only from locally connected layers, such as convolution, pooling and upsampling. In part 2, we discovered that the classifier had implicit localisation properties, stemming from this local topology, and we visualised the results through a very rough static upsampling step. The idea in this post is to train a network at this upsampling step, this time with labeled masks of how we want to image to be reconstituted.
- Downsampling path : capture semantic/contextual information (what). In this post, we will reuse the encoder Xception trained in part 2 on a classification task (ship / no ship), and keep its weights frozen
- Upsampling path : recover spatial information (where). In this part 3, this is the part of the network we will train.
Labeled data: images and masks with ships
U-Net implementation in TensorFlow, for this ship localisation task
Managing class imbalance
In the dataset we’re using, 78% of images have no ship. For images with ≥ 1 ship, the ratio of pixels with a ship to pixels with no ship is about 1/1000. Together, this makes about 1/5000 class imbalance, which is tough to handle. To manage this class imbalance, we’re taking 2 actions:
- separate the ship detection task from the image segmentation task
- for the segmentation task, pixel accuracy probably won’t cut it (network would likely learn to always say “no ship”, so we’ll mix the loss function with a different metric called the DICE similarity coefficient (2 * |X| inter |Y|) / (|X| + |Y|). We’ll also track a similar metric called intersection over union (|X| inter |Y|) / (|X| U |Y|), and the true positives rate of pixels
Managing overfitting
In this first implementation, I have managed overfitting as upstream as I could, by:
- performing a lot of data augmentation on our already pretty large training set (given the short training I gave for this blog post, training set isn’t even seen once by the neural network, let alone with data augmentation, but this means we could train for much longer and see benefits)
- freezing encoder weights, so the capacity of the model is reduced: out of 28 million weights, “only” 7 millions are trainable. Fine tuning at a later stage may be valuable, but to start it reduces the number of weights.
What could be done next is more traditional regularisation, dropouts and early stopping.
Training process
After only 10 min of training (iterating through 8000 images), our U-Net already already made pretty decent predictions.
Although the performance started relatively high, presumably stemming from our pretrained fully convolutional encoder, the initial training attempts staled at a validation Dice coefficient of about 0.35, stagnating for about 20 pseudo-epochs (of 200 steps of 20 images). Using Adam, and changing learning rates didn’t seem to help. Under the assumption of a premature convergence to a local optimum, a strategy of learning rate cycling was adopted (discussed in Part 2; code). This brought the val Dice to 0.75, highlighting once more the importance of tuning convergence.
Through a pretty basic training we reached a decent performance for the model: val dice 0.75, val IoU 0.37, val true positive pixel rate of 0.67 (and val binary pixel accuracy of 0.997). For some applications, this performance may already be enough. And notice that what we did was overall pretty simple, and the results presented here should really be seen as a starting point from which many simple improvements can be made:
- train for longer: the model has not actually seen all the training data yet, let alone the augmented data!
- model seems to overfit the training data: this is odd given the model hasn’t seen all the training set yet, and could be due to a bug (opportunity to improve!). Beyond the bug, regularisation, batch normalisation and dropouts are classic next steps.
- refine loss function to better represent the learning objective. For example recent Kaggle competitions seem to use “Focal Loss”, defined by a recent Facebook paper (FAIR, 2018). It adds a factor (1 − pt)^γ to the standard cross entropy criterion. Setting γ > 0 reduces the relative loss for well-classified examples (pt > .5), putting more focus on hard, misclassified examples
- hyperparameter optimisation: U-Net structure was pretty random, and where the skips are placed probably matters
- better train the encoder (in part 2 we also stopped as soon as we got decent performance)
I hope you enjoyed the ride, and that you now feel that building an image detection and segmentation algorithm isn’t that big a deal!
To conclude, we will show more predictions where the model sometimes does well, and other times not so well: