Digging into Detectron 2 — part 4

Region Proposal Network

Hiroto Honda
10 min readMar 10, 2020
Figure 1. Inference result of Faster (Base) R-CNN with Feature Pyramid Network.

Hi I’m Hiroto Honda, a computer vision researcher¹. [homepage] [linkedin]

In this article I would like to share my learnings about Detectron 2 — repo structure, building and training a network, handling a data set and so on.

Detectron 2 ² is a next-generation open-source object detection system from Facebook AI Research.

In part 1, part 2 and part 3, we have seen the overview of the Base-RCNN-FPN, feature pyramid network, and ground truth preparation, respectively.
This time, we are going deep into the most complicated but important part — the Region Proposal Network (see Fig. 2).

Figure 2. Detailed architecture of Base-RCNN-FPN. Blue labels represent class names.

As we have seen in part 2, the output feature maps from the feature pyramid network are:

output[“p2”].shape -> torch.Size([1, 256, 200, 320]) # stride = 4 
output[“p3”].shape -> torch.Size([1, 256, 100, 160]) # stride = 8
output[“p4”].shape -> torch.Size([1, 256, 50, 80]) # stride = 16
output[“p5”].shape -> torch.Size([1, 256, 25, 40]) # stride = 32
output[“p6”].shape -> torch.Size([1, 256, 13, 20]) # stride = 64

which are also the input to the RPN. Each tensor size stands for (batch, channels, height, width). We use the feature dimensions above throughout this blog part.

We also have ground truth boxes loaded from the dataset (see part 3) :

'gt_boxes': Boxes(tensor([
[100.58, 180.66, 214.78, 283.95],
[180.58, 162.66, 204.78, 180.95]
])),
'gt_classes': tensor([9, 9]) # not used in RPN!

How can object detectors connect the feature maps and ground-truth box locations and sizes? Let’s see how RPN — the core component of RCNN detector — works.

Fig. 3 shows the detailed schematic of RPN. RPN consists of a neural network (RPN Head) and non-neural-network functionalities. All the computation in RPN³ is performed on GPU in Detectron 2.

Figure 3. Schematic of Region Proposal Network. Blue and red labels represent class names and chapter titles respectively.

Firstly, let’s see the RPN Head that processes the feature maps fed from the FPN.

1. RPN Head

The neural network part of RPN is simple. It is called RPN Head and consists of three convolution layers defined in the StandardRPNHead class.

1. conv (3×3, 256 -> 256 ch)
2. objectness logits conv (1×1, 256 -> 3 ch)
3. anchor deltas conv (1×1, 256 -> 3×4 ch)

The feature maps of five levels (P2 to P6) are fed to the network one by one.
The output feature maps at one level are:

1. pred_objectness_logits (B, 3 ch, Hi, Wi): probability map of object existence
2. pred_anchor_deltas (B, 3×4 ch, Hi, Wi): relative box shape to anchors

where B stands for batch size and Hi and Wi correspond to the feature map sizes of P2 to P6.

What do they look like actually? In Fig. 4, the objectness logits map at each level is overlaid on the input image. You can find that small objects are detected at P2 and P3 and the larger ones at P4 to P6. This is exactly what feature pyramid network aims for. The multi-scale network can detect tiny objects which a single-scale detector cannot find. Please see Appendix (updated on Dec. 6, 2022) if you wish to visualize like Fig. 4.

Next, let’s proceed to anchor generation, which is essential to associate the ground truth boxes with the two output feature maps above.

Figure 4. Visualization of objectness maps. Sigmoid function has been applied to the objectness_logits map. The objectness maps for 1:1 anchor are resized to the P2 feature map size and overlaid on the original image.

2. Anchor Generation

To connect the objectness map and anchor deltas map to the ground truth boxes, the reference boxes called ‘anchors’ are necessary.

2–1. Generate Cell Anchors

In Base-FPN-RCNN of Detectron 2, the anchors are defined as follows:

MODEL.ANCHOR_GENERATOR.SIZES = [[32], [64], [128], [256], [512]]
MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]

What does it mean?
The five elements of ANCHOR_GENERATOR.SIZES list correspond to five levels of feature maps (P2 to P6). For example P2 (stride=4) has one anchor whose size is 32.

Aspect ratios define the shapes of anchors. For the example above, there are three shapes: 0.5, 1.0 and 2.0. Let’s see the actual anchors (Fig. 5). The three anchors at P2 feature map have aspect ratios of 1:2, 1:1 and 2:1 and the same areas as 32 ×32. At P3 level anchors are twice as large as the P2 anchors.

Figure 5. Cell anchors for P2 and P3 feature maps. (from left: 1:2, 1:1 and 2:1 aspect ratios)

These anchors are called ‘cell anchors’ in Detectron 2. (The code for anchor generation is here.) As a result we have obtained 3×5=15 cell anchors for the five feature map levels.

# cell anchors for P2, P3, P4, P5 and P6. (x1, y1, x2, y2)
tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
[-16.0000, -16.0000, 16.0000, 16.0000],
[-11.3137, -22.6274, 11.3137, 22.6274]])
tensor([[-45.2548, -22.6274, 45.2548, 22.6274],
[-32.0000, -32.0000, 32.0000, 32.0000],
[-22.6274, -45.2548, 22.6274, 45.2548]])
tensor([[-90.5097, -45.2548, 90.5097, 45.2548],
[-64.0000, -64.0000, 64.0000, 64.0000],
[-45.2548, -90.5097, 45.2548, 90.5097]])
tensor([[-181.0193, -90.5097, 181.0193, 90.5097],
[-128.0000, -128.0000, 128.0000, 128.0000],
[ -90.5097, -181.0193, 90.5097, 181.0193]])
tensor([[-362.0387, -181.0193, 362.0387, 181.0193],
[-256.0000, -256.0000, 256.0000, 256.0000],
[-181.0193, -362.0387, 181.0193, 362.0387]])

2–2. Place Anchors on the Grid Points

Next we place the cell anchors on the grid cells whose sizes are the same as predicted feature maps.

For example our predicted feature map ‘P6’ has the size of (13, 20) and stride of 64. In Fig. 6 the P6 grid is shown with three anchors placed at (5, 5). In the input image resolution, the coordinate of (5, 5) corresponds to (320, 320) and the square anchor’s size is (512, 512). The anchors are placed at every grid point, so 13×20×3 = 780 anchors are generated for P6.
The same process is carried out for the other grid points (see the example of P5 grid in Fig. 6) and totally 255780 anchors are generated.

Figure 6. Placing anchors on grid points. The top-left corner of each grid corresponds to (0, 0).

3. Ground Truth Preparation

In this chapter we associate ground truth boxes with the generated anchors.

3–1. Calculate Intersection-over-Unit (IoU) Matrix

Let’s assume we have two ground truth (GT) boxes loaded from a dataset.

'gt_boxes': 
Boxes(tensor([
[100.58, 180.66, 214.78, 283.95],
[180.58, 162.66, 204.78, 180.95]
])),

We now try to find the anchors similar to the two GT boxes out of the 255780 anchors. How can you tell whether a box is similar to another? The answer is Intersection over Union (IoU) calculation. In Detectron2, pairwise_iou function can calculate IoU for every pair from two lists of boxes. In our case, result of pairwise_iou is a matrix whose size is (2(GT), 255780(anchors)).

# Example of IoU matrix, the result of pairwise_iou
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],#GT 1
[0.0000, 0.0000, 0.0000, ..., 0.0087, 0.0213, 0.0081],#GT 2

3–2. Examine the IoU matrix by Matcher

The IoU matrix is examined by Matcher and all the anchors are labeled as foreground, background, or ignored. As shown in Fig. 7, if IoU is larger than the pre-defined threshold (typically 0.7), the anchor is assigned to one of the GT boxes and labeled as foreground (‘1’). If IoU is smaller than another threshold (typically 0.3), the anchor is labeled as a background (‘0’), otherwise ignored (‘-1’).

Figure 7. Matcher determines assignment of anchors to ground-truth boxes. The table shows the IoU matrix whose shape is (number of GT boxes, number of anchors).

In Fig. 8 we show the matching result overlaid on the input image. As you can see, most of the grid points are labeled as background (0) and a few of them as foreground (1) and ignored (-1).

Figure 8. Matching result overlaid on the input image.

3–3. Calculate anchor deltas

The anchor boxes determined as foreground have similar shapes to the GT boxes. However the network should learn to propose the exact locations and shapes of GT boxes. To achieve it four regression parameters should be learned : Δx, Δy, Δw, and Δh. These ‘deltas’ are calculated as shown in Fig. 9 using the Box2BoxTransform.get_deltas function. The formulation is written in the Faster-RNN paper⁴.

Figure 9. Anchor deltas calculation.

As a result, we obtain a tensor called gt_anchor_deltas whose shape is (255780, 4) in our case.

# calculated deltas (dx, dy, dw, dh)
tensor([[ 9.9280, 24.6847, 0.8399, 2.8774],
[14.0403, 17.4548, 1.1865, 2.5308],
[19.8559, 12.3424, 1.5330, 2.1842],
...,

3–4. Re-sample the boxes for loss calculation

Now we have objectness_logits and anchor_deltas on every grid point of feature maps, which we can compare predicted feature maps with.

In Fig.10 (left) is the breakdown of number of anchors per image and example. As you can see, the majority of anchors are background. For example, typically there are less than 100 foreground anchors, less than 1000 ignored anchors in 255780 anchors and the rest are background. If we go on training, it’s hard to learn the foreground ones due to the label imbalance.

The labels are re-sampled by using the subsample_labels function to solve the imbalance issue.

Let N be the target number of foreground + background boxes and F be the target number of foreground boxes. N and F / N are defined by the following config parameters.

N: MODEL.RPN.BATCH_SIZE_PER_IMAGE (typically 256)
F/N: MODEL.RPN.POSITIVE_FRACTION (typically 0.5)

Fig. 10 (center) shows the breakdown of re-sampled boxes. Background and foreground boxes are randomly selected so that N and F/N become the values defined by the parameters above. In case foreground number is less than F as shown in Fig. 10 (right), background boxes are sampled to fill the N samples.

Figure 10. Re-sampling the foreground and background boxes.

4. Loss Calculation

Two loss functions are applied to the prediction and ground truth maps at the rpn_losses function.

localization loss (loss_rpn_loc)

  • l1 loss⁵.
  • Calculated only at the grid points where ground-truth objectness=1 (foreground), which means all the background grid points are ignored to compute the loss.

objectness loss (loss_rpn_cls)

  • Binary cross entropy loss.
  • Calculated only at the grid points where ground-truth objectness=1 (foreground) or 0 (background).

The actual loss results are as follows:

{
'loss_rpn_cls': tensor(0.6913, device='cuda:0', grad_fn=<MulBackward0>),
'loss_rpn_loc': tensor(0.1644, device='cuda:0', grad_fn=<MulBackward0>)
}

5. Proposal Selection

Lastly we choose 1000 ‘region proposal’ boxes from the predicted boxes following the four steps below.

  1. Apply predicted anchor_deltas to the corresponding anchors, which is the reverse process of 3–3.
  2. The predicted boxes are sorted by the predicted objectness scores at each feature level.
  3. As shown in Fig. 11, the top-K scored boxes (defined by the config parameters) are chosen from each feature level of an image⁶. For example, 2,000 boxes are chosen from 192,000 boxes at P2. For P6 where less than 2,000 boxes exist, all the boxes are selected.
  4. Non-maximum suppression (batched_nms) is applied at each level independently. 1,000 top-scored boxes survive as a result.
Figure 11. Choosing top-K proposal boxes from each feature level. The numbers of boxes are the examples when input image size is (H=800, W=1280).

Finally, we obtain proposal boxes as ‘Instances’ with:
‘proposal_boxes’: 1,000 boxes
‘objectness_logits’: 1,000 scores
which are used in the next stage.

To Be Continued…

In the next part we proceed to the Box head, the second stage of R-CNN. Thank you for reading and please wait for the next part!

part 1: Introduction — Basic Network Architecture and Repo Structure
part 2 : Feature Pyramid Network
part 3 : Data Loader and Ground Truth Instances
part 4 (you are here): Region Proposal Network
part 5: ROI (Box) Head

Check this out too!

I also published the slides named ‘Digging into Sample Assignment Methods for Object Detection’, where I focus on how the detectors (Faster-RCNN, RetinaNet, YOLOv1–5, etc) define training samples for the given feature map and ground truth boxes.

https://speakerdeck.com/hirotohonda/digging-into-sample-assignment-methods-for-object-detection

[1] This is a personal article and the opinions expressed here are my own and not those of my employer.
[2] Yuxin Wu, Alexander Kirillov, Francisco Massa, Wan-Yen Lo and Ross Girshick, Detectron2. https://github.com/facebookresearch/detectron2, 2019. The file, directory, and class names are cited from the repository ( Copyright 2019, Facebook, Inc. )
[3] the files used for RPN are: modeling/proposal_generator/rpn.py, modeling/proposal_generator/rpn_outputs.py, modeling/anchor_generator.py, modeling/box_regression.py, modeling/matcher.py and modeling/sampling.py
[4] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun. Faster R-CNN: Towards real-time object detection with region proposal networks. In NIPS, 2015. (link)
[5] Implemented as smooth-l1 loss, but it’s actually pure-l1 loss unlike Detectron1 or MMDetection. (see : link).
[6] In Detectron 1 and maskrcnn-benchmark, the top-K proposals are chosen from a batch during training. (see: link).

Appendix (updated on Dec. 6, 2022)

As I receive many reactions for Fig. 4, I leave the code when I visualized the heatmap. Just add the following lines after this line inside the for loop for v0.1, Dec. 2019 (when this blog was published) or this line for v0.6 (not verified yet):

o = pred_objectness_logits[-1].sigmoid() * 255
o = o.cpu().detach().numpy()[0, 1]
o = cv2.resize(o, (320, 184))
cv2.imwrite('objectness'+str(i)+'.png', np.asarray(o, dtype=np.uint8))

To visualize the heatmaps, I used the following codes on a notebook:

import matplotlib.pyplot as plt
import cv2
imgs = []
imgs.append(cv2.imread('original_image.jpg'))
for i in range(5): # five levels
heatmap = cv2.imread('objectness' + str(i) + '.png')
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
imgs.append(cv2.resize(imgs[0], (320, 184)) // 2 + heatmap // 2) # blending
fig = plt.figure(figsize=(16, 7))
for i, img in enumerate(imgs):
fig.add_subplot(2, 3, i + 1)
if i > 0:
plt.imshow(img[0:-1, :, ::-1]) # ::-1 removes the edge
plt.title("objectness on P" + str(i + 1))
else:
plt.imshow(img[:, :, ::-1])
plt.title("input image")
plt.show()

--

--