TensorFlow.js: Building a Drawable Handwritten Digits Classifier

Jawarneh
The Startup
Published in
7 min readJun 28, 2020

In this article we will diving into the implementation of a handwritten digit classifier on the browser using: convolutional neural networks, the MNIST dataset and TensorFlow.js. We will then be discussing the limitations of using convolutional neural networks in this particular scenario and why standard fully connected networks are more suitable.

MNIST Dataset.

Our classifier will need to identify the digits that we draw onto the canvas, this is done by sending the raw pixel data into our model, a convolutional neural network, for it to return a softmax prediction on which digit the model thinks the input is most likely to be.

In order for our model to accurately identify our written digits, it needs to understand what characteristics each digit contains, hence we need to train the model with a dataset of labeled handwritten digits. This is where the MNIST dataset comes in…

Sample of MNIST dataset, courtesy of Wikipedia

The MNIST dataset consists of 60,000 training images and 10,000 validation images. Each image is grayscale where each pixel ranges from 0–255 where 0 is no colour and 255 represents that the pixel is fully coloured.

Rather than manually download the MNIST dataset, we will gain access and download it via TensorFlow’s Keras API.

Our Blueprint.

There is a variety of components that go into building our neat little web classifier:

  1. Preprocess the MNIST dataset.
  2. Build and train our model.
  3. Export our model to HDF5 format.
  4. Convert our exported model to TensorFlow.js format.
  5. Load our model onto the browser.

It is always an excellent idea to decompose/break down the approach that you will be taking. Okay… now we have understood the basic intuition of our dataset and the structure of our web app, let us dive into the code…

Prerequisites.

Before we dive into preprocessing and training our model, we need to ensure that we have the relevant libraries installed. We will only be using three libraries/frameworks: TensorFlow 2.1, TensorFlow.js Converter, NumPy and Scikit-Learn.

Copy and paste the script below onto your terminal if you do not have the libraries above installed:

$ pip install tensorflow tensorflowjs numpy scikit-learn

Preprocessing our dataset.

Now that we have all of our prerequisites carried out, we can begin to download and preprocess all of our MNIST images.

Let us begin by importing our libraries and initialising our train and test sets:

Here we begin to load our MNIST dataset into two tuples. The training set contains 60,000 images whilst the testing set contains only 10,000.

Reshaping and standardising our data

We then need to define the shape of each image to be [28, 28, 1]. We then need to redefine the shape of the entire training and testing set. As well as type casting all the pixel values in our tensors/images within our training and testing set to a 32-bit floating point value, we do this as sometimes the values come in a 64-bit format, this is unnecessary for such a simple dataset and would take up more computational resources. Finally we standardise all values between 0 and 1 in order to speed up the training process.

Distribution of each digit in train set

If we take a moment to look at the distribution of our train set, we can see that it is fairly balanced. Looking at this is important as it allows us to determine what metrics are suitable for analysing and evaluating our model. Since our classes are fairly evenly distributed we can rely on accuracy to evaluate our model.

One-hot encoding our labels

Currently, our labels are ‘Label Encoded’, this is not an optimised encoding for our dataset, hence we need to replace it with one-hot encodings. This is simply done with Scikit-Learn’s OneHotEncoder class.

Building and training our model.

Taking a naive approach, we will be training a convolutional neural network on our dataset to pose as our model for our handwritten digits classifier. CNNs are really good at identifying features in spatially represented data such as images. However, as you read on, it becomes apparent why a CNN architecture is not suitable for this particular problem.

But enough talking… lets dive into the architecture:

CNN architecture used for MNIST, courtesy of neuralnet.ai

This is the simplest CNN architecture that I have stumbled upon throughout the years that achieves sub-ninety percentage accuracy. Okay, cool… let’s write it:

You can, of course, experiment with different hyper-parameters, optimisers or loss functions.

Losses over time compared between train & test set
Accuracy over time compared between train & test set

Our model produced some very promising results with over 98% validation accuracy. This is all well and good, but the question remains, can this model predict our new, unseen handwritten digits that we create on our web app? Let’s find out!

Before we proceed, we need to export/save our model so that we can use the TensorFlow.js converter tool to reformat the HDF5 into a JSON file that TensorFlow.js accepts.

In-Browser Machine Learning.

TensorFlow.js allows us to run and train machine learning models on the client’s end. Previously, you would need to train your models on an external server, allowing you to only request and send data via an API, as you might have guessed, this was a slow approach.

Provided by TensorFlow

In order to load a model onto TensorFlow.js it needs to be converted into a JSON file which is then hosted on a web server that can be accessed and loaded onto the client’s browser. To convert your .h5 model, you need to run this simple script:

$ tensorflowjs_converter --input_format=keras \
/path/to/model.h5 \
/path/to/dest/directory

Since we are using the Keras API with TensorFlow, we need to specify the input format as ‘Keras’. The next parameter is the file directory to where you exported/saved your trained model. The final parameter is a destination folder where you want to store the newly formatted model.

Now it is time to build our web app, but first we need to add the TensorFlow.js CDN into our HTML head, inside a new HTML file that will be called index.html:

TensorFlow.js CDN

Great! Now we can access all of key TensorFlow.js functions and modules. Our next step is to load in our model in a new JavaScript file called index.js:

Our model has finally been loaded onto the browser, we can now carry out predictions by passing in a Tensor object that represents an image of a handwritten digit by the user.

In order to extract that sort of input from the user, we need to set up a canvas that represents a 28x28 matrix where each value is initially set to 0. Once the user’s mouse is over each cell, the value in the matrix corresponding to that cell is flipped to 1. Once the user stops clicking the mouse, the now altered matrix is transposed and passes through the model in order to return a prediction/probability distribution of the most likely digit that the input could have been.

Graphical representation Vs. Matrix representation

The code that was written for the canvas will not be covered in this article, instead the code for the canvas will be found on the GitHub repo of this article.

The code below demonstrates how the state of the canvas is converted into a matrix and is then passed through the model:

Now our model returns an array of the different probabilities that the model has predicted. This is extended further by extracting these values and visualising them using Chart.js, this is all included in the GitHub repo!

A problem arises!

After spending a few good minutes testing out the web app, you begin to realise that it has heavily flawed, especially when the model tries to distinguish between a 6 and a 9. Here is a ridiculous predictions that the model has produced:

Incorrect prediction by the model

Now you may be wondering if our model overfitted during training, this cannot be the case as our validation accuracy was in the high 90 percentile range. Hence I conclude that it is a problem with the architecture it self. Convolutional neural networks are designed to detect features of an image regardless of orientation/augmentation. This means that the model begins to confuse it self when predicting lopsided 6’s 8’s and 9’s. Also if we look at the MNIST dataset it self, most, if not all the images are uniform and vary slightly in features, which can open up another can of worms onto why the model is not performing particularly well in this application.

Ultimately, I strongly feel that a fully connected network would have out performed a CNN in this particular application. But I am going to leave this as a challenge for my readers!

Thank you for reading, follow us for more fun deep learning projects such as this one. This is also neuralnet.ai ‘s first Medium blog, tell us how we did and follow us on Instagram!

GitHub Repository for this project:

Our GitHub is home to all of our projects and applications that we build at neuralnet.ai!

Follow us on Instagram:

Our page educates the field of deep learning through the power of art & design: neuralnet.ai

--

--

Jawarneh
The Startup

Founder of neuralnet.ai. Educating the field of deep learning through the power of art & design. Follow us on instagram!