ResNet with TensorFlow (Transfer Learning)

mrgrhn
mrgrhn
Jan 21 · 7 min read

ResNet owes its name to its residual blocks with skip connections that enable the model to be extremely deep. Even though including skip connections is a common idea in the community now, it was a revolutionary architectural choice and allowed ResNet to reach up to 152 layers with no vanishing or exploding gradient problems during training.

Image for post
Image for post
ResNet 34 Architecture (Illustration is taken from the original paper [1])

For the previous posts, please visit:

With the developments in hardware technology and the variety of design techniques in deep learning deeper and deeper, models became popular in ImageNet competition. Unlike LeNet and AlexNet, VGG and GoogLeNet managed to deal with larger structures. However, training deeper networks requires some kind of intuition of how gradients flow in the models and a heuristic about how to train models. As models get deeper, research groups needed to push their imaginations and come up with more creative designs. After the Inception module[2] in GoogLeNet, another interesting breakthrough came with the residual learning mechanism of ResNet. Other notable improvements were 1x1 convolutions, dropout layers, and ReLU[3], yet none of them were as daring as these two.

While it is mentioned in detail in the earlier posts, it is still needed to go through the problem of vanishing or exploding gradients. In gradient-based learning procedures, gradients are calculated in terms of the final loss and the weight space. “In machine learning, an artificial neural network is a model that consists of a directed graph, with weights (real numbers) on the edges of the graph. The parameter space is known as a weight space, and learning consists of updating the parameters, most often by gradient descent or some variant.”[4] In the final layer, the loss is partially differentiated with respect to each of the weights, and the weights are updated in the reverse direction of the differentiation in order to decrease the loss. The step size of this update can be constant or adaptive and it is controlled by the learning rate of the training procedure. For the prior layers, the gradient flowing backward from the following layer is multiplied by the input of the current layer. As the gradient flows backward it is multiplied with weight matrices over and over again, which may result in vanishing or exploding gradients. The risk grows as the number of layers increases since the gradient traverses a longer path. Although exploding gradients can be controlled by batch normalization[6] and gradient clipping, vanishing gradients is much bigger of a problem.

ResNet introduces bypass connections in the network and allows the gradient to flow without getting multiplied with weight matrices several times.

Image for post
Image for post
Skip Connections (Illustration is taken from the original paper [1])

In a branched network structure, if a layer leads to multiple modules, the gradients coming from all the modules are summed up and backpropagation continues with the chain rule. In the figure above, if the weight layers tend to have very small numbers (in the order of 10^-5 or smaller) at least some of the gradients would be decreased to a millionth of the gradients in the upper layer. But the design of ResNet is providing identity connection by skipping some of the layers, thus making the gradients flow without being subject to any multiplications. That method is a very common architectural choice in recent networks having hundreds of layers.

ResNet with Tensorflow

Even though skip connections make it possible to train extremely deep networks, it is still a tedious process to train these networks and it requires a huge amount of data. It is also covered in the VGGNet post that, trying to train these kinds of networks with MNIST data may not lead to convergence and acceptable accuracies. ResNet is originally trained on the ImageNet dataset and using transfer learning[7], it is possible to load pretrained convolutional weights and train a classifier on top of it.

First, needed libraries are imported.

The Data

Then, the data is loaded as in the LeNet implementation. One important notice is that the original ResNet model receives images with the size 224 x 224 x 3 however, MNIST images are 28 x 28. The images are padded with zeros and the third axis is expanded and repeated 3 times to make image sizes 32 x 32 x 3. When loading the model from Keras, it is possible to indicate the input shape, which will be 32 x 32 x 3 in our case instead of 224 x 224 x 3.

The Model

The ResNet model consists of lots and lots of convolutional layers each having 3x3 masks (except the first layer with has 7x7 masks). There are a few variations of the model but ResNet-152 was the model that won ILSVRC in 2015 and it will be implemented in this post. Yes, you guessed it right, 152 is the number of layers the model has. The network is over eight times larger than VGGNet while still having lower complexity.[5] The reason for that is using many small kernels in the layers instead of fewer large kernels and having one fully connected layer with 1000 neurons instead of two fully connected layers with 4096 neurons. In many CNNs, most of the parameters are coming from fully connected layers, since convolutional layers have weight sharing property.

The design is pretty monotonous with only convolutional layers and skip connections other than max-pooling at the beginning and global average pooling at the end.

Image for post
Image for post
A comparison of layer depths
Image for post
Image for post
A comparison of layer depths

As you can see in the visuals above, ResNet-152 is absurdly deep and it is usually a good idea to load the model using Keras or any other deep learning library. We do not include the top, because that is what we want to train ourselves. We only load convolutional weights that are trained on ImageNet data. After loading the model the layers are set “not trainable”, thus frozen.

The top is added as follows:

The computational graph is constructed as it begins with the inputs of the base model and ends with a vector having the size of 10, indicating the probabilities of each category of MNIST. The model is compiled such that it will be trained using Adam optimizer for adaptive learning rate and sparse categorical crossentropy loss.

The model is trained as the backbone layers are frozen. Thus, only fully connected layers that were added afterward will be trained.

Even after two epochs, validation accuracy arrives near 90%. After 40 epochs the model comfortably converges. It is possible to reach up to higher accuracies by adding a couple of more fully connected layers. The successful results with only one hidden fully connected layer mean that ResNet-152 does a pretty good job while extracting features for the classifier even though ImageNet and MNIST contain fairly distant image samples.

Image for post
Image for post

Test accuracy came out at 92.75%.

ResNet-152 had a valuable contribution to the literature by being the first model to employ residual learning principles. While many other research groups were looking for ways to train deeper models, ResNet managed to do it by adding skip connections to the architecture and swept competitions in 2015 including ILSVRC. Nowadays, skip connections are being used in not only CNNs but in many other types of networks. Moreover thanks to its impressive learning capacity, it becomes a preferable feature extraction module in many tasks ranging from object recognition to autoencoders. In the ImageNet competition, with a 3.6% error rate ResNet even surpassed human performance.

Hope you enjoyed it. See you in the following Deep Learning articles.

Best wishes…

mrgrhn

  1. He, Kaiming & Zhang, Xiangyu & Ren, Shaoqing & Sun, Jian. (2016). “Deep Residual Learning for Image Recognition”. 770–778. 10.1109/CVPR.2016.90.
  2. Szegedy, Christian & Liu, Wei & Jia, Yangqing & Sermanet, Pierre & Reed, Scott & Anguelov, Dragomir & Erhan, Dumitru & Vanhoucke, Vincent & Rabinovich, Andrew. (2014). “Going Deeper with Convolutions”.
  3. Krizhevsky, Alex & Sutskever, Ilya & Hinton, Geoffrey. (2012). “ImageNet Classification with Deep Convolutional Neural Networks”. Neural Information Processing Systems. 25. 10.1145/3065386.
  4. https://www.wikiwand.com/en/Parameter_space
  5. https://web.cs.hacettepe.edu.tr/~aykut/classes/spring2016/bil722/slides/w04-ResNet.pdf
  6. Ioffe, Sergey & Szegedy, Christian. (2015). “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift”.
  7. Bozinovski, Stevo & Fulgosi, Ante (1976). “The Influence of Pattern Similarity and Transfer Learning upon Training of a Base Perceptron B2.” (original in Croatian) Proceedings of Symposium Informatica 3–121–5, Bled.

The Startup

Medium's largest active publication, followed by +771K people. Follow to join our community.

mrgrhn

Written by

mrgrhn

Boğaziçi Üniversitesi ’20 Electrical & Electronics Engineering — Physics | Articles on various Deep Learning topics

The Startup

Medium's largest active publication, followed by +771K people. Follow to join our community.

mrgrhn

Written by

mrgrhn

Boğaziçi Üniversitesi ’20 Electrical & Electronics Engineering — Physics | Articles on various Deep Learning topics

The Startup

Medium's largest active publication, followed by +771K people. Follow to join our community.

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store