Browser-based Models with TensorFlow.js

One of the key skills for machine learning is not just making models but also deployment. One of the most exciting deployment scenarios is javascript, where you can train, test, and make inferences right there in your web browser.

Tirth1306
Analytics Vidhya
9 min readSep 3, 2020

--

Source — Freepik

You might be thinking, that training a machine learning model and getting inferences from it will require a big data center or maybe a big GPU at least in your machine. But with the help of modern web-browser and high tech computers, we can instantly train a model and deploy it on your web browser. It’s cool that we can easily upload an image to a web browser or can get the image from a webcam. Then we can have a model to do inferences right in the web browser without needing to send that image up to the cloud to be processed by any server.

Design and architecture of TensorFlow.js

Source — Coursera

It is designed in such a way that it can be run on a browser as well as on the Node.js server. The Layers API in TensorFlow.js looks and feels a lot like Keras. As we will be coding in javascript, there is a slight difference in syntax that we used to code in python. The low-level APIs are called Core APIs.They are designed to work with TensorFlow saved model formats, which is designed to be a standard file format that can be used across the JavaScript, the Python APIs, and even TensorFlow Lite for mobile and embedded devices. For accelerated training and inference, the Core API then operates with the browser and can take advantage of WebGL. Also on Node.js, you can build server-side or terminal applications and can then take benefit of CPUs, GPUs, and TPUs depending on your machine.

Simple Neural Network in Browser

The first thing you’ll need to do is add a script tag below the head and above the body of your HTML file to load the TensorFlow.js file. Below is the code to do that.

This script must be included in your HTML page to use tensorflowjs services

The first line defines the model to be sequential. The simplest possible neural network is one layer with one neuron. So we’re only adding one dense layer to our sequence. This dense layer has only one neuron in it, as you can see from the units equals one parameter. We then compiled a neural network with a loss function as Mean Squared Error, which works well in a linear relationship and SGD(stochastic gradient descent) as the optimizer. Model.summary just outputs the summary of the model definition for us.

Making a Model in javascript

Also, dotraining() should be an asynchronous function because training can take an indeterminate amount of time, and we don’t want to block the browser while this is going on. You call this function and parse the model that you just created above and can do the predictions.

Training the model in javascript

Below is the data that you’ll use to train the neural network. First, you’ll notice that we’re defining it as a tensor 2D, whereas in Python we were able to use a NumPy array. We don’t have NumPy in JavaScript, so we’re going a little lower. As its name suggests, when using a Tensor 2D, you have a two-dimensional array or two one-dimensional arrays. So in this case you’ll see that my training values are in one array, and the second array is the shape of those training values. So I’m using a set of 6 values in a one-dimensional array, and thus the second parameter is [6, 1], and we will do the same for y. So if you tweak this code to add or remove parameters, remember to also add the second array to match its size.

Predicting in Javascript

After calling the asynchronous dotraining() function, we can easily do predictions using the model.predict by passing the value to be predicted in tensor2d. Now, you are ready to run our first ever and simple neural network in your web browser. Just run your HTML file and you can see the training in the console along with the summary of the model and get the predicted output as an alert message as shown below.

Demo of the whole code (Simple Neural Network in Browser)

Image Classification In the Browser

Now, we will look at image processing in JavaScript, by training convolutional neural networks for image classification in the browser, and then writing a browser app that takes these images and passes them to the classifier. We’ll start by looking at creating the model using JavaScript. So below is the code for creating a convnet with JavaScript. It’ll mostly look familiar, but there are a few minor differences.

Making Model in Javascript

Training a model can be done with the fit method of the model object. You pass it the training data and labels, as well as a dictionary of parameters. Batching data for training instead of flooding the model with all of the data at once is always a good idea. When doing it in the browser it’s an even better idea so you don’t lock up the browser itself. If you want the model to validate as it’s training to report back an accuracy, then you can use a list of validation data as shown below. You also can specify the number of epochs that you want to train and can shuffle the data to prevent over-fitting. As always, you can specify callbacks, so you can update the user on the training status, and for that javascript has a cool library called tf-vis.

Feeding the data

— Visualization tool in Javascript ( tjs-vis )

JavaScript has some extra tools that you can use to make visualization of the training a lot more friendly. First of all, you include the library called tjs-vis in your code with this script.

Including the script for better visualization

To use the tf-visualization libraries with fitCallbacks, you simply declare it to be the return from tfvis.show.fitCallbacks. This function requires you to pass it a container where it will render the feedback, and a set of metrics that it should track. It’s straightforward as setting the metrics list to the metrics that you want to capture, like loss, validation loss, accuracy, and validation accuracy. And for the container, you just set a name and any required styles, and the visualization library will create tose elements in the browser.

Setting all the parameters for tf-vis

— Sprite Sheet

Part of the MNIST Sprite Sheet ( Source — Sprite Sheet )

You might be familiar with the MNIST database which contains thousands of images. When training the model in python it takes every image as an input from the file system and processes it according to the model very easily and instantly. But in the browser, whenever you load an image it makes an HTTP call, and making thousands of HTTP call is not a good practice. One of the good solutions is by making a sprite sheet of the whole MNIST dataset. Sprite Sheet is very beneficial for training a classifier in javascript. So, the sprite sheet contains all 70,000 images in a single image, stack on one another, which will be then sliced into an individual image. This individual image is then converted into arrays and will be processed by the model.

Below is the Demo of Handwriting Classification using the MNIST dataset.

Demo of Image Classifier in Browser ( Whole code here)

Converting Models to JSON Format

We can take models that have been created with TensorFlow in Python and convert them to JSON format so that they can run them in the browser using Javascript. For this, we have to install tensorflowjs in python.

Installing tensorflowjs in python

Once you had created and trained your model, we need to save out our model. We’ll start by generating a directory to save the file, and we do that using a timestamp. So we’ll import time, get the current timestamp, and save the model and the path /tmp/saved_models/ followed by the timestamp. It saves the saved model to the specified path. Now the next line will convert your saved model to JSON format. Your model may have multiple files other than just model.json, depending upon your model. Next step, you’ll need to download these files and put them in the same directory as the HTML page that will host them. For more methods, click here.

Code to convert the model into JSON format

First of all, the URL of the model has to be loaded over HTTP. While it’s in the same directory as the HTML, we can directly write the name but I still use the URL path. Be sure to get this part right. To get the JSON and turn it into a model, I’ll call await tf.loadlayersModel() by passing that MODEL_URL. Once this completes, I’ll have a trained model available to me. Now, I can inspect the model by calling its model.summary and can get the results by calling model.predict(). Yeah, that’s it. Now you can make any type model in python and it can be used in the javascript by converting into JSON format.

Code to load your converted model into the javascript

Transfer Learning in Javascript

Transfer Learning is a technique in which we apply the knowledge to one problem i.e. obtained from another problem. Here, as before, you’ll load the JSON model from its hosted URL and use tf.loaLayersModel to load it into an object. Here I am using a hosted URL of pre-trained mobilenet. From here, you can now get one of the output layers from the preloaded mobilenet. We’re selecting the layer called conv_p3_13_relu, above which we will be freezing everything. We’ll then use the tf.model class to make a new model and its constructor can take inputs and outputs, which we will set to take the mobilenet inputs, namely the top of the mobilenet and then conv_pw_13_relu as output. So that everything beneath that layer will be ignored when we connect a new set of layers to this model.

Loading the layer up to which you want to use a pre-trained model for Transfer Learning

Unlike Python, instead of adding a new densely connected set of layers underneath the frozen layers from the original model, we will create a new model in javascript. With its input shape being the output shape of the desired mobilenet layer. We then treat this as a separate model that we train.

Extending the pre-trained model according to your requirement using Transfer Learning

At prediction time, we’ll then get a prediction from our truncated mobilenet up to the layer that we wanted to give us a set of embeddings. We’ll then pass those embeddings through the new model to get a prediction that the new model was trained on. As you can see below, it’s a little bit different from what you might be used to.

Rock Paper Scissors Lizard Spock

Source — Fansite

An extension of the traditional game of chance, Rock Paper & Scissors, created by Karen Bryla and Sam Kass. Sam Kass explains, he designed this extended game because it appeared like most games of Rock Paper & Scissors would end in a tie.

“Scissors cuts paper, paper covers rock, rock crushes lizard, lizard poisons Spock, Spock smashes scissors, scissors decapitates lizard, lizard eats paper, paper disproves Spock, Spock vaporizes rock, and as it always has, rock crushes scissors.” — Sheldon

Source — The Big Bang Theory

Now, I have developed the classifier using a pre-trained MobileNet model to classify hand gestures of Rock, Paper, Scissors, Spock, and Lizard captured by a webcam. In this project, I have used all the concepts that are discussed above, and below is the demo of my project along with code explanation.

Demo for hand gesture classifier ( Whole code here)

Conclusion

With the help of tensorflowjs, we can train, test, and validate our model in the browser with the help of javascript. We also learned some tool like tjs-vis which is used for better visualization of training in the browser and some technique like sprite sheet which is very important to reduce the HTTP calls to the server. Then we had learned some tricks like converting your pre-trained model to JSON format and then using that converted model in the javascript. Also, we have seen Transfer Learning in javascript which was slightly different from what we used to do in python. If you want to go more deep in this domain would like to recommend this course by Coursera.

Tirth Patel — Computer Science and Engineering Student, Nirma University.

LinkedIn | Github| Instagram

--

--

Tirth1306
Analytics Vidhya

Machine Learning Enthusiastic | Graphic Designer | Freelancer | Entrepreneurship