Digging into Detectron 2 — part 2

Feature Pyramid Network

Hiroto Honda
7 min readJan 10, 2020
Figure 1. Inference result of Faster (Base) R-CNN with Feature Pyramid Network.

Hi I’m Hiroto Honda, a computer vision researcher¹. [homepage] [linkedin]

In this article I would like to share my learnings about Detectron 2 — repo structure, building and training a network, handling a data set and so on.

Detectron 2 ² is a next-generation open-source object detection system from Facebook AI Research.

In part 1, I have shown the following detailed schematic diagram (Fig. 2) that depicts the settings of the standard Base-RCNN-FPN network. [config file]
This time, we are going deeper into the Backbone Network — the Feature Pyramid Network³ (FPN).

Figure 2. Detailed architecture of Base-RCNN-FPN. Blue labels represent class names.

The role of the backbone network is to extract feature maps from the input image. Here there are not any bounding boxes, anchors, or loss functions yet!

Input and Output of FPN

Firstly we will clarify input and output of FPN. Figure 3 is the closer look at the FPN schematic.

Figure 3. Detailed architecture of the backbone of Base-RCNN-FPN with ResNet50. Blue labels represent class names. (a), (b) and (c) inside the blocks stand for the bottleneck types detailed in Fig. 5.

input (torch.Tensor): (B, 3, H, W) image

B, H and W stand for batch size, image height and width respectively. Be careful that the order of input color channels is Blue, Green and Red (BGR). If you put an RGB image as input, detection accuracy might drop.

output (dict of torch.Tensor): (B, C, H / S, W / S) feature maps

C and S stand for channel size and stride. By default C=256 for all the scales and S = 4, 8, 16, 32 and 64 for P2, P3, P4, P5 and P6 outputs respectively.

For example, if we put a single image whose size is (H=800, W=1280) into the backbone, the input tensor size is torch.Size([1, 3, 800, 1280]) and the output dict should be:

output["p2"].shape -> torch.Size([1, 256, 200, 320]) # stride = 4 
output["p3"].shape -> torch.Size([1, 256, 100, 160]) # stride = 8
output["p4"].shape -> torch.Size([1, 256, 50, 80]) # stride = 16
output["p5"].shape -> torch.Size([1, 256, 25, 40]) # stride = 32
output["p6"].shape -> torch.Size([1, 256, 13, 20]) # stride = 64

Fig. 4 shows what the actual output feature maps look like. One pixel of ‘P6’ feature corresponds to broader area of input image than ‘P2’- in other words ‘P6’ has a larger receptive field than ‘P2’. FPN can extract multi-scale feature maps with different receptive fields. See Appendix for the codes to visualize feature maps like Fig.4 (updated in Dec. 2022).

Figure 4: Example of input and output of FPN. The feature at the 0th channel is visualized from each output.

Code Structure⁴ for Backbone modeling

The related files are under detectron2/modeling/backbone directory:

├─modeling   
│ ├─backbone
│ │ ├─backbone.py <- includes abstract base class Backbone
│ │ ├─build.py <- call builder function specified in config
│ │ ├─fpn.py <- includes FPN class and sub-classes
│ │ ├─resnet.py <- includes ResNet class and sub-classes

The following is the class hierarchy.

FPN (backbone/fpn.py)
ResNet (backbone/resnet.py)
│ ├ BasicStem (backbone/resnet.py)
BottleneckBlock (backbone/resnet.py)
LastLevelMaxPool (backbone/fpn.py

ResNet

ResNet⁵ consists of a stem block and ‘stages’ that contain multiple bottleneck blocks. As for ResNet50, the block structure is :

BasicStem
(res2 stage, 1/4 scale)
BottleneckBlock (b)(stride=1, with shortcut conv)
BottleneckBlock (a)(stride=1, w/o shortcut conv) × 2
(res3 stage, 1/8 scale)
BottleneckBlock (c)(stride=2, with shortcut conv)
BottleneckBlock (a)(stride=1, w/o shortcut conv) × 3
(res4 stage, 1/16 scale)
BottleneckBlock (c)(stride=2, with shortcut conv)
BottleneckBlock (a)(stride=1, w/o shortcut conv) × 5
(res5 stage, 1/32 scale)
BottleneckBlock (c)(stride=2, with shortcut conv)
BottleneckBlock (a)(stride=1, w/o shortcut conv) × 2

ResNet101 and ResNet152 have larger number of bottleneck blocks (a), defined at: [code link].

(1) BasicStem (stem block) [code link]

The ‘stem’ block of ResNet is quite simple. It down-samples the input image twice by 7×7 convolution with stride=2 and max pooling with stride=2.
The output of the stem block is a feature map tensor whose size is (B, 64, H / 4, W / 4).

- conv1 (kernel size = 7, stride = 2)
- batchnorm layer
- ReLU
- maxpool layer (kernel size = 3, stride = 2)

(2) BottleneckBlock [code link]

The bottleneck block is originally proposed in the ResNet paper⁵. The block has three convolution layers whose kernel sizes are 1×1, 3×3, 1×1 respectively. The input and output channel numbers of 3×3 convolution layer are smaller than the input and output of the block, for efficient computation.

There are three types of bottleneck blocks as shown in Fig.5 :
(a): stride=1, w/o shortcut conv
(b): stride=1, with shortcut conv
(c) : stride=2, with shortcut conv

shortcut convolution (used in (b), (c))

ResNet has identity shortcut that adds the input and the output features. For the first block of a stage (res2-res5), a shortcut convolution layer is used to match the number of channels of input and output.

downsampling convolution with stride=2 (used in (c))

At the first block of the res3, res4 and res5 stages, the feature map is downsampled by a convolution layer with stride=2. A shortcut convolution with stride=2 is also used, because the input channel number is not the same as the output.

Note that a ‘convolution layer’ mentioned above contains convolution torch.nn.Conv2d and normalization (e. g. FrozenBatchNorm⁶). [code link]
ReLU activation is used after convolutions and feature addition (see Fig. 5).

Figure 5. Three types of bottleneck blocks.

FPN

FPN contains ResNet, lateral and output convolution layers, up-samplers and a last-level maxpool layer. [code link]

lateral convolution layers [code link]

This layer is called ‘lateral’ convolution because FPN is originally depicted like a pyramid where the stem layer is placed at the bottom (it is rotated in this article). The lateral convolution layers take features from the res2-res5 stages with different channel numbers and returns 256-ch feature maps.

output convolution layers [code link]

An output convolution layer contains 3×3 convolution that does not change number of channels.

forward process [code link]

Figure 6. Zooming into a part of FPN schematic that deals with res4 and res5.

H/32, The forward process of FPN starts from the res5 output (see Fig. 6).
After going through the lateral convolution, the 256-channel feature map is fed to the output convolution, to be registered to the results list as P5 (1/32 scale).

The 256-channel feature map is also fed to the up-sampler ( F.interpolate with nearest neighbor) and added to the res4 output (via lateral convolution). The resulting feature map goes through the output convolution and the result tensor P4 is inserted to the results list (1/16 scale).

The routine above (from up-sampling to insertion to the results) is carried out three times, and finally the result list contains four tensors — namely P2 (1/4 scale), P3 (1/8), P4 (1/16) and P5 (1/32).

LastLevelMaxPool [code link]

To make the P6 output, a max pooling layer with kernel size = 1 and stride = 2 is added to the final block of the ResNet. This layer just down-samples P5 features (1/32 scale) to 1/64-scale features to be added to the result list.

To Be Continued…

Now we have obtained multi-scale feature maps. In the next part we will look at the data loader and ground-truth data preparation. Thank you for reading and please wait for the next part!

part 1: Introduction — Basic Network Architecture and Repo Structure
part 2 (you are here) : Feature Pyramid Network
part 3 (next story!) : Data Loader and Ground Truth
part 4: Region Proposal Network
part 5: ROI (Box) Head

[1] This is a personal article and the opinions expressed here are my own and not those of my employer.
[2] Yuxin Wu, Alexander Kirillov, Francisco Massa, Wan-Yen Lo and Ross Girshick, Detectron2. https://github.com/facebookresearch/detectron2, 2019.
[3] T.-Y. Lin, P. Dollar, R. Girshick, K. He, B. Hariharan, and S. Belongie. Feature pyramid networks for object detection. In CVPR, 2017.
[4] as of Jan. 5, 2020. The file, directory, and class names are cited from the repository² ( Copyright 2019, Facebook, Inc. )
[5] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, 2016
[6] Why are the batchnorm parameters frozen? See this issue.

Appendix (updated on Dec. 6, 2022)

As I receive many reactions for Fig. 4, I leave the code when I saved the heatmap during inference. Just add the following line after this line (v0.1) or this line (v0.6, not verified yet):

self.save_features_as_images(features)

and add the following function to the GeneralizedRCNN class:

    @staticmethod
def save_features_as_images(features):
import numpy as np
import cv2
for k, v in features.items():
v_ = v[:, 0].cpu().numpy()
v_ = (v_ / 16 + 0.5) * 255
v_ = np.asarray(v_.clip(0, 255), dtype=np.uint8).transpose((1, 2, 0))
cv2.imwrite(k + '.png', v_)

--

--