Image Classification with client-side neural network using TensorFlow.js

Deepak Puttarangaswamy
Agara
Published in
7 min readApr 9, 2019

With client-side neural network, we can train and build models on the browser which will use user data locally. This provides high availability and easy interaction for the end user.

In this article, you will be reading about how to setup Keras’ pre-trained MobileNet model trained over ImageNet to classify images on client-side using TensorFlow.js.

Final application

What is TensorFlow.js?

TensorFlow.js is a library for developing and training ML models in JavaScript, and deploying in the browser or on Node.js.

TensorFlow.js provides better API and documentation than some of the other libraries to build and train complex models from scratch.

Some of the other comparable libraries:
1. ml5.js: It aims to make machine learning approachable for a broad audience of artists, creative coders, and students. The library provides access to machine learning algorithms and models in the browser, building on top of TensorFlow.js with no other external dependencies. You can get started with Machine Learning but not worry about any low-level details like Tensors or Optimizers.

2. BrainJS: It is a library of Neural Networks written in JavaScript. This is a continuation of the harthur/brain repository (which is not maintained anymore)

Click here to know more about TensorFlow.js.

Why client-side neural network?

As web technology advances, it opens door to many platforms for unique web experiences and applications.

Keep user data local — With the client-side neural network, we can have small sized models run locally. The user data stays at the client-side and does not have to be sent across the internet to run through a model and get responses.

No dependencies — Highly usable and accessible to the end user. Anyone can use it, this requires no additional installations or prerequisites. It can simply be accessed from the web browser via the URL address where frontend application is hosted.

Interactive and easy to use — In general, web-based applications are highly interactive and easy to use. It is easy for users to get started and engage with the application.

The Caveats — Having models run in the browser, it is best suited for tasks like transfer learning, fine-tuning pre-trained models and for inference.

Transfer learning is the improvement of learning in a new task through the transfer of knowledge from a related task that has already been learned.
Here is a demo of transfer learning, you will use webcam as a controller to play pac-man using images trained in your browser. Here, a pre-trained image classifying model is further trained on images in the browser to define the directions and controls of a joystick.

Loading and running massive sized models into our web application might cause some performance issues. TensorFlow suggests using a model that is of 30 MB in size or less in the browser.

Check out some cool demos from TensorFlow.js here.

Let's get started

We will be building a web application to choose an image and submit it to our model, the app will give us back the top five highest predictions for the image from the imageNet classes.

I have created the frontend application for this example with React, Redux, and Bootstrap. Feel free to pick your choice.

Here is the GitHub repo of the codebase.

Let's take a MobileNet model that is already built and trained with Keras and make use of it in the browser with TensorFlow.js.

First, we need to install the TensorFlow.js model converter tool. From a Python environment where Keras is already installed, run pip install tensorflowjs from the terminal.

Next, open Jupyter notebook. Run jupyter notebook from the terminal.

We are in the Jupyter notebook now. Lets import Keras and TensorFlow.js library. We will create a MobileNet model and then convert the model and save it in a directory. Below is the code from the notebook.

If everything goes well, then you should be able to see the converted model files in the mentioned output directory.

Loading the model

Now, let's load the model as and when our frontend application loads on the browser. I will be saving the model in browser’s indexed DB, this way we can load the model only once and cache it for the next subsequent load/refresh of application.

I am using Web server for Chrome to host the model on my localhost with port number 8887.

Select an image

Let's provide an Input element to browse and select an image from the local disk. On selecting the image, we will create the object URL for the selected image and render the image as a preview. We then, make a call to predictImage action by passing img element as an argument.

Input element to choose an image
Image element to preview the selected image
On choosing a photo, call model to predict the label for the selected image

Pre-processing the image

Alright, now when we choose an input image from our application, we need to transform the selected image to a rank-4 tensor object of floats with height and width dimension 224 x 224. It's expected, as our mobilenet model is pre-trained with ImageNet images with dimension 224 x 224.

To do this, we create a tensor object from the image by calling the TensorFlow.js function tf.fromPixels() and passing our image element to it. We then resize the image to 224 x 224, cast the tensor’s type to float32 and expand the tensor’s dimensions to be of rank-4.

Transforming an image to a tensor object

MobileNet wants the image data to be further pre-processed in a specific way. In other libraries like, Keras, pre-processing functions for specific models are included in the API. Although in TensorFlow.js does not have these pre-processing function included, we will build a pre-processing function as follows:

The images that MobileNet was originally trained on were preprocessed so that the RGB values were scaled down from a scale of 0 to 255 to a scale of -1 to 1.

We do this by first creating a scalar value of 127.5 which is exactly one half of 255. We then subtract this scalar from the original tensor and divide that result by the scalar. This way, we will have all the values in tensor on a scale of -1 to 1.

([0, 255] - 127.5) / 127.5 = [-127.5, 127.5] / 127.5 = [-1, 1]
check out Broadcasting to understand this better.

Getting a prediction

When the user selects an image, we make a call to action method to transform the image into a tensor object. Now, we can pass the transformed tensor to the model to get a prediction by calling predict() on the model and passing tensor to it. predict() returns a tensor of the output predictions for the given input. We then call data() on the predictions tensor, which asynchronously loads the values from the tensor and returns a Promise .

So, predictions array is going to be made of 1000 elements, each of which corresponds to the prediction probability of individual ImageNet class. Each index in the array maps to the specific ImageNet class.

We will get the top five highest predictions out from the predictions and store it in top5 variable. We will iterate through the predictions array and create object item with prediction probability and imageNet className properties. We will then sort the list of objects in descending order of prediction probability and obtain the first five from the sorted list using the slice() function.

We will then store the prediction state and display it on the UI as a list. We will highlight the predictions with different background colors based on the probability score.

Demo

We’ve completed building our application. You will be able to see the prediction as you select an image. Notice, how fast our MobileNet is generating these results.

The trained MobileNet model used in this example is about 17 MB in size. MobileNets have a reduced number of parameters — 4.2 million, faster in performance and are useful for mobile applications. They are small and low latency convolutional network.

Even though MobileNet has reduced size, reduced parameters and performs faster, it is less accurate than other state-of-the-art networks as discussed in this paper. But there is only a slight reduction in accuracy when compared to other networks.

So hopefully, this article has illustrated the practicality of using MobileNet to classify images on the client side using TensorFlow.js. Let me know in the comments if you have something to share.

Some results

References and useful links:
TensorFlow.js official site https://js.tensorflow.org/
Keras https://keras.io/
https://gogul09.github.io/software/mobile-net-tensorflow-js
Deeplizard http://deeplizard.com/learn/playlist/PLZbbT5o_s2xr83l8w44N_g3pygvajLrJ-
Standford’s CS231n: Convolutional Neural Networks for Visual Recognition http://cs231n.stanford.edu/

--

--