Build and Deploy Image Classification web app using Django on Heroku (Part-2)
In Part-1, we have seen how to build and save the model. Now lets use that model and develop the Web App
- Install Django
- Create a conda environment and install Django
python -m pip install Django
2. Creating the Django project
django-admin startproject classifyImage
3. Create a folder “templates” and create “index.html” file and make the following change in settings.py
TEMPLATES = [
. . . 'DIRS': ['templates'], 'APP_DIRS': True, . . .]
- Create a folder models inside the classifyImage sub folder, copy the ckpt.pth file thats generated after the model training
Directory structure of the Project will be
classifyImage
├── manage.py
├── classifyImage
│ ├── models
│ ├── ckpt.pth
│ ├── __init__.py
│ ├── asgi.py
│ ├── settings.py
│ ├── urls.py
│ └── wsgi.py
├── templates
│ ├── index.html
Add the path in the urls.py
from django.contrib import admin
from django.urls import path
from classifyImage import views
from django.conf.urls.static import static
from django.conf import settingsurlpatterns = [
path('admin/', admin.site.urls),
path('', views.index, name='index'),
]urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
Lets design the HTML page (index.html)
<!DOCTYPE html>
<html lang="en" dir="ltr">
<head>
<meta charset="utf-8">
<title>ClassifyImage.</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet"> <style>
*{
font-family: 'Poppins', sans-serif;
box-sizing: border-box;
margin: 0;
} .column {
float: left;
padding: 10px;
height: 600px;
} .left {
width: 70%;
} .right {
width: 30%;
padding-left: 30px;
} .row:after {
content: "";
display: table;
clear: both;
} .vertical-center {
margin: 0;
position: absolute;
top: 50%;
-ms-transform: translateY(-50%);
transform: translateY(-50%);
} .center {
margin: auto;
width: 50%;
padding: 10px;
} .button {
background-color: #FEE13D;
border: none;
color: #000000;
padding: 10px 60px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 16px;
margin: 4px 2px;
cursor: pointer;
} .text {
font-size: 16px;
text-align: center;
background-color: #FAFAFA;
padding: 20px 10px;
} @media screen and (max-width: 600px) {
.column {
float: none;
}
.left {
width: 100%;
padding-bottom: 10px;
float: none;
} .right {
width: 100%;
padding-bottom: 10px;
float: none;
} .vertical-center {
position: absolute;
transform: translateY(0%);
top: 120%;
margin-top: 10px;
} .center{
margin: 0;
}
} </style></head><body>{% load static %} <div class="row"> <div class="column left" style="background-color:#FFFFFF;"> <p><b>ClassifyImage.</b></p> <div class="center"> {% if image_uri is not None %} <br><br> <img class="outputImage" src="{{ image_uri }}" width="400" height="400">
<div class="text">It seems to be <b>{{predicted_label}}</b></div> {% endif %} </div>
</div> <div class="column right" style="background-color:#000000; color: #FFFFFF;"> <div class="vertical-center"> <h3>Choose Image to</h3> <h2><b>CLASSIFY</b></h2> <br><br> <form method="post" enctype="multipart/form-data"> {% csrf_token %} {{ form }} <br><br><br><br> <button type="submit" id="button" class="button" >Classify It</button> </form>
</div>
</div>
</div>
</body>
<script>
function loadFile(input) {
if (input.files && input.files[0]) {
var reader = new FileReader();
reader.onload = function (e) {
$('#output')
.attr('src', e.target.result);
}; reader.readAsDataURL(input.files[0]);
}
}
</script></html>
Create forms.py inside the sub folder classifyImage
from django import formsclass ImageUploadForm(forms.Form):
image = forms.ImageField(label=False)
- Install torch
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
Now lets do the final part of using the model to predict the class of the uploaded image
Create views.py inside the classifyImage sub folder
from django.shortcuts import render
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
import base64
from .forms import ImageUploadFormdef getClassOfImage(image):
net = torchvision.models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs,100) classes = ('apple', 'aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','keyboard','lamp','lawn_mower','leopard','lion','lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom','oak_tree', 'orange','orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf', 'woman','worm') PATH="classifyImage/models/ckpt.pth"
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['net'])
net.eval() transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) img = Image.open(image)
img = img.resize((32,32))
input = transform(img)
input = input.unsqueeze(0)
output = net(input)
_, predicted = torch.max(output, 1)
print('Predicted: ', classes[predicted[0]]) return classes[predicted[0]]def index(request):
image_uri = None
predicted_label = None
if request.method == 'POST':
form = ImageUploadForm(request.POST, request.FILES)
if form.is_valid():
image = form.cleaned_data['image']
image_bytes = image.file.read()
encoded_img = base64.b64encode(image_bytes).decode('ascii')
image_uri = 'data:%s;base64,%s' % ('image/jpeg', encoded_img) # get predicted label
try:
predicted_label = getClassOfImage(image)
except RuntimeError as re:
print(re) else:
form = ImageUploadForm() context = {
'form': form,
'image_uri': image_uri,
'predicted_label': predicted_label,
} return render(request, 'index.html', context)
Lets run the web app
python manage.py runserver
GitHub Link: