Deep Learning Image Classification Using PyTorch & Fastai v2 on colab

Sumit Redekar
5 min readOct 12, 2021

--

Step by step guide to train our own Cat Breed Image Classifier with ease and then further deploying with Streamlit.

Introduction

Fastai is a library developed by Jeremy Howard and Rachel Thomas built on top of PyTorch for deep learning. Their goal is to democratize deep learning by providing a massive open online course(MOOC) named “Practical Deep Learning for Coders”, which has no other prerequisites except for knowledge of the programming language Python.

So in this article, we are going to develop a cat breed classifier in just a few lines of code with fastai.

Installing fastai

We are working with deep learning image classification so we are going to code in Google Colab as we need GPU to train the model.

Let’s first install fastbook,

!pip install fastbook
import fastbook

After installing fastbook, we will install fastai vision

from fastbook import *
from fastai.vision.widgets import *

So now fastai is installed we can now gather data for training.

Data Gathering through DuckDuckGo API

So as I told earlier fastai is pretty simple and powerful and they also provide some functions to gather data through DuckDuckGo and Bing. For now, we will use DuckDuckGo to download our cat images.

cat_types = 'Asian', 'Australian Mist','Bengal','British Longhair',        'Cyprus', 'Bombay', 'Japanese Bobtail','Russian Blue','Selkirk Rex',  'Turkish Vankedisi'path = Path('cats2')
if not path.exists():
path.mkdir()
for o in cat_types:
dest = (path/o)
dest.mkdir(exist_ok=True)
urls = search_images_ddg(f' {o} cat')
download_images(dest, urls=urls)

So in the first line of code, we defined different types of cat breeds that we want to download the images.

Then after that, we have created a directory to save our downloaded images.

search_images_ddf() will search images and then download_images() will save our images in the directory.

Okay so, this will download around 2000+ images.

After downloading we will get our filenames with get_image_files().

fns = get_image_files(path)
fns

Now we will check if there are any broken images or something like that with the verify_images().

failed = verify_images(fns)
failed

If there are any failed images, then it will unlink that,

failed.map(Path.unlink);

So finally, now we can move further towards modeling.

DataBlock and DataLoader

Fastai provides a mid-level API DataBlock to deal with the data.

cats = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2,seed=42),
get_y=parent_label
)

We can use different blocks, but for this, we have used ImageBlock as X and CategoryBlock as the label. Then we used get_image_files to get the path of the images and the parent_label method to find folder names as the label. These all are in-built functions provided by fastai.

Then we have used RandomSplitter to split the data into training and validation datasets.

cats = cats.new(
item_tfms=RandomResizedCrop(224, min_scale=0.5),
batch_tfms=aug_transforms()
)

Now we have resized all the images with the RandomResizedCrop() to 224 and aug_transforms() for data augmentation.

This is the template of Fastai datablock, if you want to know more about it you can find the tutorial here.

Now the DataBlock is ready, we can create DataLoader,

dls = cats.dataloaders(path)

So now, we will see what our data looks like,

dls.valid.show_batch(max_n=12, nrows=4) 

This will show validation data and the output looks like this,

Validation data

Now I am excited to work with these cute cats!

Fitting ResNet-50 architecture

So what is ResNet-50 architecture?

ResNet-50 is a convolutional neural network that is 50 layers deep. You can load a pre-trained version of the network trained on more than a million images from the ImageNet database. The pre-trained network can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals. As a result, the network has learned rich feature representations for a wide range of images. The network has an image input size of 224-by-224. For more information, you can refer to this article.

Let’s fit our ResNet-50 architecture,

learn = cnn_learner(dls, resnet50, metrics=error_rate)
learn.fine_tune(epochs=10)

So this will fit our model for epochs=10.

After fitting our model, we are going to check the confusion matrix. If you don’t know what is confusion matrix you can refer to this.

from IPython.core.pylabtools import figsize 
interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix(figsize=(10,10)

So the output looks like this,

Confusion Matrix for Cat Breeds

After plotting the confusion matrix, now we will plot where our model predicts wrong.

interp.plot_top_losses(10, nrows=2, figsize=(20,6))

This line of code will plot the top 10 losses where our model performs worst.

So the next step is to save our model.

learn.export(fname='resnet50.pkl')

Finally, our model is ready we have successfully trained our cat breed classification model, so now the next step would be to deploy the model with Streamlit.

Deploying the model

We are going to create the web app with Streamlit, if you don't know what exactly it is you can refer to this.

First, we install streamlit, it's not that hard to install streamlit,

pip install streamlit

Now we will import our libraries,

import pathlib
from pathlib import Path
import streamlit as st
from fastai.vision.all import *
from fastai.vision.widgets import *

let’s load our model,

learn_inf = load_learner('resnet50.pkl')

Now we will create our class Predict to get our predictions done.

class Predict:    
def __init__(self, filename):
self.learn_inference = load_learner(Path()/filename)
self.img = self.get_image_from_upload()
if self.img is not None:
self.display_output()
self.get_prediction()
@staticmethod
def get_image_from_upload():
uploaded_file = st.file_uploader("Upload Files",type=['png','jpeg', 'jpg'])
if uploaded_file is not None:
return PILImage.create((uploaded_file))
return None
def display_output(self):
st.image(self.img.to_thumb(500,500), caption='Uploaded Image')
def get_prediction(self):
if st.button('Classify'):
pred, pred_idx, probs = self.learn_inference.predict(self.img)
st.write(f'**Prediction**: {pred}')
st.write(f'**Probability**: {probs[pred_idx]*100:.02f}%')
else:
st.write(f'Click the button to classify')
if __name__=='__main__':
file_name='resnet50.pkl'
predictor = Predict(file_name)

Now our web app is ready, we can upload image and we can easily classify which cat breed it is.

To run our web app, open the terminal and change the directory to your current file, and then,

streamlit run filename

This will open our web app on the local network.

Live demo of our Streamlit web app

Whooosshhh!! Finally, we have created our image classification model and we can showcase our work with Streamlit. If you want you can host this web app with AWS, Heroku, or Streamlit Sharing.

Conclusion

In this article, we have created our own image classification with few lines of code and then we have also created our web app to classify the different breeds of cats.

If you want the whole code it is available on GitHub.

--

--