Interactive Handwritten Digit Classifier Web App from Scratch: A Step-by-Step Guide

Saikat Sutradhar
8 min readDec 18, 2022

--

Giving a friendly face to machine-learning problems using HTML/CSS/JS + Flask(Python)

Note: I have tried to consice things for general audience but this article might get technical at certain sections.

I have always been fascinated by the fact that by using mathematical algorithms and relevant data, computers are capable to make predictions. In this article, I will show you how to leverage this computational capability and build a machine-learning model to recognize handwritten digits.

But unlike traditional machine-learning exercises, here I have integrated the model with a user interface (UI) and made it a playground for testing the model’s capabilities.
Here’s the demo of the application:

Real-time predictions on user’s drawing

Table of contents:

  1. Architecture — boilerplate for the app
  2. User Interface — where the user interacts
  3. Prediction Model — where the prediction happens
  4. Integration — stitching UI and Model together

1. Architecture

Broadly speaking, there are two sides to building this solution.

  • Front-End: The interface which is visible to the user and is used for drawing digits, displaying output, etc.
  • Back-End: It’s not directly accessible by the user but it is responsible for recognizing the user’s drawing.
Fig. Solution Architecture

Now, to make the Front-End “talk to” and “hear back” from the Back-End you need to integrate the two parts. This is where Flask comes in.

Flask is a web framework that enables user to POST requests through the user interface and GET responses from the server.

With the help of this, you will be able to integrate both ends and make a web application that hides all the irrelevant information from the user and ensures a pleasant experience.

2. User Interface

I wanted to have a minimal UI that would be a playground for testing the model. The wireframe of the UI is shown below.

HTML elements like <form>, <canvas>, <button> etc. are used to prepare the boilerplate of the UI. To show the output of the model prediction, multiple <div> are used with labels from 0 to 9 indicating the prediction class. Each of these has a progress bar that represents the percentage probability of prediction.

All the logic in the page is implemented using vanilla JavaScript and the styling is done using CSS. Getting into the code details is out of the scope of this article.
Here’s how the final UI looks:

This is the UI for the web app
Meet the friendly UI

Few salient features of the UI:

  • The predict button is enabled only after the drawing is significant enough i.e. only after it exceeds a pixel count threshold.
  • The output labels have progress bars that adjust according to the model’s probability of predicting the input digit’s class.
  • The clear button cleans the canvas for a new drawing and resets the output labels.
  • The UI displays an alert if the prediction has failed. This is described later in the article.

The process of interaction to recognize drawings is very simple. You draw a digit and click on the predict button. The model understands your drawing and makes the best guess which is displayed back to you on the UI. How it works is explained in the subsequent sections below.

3. Prediction Model

To understand the user’s drawing and make a guess we need to build a machine-learning model. Technically speaking, this type of scenario is called classification.

3.1. Overview

In the context of handwritten digit classification, the model is trained to predict the digit (0 through 9) that is written in a given handwriting sample. The model is trained on a labeled dataset of handwriting samples and learns to identify patterns and features in the images that are associated with specific digits. When given a new handwriting sample, the model uses a decision function to predict the digit that the image most closely resembles based on the patterns and features it learned during training.

3.2. Dataset

One of the most well-known and widely used datasets for this task is the MNIST (Modified National Institute of Standards and Technology) dataset. This dataset consists of more than 60,000 images of handwritten digits, each represented as a 28x28 pixel grayscale image.

Fig. MNIST dataset with labels

3.3. Preprocessing

Computers only understand numbers. In the case of these 28x28 pixel grayscale images, you have to process them so that you can use them in your model. Each pixel value is typically represented as numbers in the range of 0 to 255, where 0 represents the minimum intensity (black) and 255 represents the maximum intensity (white).

Fig. Image and its respective pixel values after converting to 1-bit

I have converted the grayscale image to a 1-bit image so that all the pixel values are either 0 (black) or 255 (white) and replaced all the 255 values with 1. This helps to improve the performance of the machine-learning algorithm. This technique is called normalization.

After the preprocessing is done, you need to arrange the 2D matrix of pixel data in a structured way so that you can feed it into your model. The pixel values are arranged in a long array starting from the top left of the image to the right and then on to the next row starting from left to right and so on. Here’s a preview of the processed data:

28 * 28 = 784 pixels’ values are stored in each of the columns “0” to “783” respectively. The “label” column indicates the digit class to which that pixel data belongs.

3.4. Model

I have used the SVM (Support Vector Machine) algorithm here. Here’s a brief about how the algorithm works:

Support Vector Machines (SVMs) are a type of supervised machine learning algorithm that can be used for classification tasks. The SVM algorithm tries to find a line or boundary (called a “decision surface”) that best separates the different classes of digits in the dataset.

Fig. SVM finds a decision boundary between different categories

Overall, the goal of the SVM algorithm is to find the decision surface that maximally separates the different classes of digits in the training data, so that it can accurately classify new images of digits.

One popular library for implementing SVM algorithms in Python is scikit-learn. The following code shows how to implement it.

import pandas as pd
from sklearn.svm import SVC

train_data = pd.read_csv() # load the training data here
X_train = train_data.iloc[:,1:] # dataframe with just the pixel values
y_train = train_data.iloc[:,0] # data frame with the labels i.e 0, 1, 2..

model = SVC(kernel='linear', probability=True, random_state=42) # building the model
model.fit(X_train,y_train) # fitting the model on training data

You can use the predict_proba() function to predict the probability of each digit class for a given set of input data. For example: if the function returns [0.7, 0.3] it means that the image has a 70% chance of belonging to the first category and a 30% chance of belonging to the second category.

Additionally, to have more confidence in the prediction I have enabled the UI to alert the user if the prediction probability is less than 50%

Fig. CSS & JavaScript is utilized here to update the UI if the prediction probability is less than 50%

I have trained the model on more than 6000 images to achieve a good accuracy score. Now, the model is ready for predicting unseen images. For this purpose, you need to export and store the model so that it can be accessed by your web application. This is done by exporting the model as a pickle file. This can be done as shown.

import joblib
joblib.dump(model, 'model.pkl')

4. Integration

Let's recap — the UI is built and the prediction model is ready. The UI is the way to send data and receive output from the machine-learning model which holds the logic to recognize handwritten digits. As mentioned previously, you need to “stitch” these two parts together using the Flask web framework.

Let me show you how to implement it in the code

First, you will need to install Flask. It can be done by running the pip install Flask command. A beginner’s guide on using Flask can be found here.

Next, create a new Python file called app.py and add the below code:

from flask import Flask, render_template, request, jsonify
import joblib

app=Flask(__name__)

model = joblib.load("modek.pkl") # loading the pre-trained model

@app.route('/')
def home():
return render_template("index.html") # landing page

@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
pixel_data = data['pixel_data'] # data from the image after preprocessi
prediction_proba = model.predict_proba(pixel_data) # make predictions
return jsonify({'prediction_proba': prediction_proba}) # return data

if __name__=="__main__":
app.run(debug=True)

The index.html file is the Front-End for this app that I had discussed earlier. This is how the <body> of the file should look:

<!-- Note: This is only a snippet from index.html file  -->

<body>
<form id="form">
<!-- add all the elements here -->
<button type="submit" id="pred-btn">Predict</button>
</form>
</body>

You need to add JavaScript to the HTML file to handle the <form> submission and make AJAX requests to the Flask server using fetch().

// Note: This is only a snippet from script.js file

const predict_button = document.getElementById("pred-btn");

// on clicking predict button the canvas image data should
// be send to the server and the server will predict using it
// and send back the prediction here
predict_button.addEventListener('click', e => {
e.preventDefault();
var pixel_data = 0; // replace 0 with the image pixel data
fetch('/predict', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({pixel_data:pixel_data}) // sending the data
})
.then(response => response.json()) // getting the response
.then(data => {
// add here what you want to do with the predictions
console.log(`Prediction: ${data.prediction_proba}`);
})
.catch(error => {
console.error(error);
});
});

The above lines of code use fetch() to send & receive data when predict_button is clicked. It sends the pixel_data to app.py where the prediction probabilities are given by model.predict_proba() which is then sent back to script.js as JSON data. The prediction data is received as a response inside data.prediction_proba which is your model’s output for the given data input. You can use .then(data => {...}) to handle the output as desired.

To run the app, start the Flask development server by running the flask run on the terminal. The app should be now accessible in your web browser at http://localhost:5000/.

Once all the parts are in place, you should be able to draw a digit and get model predictions by clicking predict button.

Thanks for making it thus far. I hope this article finds you some help. For the codes and templates, you can reach out to me here.

--

--