Bring Machine Learning to the Browser With TensorFlow.js — Part III
How to go from a web friendly format to a web application
Welcome to part three of a series of posts where I walk you through how TensorFlow.js makes it possible to bring machine learning to the browser. First, there’s an overview of how to bring a pre-trained model into a browser application. Then you’ll find greater detail on how to convert your pre-trained model to a web friendly format. Now in this post, we step through using that web friendly model in a web application.
Importing the model
The first step in importing the model to a browser readable format is to include the TensorFlow.js library in your HTML via script tag.
This will load the latest version of TensorFlow.js but you can also target a specific version or load it via NPM.
With the library loaded, a global
tf variable becomes available for accessing its API. You can load the Image Segmenter model using the
Frozen models are graph definitions and a set of checkpoints together in a single file.
Pass the URL to the dataflow graph (
tensorflowjs_model.pb) and weights manifest file (
weights_manifest.json) to the
The shard files generated by the converter are in the same location as the weights manifest file. If needed, you can store and serve the shard files from different URL. But, you will need to edit the
weights_manifest.jsonand update the paths to the shard files.
For other models, you may need to use the
loadModel API instead.
Depending on the model size, loading may take some time. Once loaded, the model is ready to accept inputs and return a prediction.
Pre-processing the input
Models will need the inputs to be of a specific type and/or size. In most cases, the input needs some sort of pre-processing before sending it to the model. For example, some models may require a one-dimensional array of a certain size while others may require more complex multi-dimensional inputs. So the input (e.g., image, sentence, etc.) would need to be pre-processed to the expected format.
For the Image Segmenter, recall, when inspecting the model graph, the input was an
ImageTensor. It was of type and shape
This is a four-dimensional array of 8-bit unsigned integer values. The
?s are placeholders and can represent any length. They would correspond to the length and width of the image. And the
3corresponds to the length of the RGB value for a given pixel.
For 8-bit unsigned integer valid values are from 0 to 255. This corresponds with an image’s pixel RGB value which also ranges from 0 to 255. So, we should be able to take an image convert it to a multi-dimension array of RGB values and send that to the model.
To get a Tensor with the pixel values, you can use the
fromPixels function provided by TensorFlow.js. This will produce three-dimensional array with the shape
[?, ?, 3] from the given HTMLImageElement. But, the Image Segmenter is expecting a four-dimensional array. To insert an extra dimension and get the shape needed, you also will need to call the
You should now have the required input data to run the model.
Running the model
Run the model by calling
predict with the input data. The function takes the input Tensor(s) and some optional configuration parameters. It returns a prediction.
Computations are in batches. If needed you can run prediction on a single batch with the
Depending on the model complexity, the prediction may take some time.
Processing the output
The type and shape of the output returned depends on the model. To do something meaningful extra processing of the prediction is most likely required.
For the Image Segmenter, the output is a segmentation map with integers between 0 and 20. The integers corresponds to one of the pre-defined labels for each pixel in the input image.
In our web application, we are going to overlay the original image with the segments found. And each segment color coded. For example, RGB (192, 0, 0) for chairs and RGB (0, 64, 0) for potted plants.
With a color map, go through the converted array and assign the appropriate color to each segment. Then take this data to create the desired overlay image.
You can now add the resulting image to your HTML page.
Completing the web application
To complete the application, add buttons to load the model, upload an image, and run the model. Also, add the code to overlay the input image and output prediction.
You can find the completed project here. The repository contains the demo web application. It also includes the web friendly format from the
tensorflowjs_converter. You will also find a Jupyter notebook to play with the Image Segmenter in Python.
The ability to use machine learning technology on the web is often limited. Creating and training some models involve massive data and intense computations. The browser may not be the ideal environment. But, an exciting use case is to take models trained elsewhere then import and run them in the browser.