Nerd For Tech
Published in

Nerd For Tech

Convolutional Neural Network in PyTorch

In this article, I will explain how CNN works and implement slightly modified LeNet5 model using PyTorch. These are my notes about Convolution Neural Networks, which are summed together in an accessible way to help you understand this topic.

Photo by Matt Noble on Unsplash

Why are we using Convolution Neural Networks?

CNN’s are similar to regular Neural Networks. They have weight and biases, but they are specifically designed for computer vision tasks like object recognition or image segmentation. Since 2012 they have been the driving force of the deep learning revolution because of their ability to utilize GPU power and handle extensive data.

Alright, let’s see what the building blocks of ConvNets are and how they work!

Architecture

Overall structure of ConvNets
LeNet5 architecture[3]

Feature extractor consists of:

  • Convolution layers
  • Pooling layers(“Subsampling”)

The classification block uses a Fully connected layer(“Full connection”) to gives the final prediction.

After every Convolution or Fully-connected layer, we add the non-linear activation function(except the last layer, where we use Softmax). Currently, ReLU is a go-to function because of the faster computation and the best results.

I will try to briefly introduce each of the layers’ purpose and properties, and we can jump straight into the coding.

Let’s start with convolution layer!

In a nutshell, convolution is a sliding filter along with the image, which calculates the multiplication sum. So, what is the filter, then?

You probably used Sepia effect which adds a warm brown tone to the picture. In fact, it is a 3x3 window with predefined values that is moved around the picture.

Convolution operation[1]

In CNN, we want to learn these values to extract relevant features. The learning process uses the the backpropagation algorithm, the same as in regular Neural Networks.

The convolution layer has four hyperparameters that determine the size of the output:

  • Filter size — the standard choice is 3x3 and 5x5, where empirically 3x3 yields the best accuracy results
  • Depth — number of filters in the output, usual power of 2 because of the computation reasons. When we get a deeper network, each layer has a broader scope of features and abstract connections to look for. Thus we increase their number with each layer, usually like [64, 128, 256…].
  • Strides — controls the movement of the filter. If strides are 1, we move the filter by one pixel. When stride is 2, we move the filter by two pixels. It reduces the size of the output, thus speeding up the computation.

To check whether stride number is performable, the result of this equation has to be integer:

  • Padding — another way to control the spatial output size but this time by adding zeros around the border. Zero-padding(“same”) helps to preserve the initial size. “Valid” means no padding at all. It is beneficial with deeper networks, where the volume would decrease too quickly without it.

The formula to calculate the output size:

Pooling layer

Another downsampling operation gives the network some amount of translation invariance. There are two variants of pooling:

  • Max Pooling — takes the max value of window so extracts the sharpest and brightest features
  • Average Pooling — takes an average of window so extracts smooth features

These days pooling operations are mainly replaced by strides.[2]

Fully-Connected layer

They are used to learn the connections between features extracted from different filters and output the probabilities of class predictions.

Output calculations and hyperparameters

Coding

Finally, after a bit of theory, we are ready to do some programming. If you want to see the complete code, check out my Github.

LeNet model

Training function

Dataset

I used well-known MNIST dataset, which fits perfectly for such a simple model.

Training and testing

Results

After 2 minutes of training, model achieved around 98.5% accuracy. I also plotted feature maps to see how filters are extracting features. If you are interested in how feature maps look for different inputs, I recommend this website.

Conclusions

Convolution Neural networks are the bread-and-butter of Deep Learning and play a massive role in many domains like object recognition, image segmentation, or medical imaging. So it is essential to know how they work to optimize and upgrade them. I hope after reading this article, you are able to build your own ConvNet in PyTorch!

If you want to see my other projects check my Medium and Github profile.

References

[1] A guide to convolution arithmetic for deep learning

[2] Striving for Simplicity: The All Convolutional Net

[3] Gradient-Based Learning Applied to Document Recognition

--

--

--

NFT is an Educational Media House. Our mission is to bring the invaluable knowledge and experiences of experts from all over the world to the novice. To know more about us, visit https://www.nerdfortech.org/.

Recommended from Medium

What the heck is random_state?

Dice photo

Part 2:- ABC of Machine Learning

Building an End to End Recommendation Engine using Matrix Factorization with Cloud Deployment using…

Distributional Reinforcement Learning — Part 2 (IQN and FQF)

How I built a Credit Card Fraud Detection Classifier in R

Interview with author of “Expectation Propagation: A probabilistic view of Deep Feed Forward…

Machine Learning: Loss Functions for Stupid-Heads

Introduction to my understanding of computer vision

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Maciej Balawejder

Maciej Balawejder

Mechanical Engineering student with vast interest in Machine Learning and AI in general. https://github.com/maciejbalawejder

More from Medium

Object Detection Introduction

Coding A Neural Network From Scratch

Introduction to Convolutional Neural Networks — Deep Learning

Best Practices for Image Processing & Computer Vision