Learn and Play with TensorFlow.js [Part 3: MNIST Classification]

Here we are, after the introduced to Binary Classification in the previous chapter, we’ve come to the most popular example in Machine Learning tutorial: MNIST Classifier

MNIST Database

MNIST handwritten digit database is a popular multi-class classification dataset that is widely used in Machine Learning tutorials to understand the basic of classification. The original MNIST database consist of 60,000 training images and 10,000 testing images

To use the dataset, we have provided a class called MnistData in the mnist_data.js JavaScript file. Instance of MnistData class will be able to download the MNIST dataset provided using load() function. After the data has been downloaded, we can convert it to training and test tensor data using getTrainData() and getTestData() functions.

Download mnist_data.js file here

Main Page

As before, we start this application by creating a new index.html file and put it in a new directory called part_3. Then, for this third example, let’s make the application as complete as possible from top to bottom. We’ll create it so that the users can define their own network architecture, train it, test it, and even save and load the trained models, so that later if users want to use this application again, they don’t need to train the model again from the start. Just load it, and it’s done.

Overall, there will be six <div> sections in the main view, each of which will contain:

  1. Section to download the MNIST dataset and display some sample images.
  2. Section to design and compile the model architecture.
  3. Section to display the training menu along with the progress of the training.
  4. Section to evaluate the trained model and show it in class accuracy and confusion matrix. This section will also display some examples of predicted test images.
  5. Section to save and load the trained models.
  6. Section to test the model using handwriting input directly from the user.

Let’s start build it one by one.

Load Dataset

In this section, we provide a Load Data button to download the dataset. Since the downloading process takes time, add a badge to display the message whether the process of downloading data is complete or not. Then put a div to display some examples of images that have been downloaded.

Initialize Model

In model initialization, insert a text area so that users can design and compile their own model architecture. However, here we also add a couple of radio button to use examples of predefined model architecture that are ready to use. We give two predefined model choices: the “standard” Artificial Neural Network (ANN), and the Convolutional Neural Network model (CNN).

CNN, or ConvNet, is an Artificial Neural Network architecture variant that is very popular and wildly used in recent years. The architectures make the explicit assumption that the inputs are images (3-dimensional). ConvNet has higher generalization power with fewer parameters than vanilla ANN.

Lastly, add a button to process the designed architecture, and add a div section to display a summary of the architecture.

Train Model

For the learning section, provide three columns. In the first column, put in two fields to receive input max epoch and batch size along with the train button. The second and third column will be used to display loss and accuracy progress during training.

Wait, what is batch size?

To know what is batch size, let’s take a step back first.

Training in Neural Network

The learning iteration in training a Neural Network consists of two steps:

  1. the forward pass to read all the training data to calculate its output, and
  2. the backward pass to update and correct the weights (parameters in the model) based on the output error produced

Epoch is a learning iteration count (forward and backward) for the entire training data. It means that a learning is said to have gone through one epoch if it has done the forward and backward pass for every data in the training set. The basic Gradient Descent algorithm will use all training data at once every iteration. So that 1 iteration = 1 epoch. This training process is commonly referred to as Full Batch Gradient Descent Learning. For example, if there were 100 2-dimensional data, then in one process forward and backward with Full Batch GD, the network would directly swallow the [100x2] sized matrix.

The problem with this approach is that as you grow your data, because the network will learn better if you have more data, it becomes almost impossible to read all data in one giant matrix. For example, if we look at this MNIST dataset, each grayscale image is [28x28x1] in size, and there are 50,000 images of training data. To read a float64 matrix with the size of [50000 x 784] it will certainly be very heavy.

There is, however, another variant of Gradient Descent Learning called Stochastic Gradient Descent. In SGD, instead of using all training data at once, you pass one data selected randomly one at a time until all training data is used. So in one epoch, there will be n iterations as much as the training data. It does seems like the training might have gone slower, isn’t it? Actually, not. Because the computation is lighter, and the parameters are updated in each iteration, SGD tends to converge much faster that Full Batch GD, although usually the results are not as optimal.

Then there is the middle ground, called Mini-Batch Gradient Descent. The idea is that if we don’t want to use one data at a time, but we can’t use all data at once, so why don’t we use a bunch of data, any number of data that fits in our memory, in each iteration? So there you have it. Using Mini-Batch GD will make the network converges in fewer iterations than Full Batch GD, and since Mini-Batch utilize vectorized operation, typically it results in a computational performance gain over SGD.

Back to Main Page

After we know what is batch size, let’s continue and complete this training section.

As mentioned earlier, here we add three columns. In the first column, there are inputs for epoch and batch, and a button to begin the training process. The training button is set to be disabled until the model has been initialized. Here, let’s give a standard value of epoch = 1 and batch size = 100 (The standard batch size for TensorFlow.js is 32).

In the second and third columns, put each a div section to display graphs of loss and accuracy. Also add each 2 badges to display the training status, the total number of iterations to be performed, the progress of the training, and the training accuracy of the current iteration.

Evaluate Model

In the fourth part, we will show the performance of the trained model on the test data in the form of Class Accuracy and Confusion Matrix. Both of these performance visualization techniques can help us understand better how well is the trained models.

For that, add a button to start the evaluation, and two columns to display each visualization. Below that, add another button and a div section to display some sample images of test data and their predicted class.

Save/Load Model

After we train the model, we can saveit so that we can use it later on without having to retrain the model from the start. When saving a model, TensorFlow.js will create two files in the .json and .weights.bin formats.
The JSON file contains the structure/architecture of the model that was built, while the BIN file stores the weight parameters of the trained model.

Now, add two columns for each section to save and load the model. In the first column, we only need a Save Model button. While in the second column we need a button to load the model architecture from the JSON file, another button to load the weight from the BIN file, and the last button to process it.

Test Model

Last but not least, it is the part where user can try directly the results of the trained model, by giving new handwritten digit input. Here we design three columns. The first column contains a canvas where the user can draw the new digit. The second column will display the cropped image of the input, and the prediction. The last column is used to display the prediction score histogram.

Helper Functions

Before we move on to creating index.js file, let’s talk about the utility functions that have been provided that will help us ease up several things. Those functions are MNIST Utility Functions and Drawing Utility Functions.

MNIST Utility Functions

There are four functions stored mnist_utils.js file.

  • function getModel(name)
    This function provides default template for the ANN/CNN network architecture. The function returns the 2 layer dense architecture text for the ANN option and 3 convolution layers (2 conv, 1 dense) for the CNN option. Each model use ReLU activation function and softmax output activation. Both architecture are compiled with rmsprop optimization using categorical cross entropy and accuracy.
  • function showExample(elementId, data, labels, prediction=null)
    This function will help us display several examples of MNIST images passed in data parameter to the selected elementId.
  • function cropImage(img, width=140)
    This function is used during the testing process to crop the handwritten image input from the user and discard the excess white spaces to increase the model’s performance. The value of the width in the function states the size of the canvas used.
  • function firefoxSave(model)
    TensorFlow.js has provided a function module for saving and loading trained models. Unfortunately, up to the time this tutorial was created, the modules provided could only work properly in the Chrome web browser. For the use of other browsers like Firefox, the save model function cannot run properly. Therefore, we provide additional function to help save models in other browsers such as Firefox.

Drawing Utility Functions

The helper function in the draw_util.js file is the initCanvas(canvasId) function which will provide functionality on the selected canvas to be able receive action from the mouse event so that the user can draw on the canvas. This function accepts the input canvas ID that will be used as a drawing, as well as the size and color of the pen.

Download mnist_utils.js file here, and draw_utils.js file here

index.js

Again, create new index.js file and put it in part_3 directory. At first, import all necessary libraries as we’ve done before. In addition to that, import MnistData from mnist_data.js, all other functions from mnist_utils.js, and initCanvas from draw_utils.js

At this point, we can already see the view by executing

\ai_tfjs> parcel part_3\index.html

You can see the view will look like this

Load Data Button

Moving on, let’s add functionality to the Load Data button. Create an asynchronous click event, and when the button is pressed, give a message that the MNIST data is being downloaded, then call the load() function to start downloading.

There are a total of 65,000 MNIST images provided. So that the training process does not run too long, let’s just use only 40,000 as training data and another 10,000 for test data. When the data has been downloaded, change the message displayed, then view the first 8 images from the test data.

Init Button

First, give an Action Event on the radio button to fetch the predefined model architecture using util.getModel() function. Then create model as a global variable. Next create the event action click for init-btn and get model architecture written in the text area. Execute the string code using the eval() function.

If the model initialization is successful, display the architecture summary using modelSummary() function and activate all other buttons.

Train Button

Add an asynchronous action event to the train-btn. For starter, display message that the model is being trained, then fetch the epoch and batch values.

Next we take MNIST training data using getTrainData() function, and as explained earlier, divided it into multiple batches according to the size given by the user. So that here we can calculate and show how many iterations (forward and backward functions) will be executed given the number of epochs. The total iteration can be calculated as
iteration = the amount of data / batch size * max epoch.

Now we call the fit() function to train the model using the given training data, epoch, and batch size. During the training, after each batch iteration is complete, show the loss and accuracy graph.

At the end of the function, show message when the training is complete, then activate the save button.

Evaluate and Show Example Button

Create asynchronous click event for Evaluate Button. Inside, fetch testing data using getTestData() function, and feed it into predict() function to get the class prediction. Calculate the predicted class accuracy and show it to the main page.

To show the detailed accuracy and number of wrongly classified data in each class, calculate the per-class accuracy and confusion matrix using metrics module in tfjs-vis library. Display the result using show and render module to the determined div sections.

On the Show Example button, give Action Event to display another set of test data images, but now along with the output label predicted by the model. Take 16 test image test getTestData() function, and call the predict() function to get the prediction label. Then use util.showExample() function to display the result.

Save and Load Button

For the Save button, add an asynchronous Action Event, then call the save() function and pass the argument (‘downloads://’) to download and save the model to your local directory.

As explained earlier, up to the time this tutorial is written, the save() function can only run in the Chrome browser. To save a model that is run in another browser such as Firefox, uncomment thel util.firefoxSave (model) code.

Then for the Load button, also add an asynchronous Action Event function. Inside it, take the given JSON and BIN files uploaded in the json-upload and weights-upload file input. Then call the loadLayersModel() function to start forming a new model. After the model is loaded, activate the Predict button for testing.

Predict and Clear Button

Finally, for the testing section, first let’s initialize the predict-canvas so that it can receive input mouse action calling the initCanvas() function. Then for the Clear button, add an action event to clear the canvas by calling the clearRect() function.

For the Predict button, add asynchronous Action Event and extract the image drawn by user from canvas using browser.fromPixels() function from TensorFlow.js. Then use helper function util.cropImage() to crop and resize the image to fit the network size. Show the cropped image to the preview canvas using browser.toPixels() function.

Cast the image into tensor and feed it to predict() function to get the class prediction. In addition to that, show the prediction score as a bar chart and display it to the main page.

Final Result

The MNIST CLassifier using TensorFlow.js application is done. You can try the application. If everything goes well, the display will look like the image below.

Do try to save the model after training, and load it later on. If you load the trained model after you download the MNIST data, and you check the evaluation and prediction example, you will see that the evaluation result displayed will be the same.

That’s it.

And this is the end of our TensorFlow.js Tutorial series. You can play it around further with this library, and make something really incredible. In the next chapter, we’ll see some awesome examples of already made application built with TensorFlow.js.

--

--

anditya.arifianto
Artificial Intelligence Laboratory — Telkom University

Lecturer at Telkom University, Software Engineer, Artificial Intelligence and Deep Learning Enthusiast