The Startup
Published in

The Startup

Deep Learning for Image Classification — Creating CNN From Scratch Using Pytorch

Introduction

This article will explain the general architecture of a Convolution Neural Network (CNN) and thus helps to gain an understanding of how to classify images in different categories (different types of animals in our case) by writing a CNN model from scratch using PyTorch.

Prerequisites

  • Python
  • Basic understanding of Neural Network
  • Basic understanding of Convolution Neural Networks (CNN)

Lets Code

Step 1: (Downloading Dataset)

  • Download the dataset from this kaggle link and extract the zip.
  • Alternatively we can also clone the dataset and the project files form this github link as well.
  • The dataset contains about 28,000 images belonging to 10 categories: dog, cat, horse, spyder, butterfly, chicken, sheep, cow, squirrel and elephant.

Step 2: (Create Datasets & Data Loaders to load these Images)

Step 3: Creating CNN Model Architecture

Lets create a simple CNN model architecture.

Like all the general CNN architectures, our model also has 2 components

  1. A set of convolutions followed by a non-linearity (ReLU in our case) and a max-pooling layer
  2. A linear classification layer for classifying an image into 3 categories (cats, dogs and pandas)
CNN Model Architecture
  • The model contains around 2.23 million parameters.
  • As we go down the convolutions layers, we observe that the number of channels are increasing from 3 (for RGB images) to 16, 32, 64, 128 and then to 256.
  • The ReLU layer provides a non-linearity after each convolution operation.
  • As the number of channels are increasing, the height and width of image is decreasing because of our max-pooling layer.
  • We added Dropout in our classification layer to prevent the model from overfitting.

Step 4: (Defining Model, Optimizer and Loss Function)

We are using Adam optimizer with 0.0001 learning rate along with Cross Entropy Loss.

Step 5: Start Training

Finally the moment has arrived we all are waiting for i.e Training the Model

For Training and Testing I created these two helper functions.

Now Lets start the Training:

Thanks to the helper functions we created above for, we can easily start out training process using the following code snippet.

We are training the model for 50 epochs and also saving it to disk after every 10th epoch.

Here is the output that we get during training…

  • The step took around 2 hours (for 50 epoch) on google colab using a Tesla T4 GPU runtime.
  • As we can see the accuracy went up from 21% after 1st epoch to 75% after 50th epoch. (After training for another 50 epochs the accuracy went up to 78%)
  • This is quite good considering our very basic CNN model with only 2.23M parameters.
https://unsplash.com/photos/bH7kZ0yazB0

Evaluating the Model

Here is the plot of our Training & Testing Loss

  • After around 20th epoch, we can see a noticeable variance in the curve.
  • We’ll see how we can improve this more in next section. But till now everything looks great.

Now Finally lets test it out on some random images…

https://unsplash.com/photos/M3fhZSBFoFQ

Congratulation on sucessfully training the model & Thanks for sticking till the end.

Please let me know about your views or queries in the comment section.

--

--

--

Get smarter at building your thing. Follow to join The Startup’s +8 million monthly readers & +756K followers.

Recommended from Medium

Understanding the basics of Neural Networks (for beginners)

Natural Language Processing research

RL — Conjugate Gradient

Path Planning Using Potential Field Algorithm

Random Forest Regression in Python

What happens when you speak to a machine — Part 2

Regularization — Understanding L1 and L2 regularization for Deep Learning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Abhishek

Abhishek

❤️#!/bin/bash❤️

More from Medium

Speech to Text Processing using Deep Learning Models

How to Start Using Natural Language Processing With PyTorch

Natural Language Processing with PyTorch

SIGN LANGUAGE DETECTION MODEL

Image Detection with AI Explainability Project