Notes on Focal Loss and RetinaNet

RetinaNet[1] is a one-stage object detector (like SSD and YOLO), which has the performance of two-stage detectors (like Faster-RCNN). The main contribution of this paper is a new loss function called Focal loss[1] for classification, which significantly increased the accuracy. RetinaNet is essentially a Feature Pyramid Network with the cross-entropy loss replaced by Focal loss. The observation and ideas might be more interesting.

Performance boost[1]

Why is the one-stage detector worse than the two-stage detector?

One of the differences between one-stage detector and two-stage detector is that the one-stage detector uses a fixed grid of boxes (like anchors in SSD) while the two-stage detector uses proposal network to generate(or filter) box proposals. To not miss objects, the former normally uses 10k ~ 100K box proposals per image, the later normally generate much less proposals (this number for Faster-RCNN is 2k only after Non-Maximum Suppression). For the same image, more box proposals means more background boxes. Here it brings a typical data imbalance problem. Normally the background boxes are easier to classify. Too many of them will bias the classifier to emphasize the background in order to minimize the overall loss. The objects in the image will be down-weighted.

The Solutions of data imbalance

There are two popular ways to address it in Machine Learning: downsampling dominant cases (or oversampling minority cases), changing weights in the loss function. Hard Negative Mining in SSD and Faster-RCNN belongs to the first regime. This paper follows the second way and proposes a new loss function to address this problem. From the following figure, you can see it lowers the loss for well classified cases, while emphasizing hard ones.

Focal Loss[1]

The Focal loss gives a significant increase for the detector. How about using hinge-loss (the one used in SVM)? The paper mentions the authors failed to train a stable network by using hinge-loss.


  1. Focal loss for dense object detection.