Scratch to SOTA: Build Famous Classification Nets 5 (ResNet)
Introduction
The advent of ResNet family at the end of 2015 shattered many computer vision records. Its model ensemble achieved 3.57% of top-5 error on ImageNet classification challenges (nearly halved GoogLeNet’s 6.67%). It attained a 28% of relative improvement on COCO object detection challenge, by simply replacing the detector’s backbone. It surely popularized the usage of skip connection (now it is everywhere!). In fact, ResNet is so effective that Geoffrey Hinton expressed some degree of regret that his team did not develop a similar idea into ResNet.
In this article, we will go through the ideas behind ResNet and its PyTorch implementation.
Overview
- Motivation and rationale behind skip connection.
- ResNet structures and bottleneck design.
- PyTorch implementation and discussion.
ResNet Basics
Motivation
We have got VGG19 and GoogLeNet (which has 22 layers) up till now, why don’t we just design a VGG100 or a GoogLeNet with 200 layers? The reason is that adding depth has diminishing margin of return. The accuracy will improve up to certain point then deteriorate.
“Oh, it’s overfitting, right?”, we’d make this educated conjecture. However, it is not the situation. If we add more layers to our VGG or GoogLeNet, not only will the validation accuracy decrease, so will the training accuracy. The graph below shows this phenomenon. It shows that the deeper network cannot even fit the training data as well as its shallower counterpart.
This observation is quite counter-intuitive. If we append more layers to a network, the “solution space” of the deeper network should contains that of the shallower network. If the newly added layers can at least learn an identity mapping, the deeper layers should have a no-worse performance.
The gap between the ideal and the reality, or the theory and the practical, is signified by this observation. While the mathematics ordains that a deeper network has higher power of representation, our current solver is incapable of arriving at the desirable locations in this larger solution space.
To partially overcome this gap, VGG networks trained its layers in stages (or used better initialization), GoogLeNet added auxiliary classifiers, ResNet implemented the skip connections.
Skip Connection
The authors of ResNet addressed this degradation problem by letting the layers of a network derive a transformation of the feature to add to the feature itself instead of transforming the features directly. As shown in the figure below, the two weight layers learn a “residual mapping” F(x)
to add to the original feature x
. If the underlying desirable transformation the layers are to achieve is H(x)
, layers without skip connection are forcing its mapping F(x)
to approximate H(x)
, while layers with skip connection are approximting H(x) — x
.
This design is based on the hypothesis that it is easier to incrementally modify the previous feature maps than to transform the previous feature maps directly. In the case that no modification is needed for good model performance, it is easier to push the output of the layers to zero than to learn an identity mapping.
At the same time, just like in the GoogLeNet which used auxiliary classifiers to enhance the backward flow of the gradient, these skip connection also alleviates the problem of vanishing gradients. The lower layers that are more remote from the source of gradient flow can enjoy a boost in gradient signal via the “express way” of skip connections.
This really simple modification allows us to increase the depth of the network drastically. A 152-layer network can now be trained as to be seen soon in this article.
The last bit of details remain about this skip connection is how many layers should a skip connection skip? ResNet eventually lets the connections skip 2 or 3 layers. While more layers can be skipped too (check out DenseNet, whose skip connections leap over a range of layers), the authors point out skipping just one layer is not advantageous, as whose formula Wx + x
is the same as (W+I)x = W'x
. It is essentially the same as layers without skip connections.
For implementation, the original feature map and the newly transformed feature map are added before the ReLU activation.
Bottleneck Block
The figure above shows the basic block implementing skip connections in the vanilla way. However, for layers with a large number of filters, the complexity is too big to manage. The trick of using 1x1 filters to downsample the number of channels is used again here.
As shown in the figure on the left, the number of channels is first squeezed to a small number, before applying filters with larger receptive field. To let the output have the dimensions as the input (so we can use skip connection), we expand the number of channels again. The skip connection is always between feature maps of many channels, the feature maps with less channels form the “bottleneck” of the block. (Some readers may have heard of inverted residual/bottleneck block, introduced by MobileNetV2. Its skip connection is between feature maps of smaller dimension. We will look into that in a future article.)
With this bottleneck block, even the 152-layer ResNet has lower complexity than VGG16, despite the 8 times increase in depth. Wow!!
Structure of ResNets
The original paper introduced 5 variants of ResNet with 18, 34, 50, 101 and 152 layers each. This table shows their architectures.
Other than the first two layers, the first blocks of each module (conv(x)_1) also have a stride of size 2 to halve the spatial dimension. Every time the spatial dimension is halved, the number of channels doubles.
The structures of ResNets are straightforward, other than a small complication. The skip connection of the first block in each module need to be modified. The reason is that doubling of the channels happens at the first block. Hence, the skip connection will be placed between two feature maps of different number of channels.
There are two ways off-the-shelf to address this issue. The first one is tjo zero-pad the previous feature map for addition. The second one is to use 1x1 filters yet again to expand the previous feature map’s depth. In author’s experiments, the second way is superior in accuracy. The implementations of ResNet generally adopt the second design.
Unlike all the other networks introduced in this series so far, ResNet only has one fully-connected layer to further reduce the network complexity. The final feature map is simply average pooled across spatial dimensions before going into this linear classifier.
Code
To build ResNet, let’s first create a helper class called Conv2dBn
as before. It is because all convolution layers in ResNet are followed by batch normalization.
The basic version of the residual block is coded as below. There are two things to note about this code.
Firstly, notice that the bias parameter is set to False
for Conv2dBn()
. The reason is that, with batch normalization, the mean of each channel is subtracted from the feature map, removing any effects of the bias term in the Conv2d
beforehand. The bias
term in the BatchNorm2d()
will instead offer the offsets.
Secondly, as mentioned above, filters of size 1x1 is used as self.res
to match the input and output feature maps’ channel size if stride is not 1. Else, identity function is used.
The bottleneck block is coded below. self.squeeze
and self.expand
are 1x1 filters that reduces the channel size before increasing it back for skip connections.
Next, we create ResNet template that can help us build all the 5 ResNets with a simple config file. The code below shows the typical __init__()
and forward()
method of such class.
Two builder methods are shown below. __build__module()
method makes each of the 4 residual modules. basic_block
parameter refers to whether we want to use the vanilla Block
or Bottleneck
, num_blocks
parameter refers to how many such blocks are in the module, planes
refers to the input, intermediate and output feature map’s number of channels in each block, stride
is the stride size for the first block in each module.
_build_modules()
method helps compute the number filters used for each module and calls _build_module()
method to build these modules.
The _init_weights()
method is shown below. The only issue is probably the mode
for kaiming_normal_()
. The official Torchvision model used fan_out
here. The reason for my choice of fan_in
is same as that in the VGG article here.
We have finished our template/builder class now. Let’s transfer the networks’ structure from the table to a config dictionary.
“num_filters”
refers to the number of convolution filters in the first module. For Bottleneck
, it has two numbers, one for the squeezing 1x1 filters / 3x3 filters and one for the expanding 1x1 filters. The number of filters for other modules will be calculated by _build_modules()
methods. "num_blocks"
refer to the number of blocks used in each of the 4 modules.
Finally, we can build all 5 ResNets with the config dictionary and the builder class.
Conclusion
The two main concepts for ResNet are skip connection and bottleneck block. The former allows us to train very deep networks while the latter reduces the complexity of these networks dramatically. For various computer vision tasks, ResNets are often the default models to be used for preliminary investigation. Many real-time applications use the like of ResNet50 to form their feature extraction backbones. Deeper ResNets are usually used in competition for better accuracy at the cost of lower speed.
Up till this article, we have finished four very famous models. These four models aim to pack as much depth/transformation as efficiently possible to push the boundary of classification accuracy. However, there is another research direction that works to lower the complexity of models as much as possible without sacrificing to much accuracy.
In the following articles, we will introduce three of such models — SqueezeNet, MobileNetV1 and MobileNetV2.