How We Used Transfer Learning to Create an Image Classifier in the Browser

Joshua Jarvis
Carnegie Mellon Robotics Academy
6 min readAug 11, 2020

Introduction

We recently released an interactive activity that provides an introductory understanding of how an Image Classifier might be used in a factory context. The goal of this activity is to produce as many ‘good’ widgets as possible while also minimizing ‘defective’ widgets. The robot can be trained to classify any part coming down the assembly line and can also be told to remove an object that is identified as ‘defective’.

Image Classifier Activity Gameplay

We developed the conveyor belt and camera scenes in Unity. We have an article that takes a deeper dive into how we like to do Unity and JS communication here.

Initially, we thought that it might be a good idea to invoke an image classifier endpoint on a remote server. This would offset the computation demands from the client to the server but wouldn’t be without its own pitfalls. For our requirements ‘training’ an image classifier needed to take place in real time. So the communication over the network would have to be fast. This would scale poorly if the activity is used by many concurrent users and each client was capable of sending a few hundred requests. There had to be a better way to do this. We talk about how you can host a ML model on a serverless backend in another article.

Fortunately for us, the answer to our problems was TensorFlow.js (TF). The most convenient part was that we could use an existing model, train it through transfer learning, and immediately run inferences against the classifier without leaving the browser.

Image Classification through Transfer Learning

TF provides two trained models that we can use to build a custom image classifier —Knn-classifier (KNN) and MobileNet.

The MobileNet model is an image classifier model designed to run in low resource environments like a browser. However, if we use the model to predict our game’s widget parts we won’t get the predictions we want. We will instead get a label based on the very large training set that MobileNet was generalized on.

Transfer Learning allows us to take the existing capabilities of the MobileNet model and ‘teach it new tricks’. In our case we would like to teach it to identify the difference between the defective and pristine widgets coming down the conveyor belt. We will need another use another model — KNN to achieve this.

The TF MobileNet API provides a convenient method to extract the features of the image. We can use these features along with a custom label such as ‘defective gear’ or ‘good gear’ and add the examples to our KNN model. The KNN model takes the K (default 3 for this model) nearest neighbors of similar images. For each prediction. we will take the output of the MobileNet and feed that into the KNN model to see which zone the image falls into. The KNN zone maps to the identifying label and gives us our custom prediction ‘defective gear’ or ‘good gear’.

A KNN can be used to group the K nearest neighbors of a dataset

Implementing our Image Classifier

The image classifier activity was built with Unity, React, Redux, and Tensforflow.js. We used React for our interactive GUI, Redux to manage and store the activity’s state, and Tensorflow.js models to build our image classifier. The source of the project can be viewed here.

For this article we will mostly focus on the Tensorflow.js code used to build the image classifier.

The MobileNet model is loaded asynchronously via a Redux action when the game is loaded. The model takes some time to load so we want to make sure that the model is available before the activity begins.

// Redux Actionimport * as MobileNetModule from '@tensorflow-models/mobilenet';const loadMobileNet = (model) => {
return {
type: 'LOAD_MOBILENET',
model,
};
};
export const loadMobileNetAsync = () => {
return (dispatch) => {
MobileNetModule.load().then((m) => {
dispatch(loadMobileNet(m));
});
};
}

// Redux Reducer
const featureExtractor = (state = {}, action) => {
switch (action.type) {
case "LOAD_MOBILENET":
return action.model;
default:
return state;
}
};
export default featureExtractor;

The KNN classifier can be instantiated synchronously and is initiated as part of the default state for the classifier reducer.

import * as knn from '@tensorflow-models/knn-classifier';const classifier = (state = knn.create(), action) => {
switch (action.type) {
default:
return state;
}
};
export default classifier;

Adding Examples to the KNN

Each example comes from a binary file sent from Unity using the unity comms protocol between Unity and React. The MobileNet model (labeled as the feature extractor in the function) expects the file to be of type Image. We process the binary from Unity as an image.

export const addExample = (binary, label) => {
const { featureExtractor, classifier } = store.getState();
preProcessImageWithCallback(binary, (img) => {
const features = featureExtractor.infer(img, "conv_preds");
classifier.addExample(features, label);
});
};

The important part to note here is that we do so asynchronously and only call the preProcessImageWithCallback function once image onload event triggers. The callback function uses the MobileNet model (featureExtractor) to generate a set of extracted features for the KNN model (classifier) to add an example.

const preProcessImageWithCallback = async (binary, callback) => {
let img = new Image(800, 600);
img.src = `data:image/png;base64, ${binary}`;
img.onload = () => {
callback(img);
};
};

Predicting Images

Once the KNN has been built with a sufficient set of examples, predictions become remarkably accurate. The predict function returns a Promise that is resolved once the classifier.predict method returns a result. If the result label does not map to the KNN’s label classes, we reject the predict promise and rebuild the model.

We also send a defaultLabel class in cases where the predict method is invoked before the KNN has examples.

export const predict = (binary, uid) => {
return new Promise((resolve, reject) => {
const { featureExtractor, classifier, labelClasses } = store.getState();
preProcessImageWithCallback(binary, (img) => {
const features = featureExtractor.infer(img, "conv_preds");
classifier
.predictClass(features)
.then((result) => {
const labelClass = getLabelClassByName(
labelClasses.list,
result.label
);
const confidences = result.confidences;
if (labelClass) {
resolve({
...labelClass,
uid,
confidences,
});
} else {
reject("rebuilding model");
}
})
.catch((_) => {
resolve({
...defaultLabelClass,
uid,
rgb: {
a: 0,
r: 0,
g: 0,
b: 0,
},
});
});
});
});
};

Reloading and Deleting Labels

Our activity interface allows the learner to delete individual labels. Unfortunately, this behavior was not directly supported in the KNN API. To support this feature, we had to delete all examples and then reload the examples from the Redux state for each deletion.

Clearing an entire set of class labels is pretty straight forward.

export const deleteExamples = (label) => {
const { classifier } = store.getState();
try {
classifier.clearClass(label);
} catch (e) {}
};

Reloading the examples involved taking all of the KNN’s class labels and adding the examples back the KNN. The KNN’s classDatasetMatrices method only returns the label names and does not provide the full object for each label. We store the full object in the labelClasses reducer and iterate over each object, clear all examples, and then add back each example for each image. This is a little computationally inefficient and could be improved by adding direct support for individual label deletions directly to the KNN object. Fortunately, the challenge can be completed with a small number of examples and does not run into problems even if the reload examples function is O(n²) .

export const reloadAllExamples = () => {
const { classifier, labelClasses } = store.getState();
const labels = Object.keys(classifier.classDatasetMatrices);
labelClasses.list.forEach((label) => {
const labelObject = labels[label.id - 1];
classifier.clearClass(labelObject);
label.images.forEach((img) => addExample(img, label.name));
});
};

Conclusion

Hopefully this article helps demystify the use of an image classifier in a Javascript application. In review, we managed to build an image classifier by first extracting the features of our image using the MobileNet model, feeding those features into a KNN model with a label, and then invoking the KNN predict function using the MobileNet features of subsequent images for a prediction. We managed to do customize a model in real time without too much overhead. We recommend TensorFlow.js models for any JS developer interested in conveniently leveraging ML in there applications without depending on a remote server. Be sure to checkout other TF.js models for your next project.

This material is based upon work supported by the National Science Foundation under Grant Number 1937063.

Github Repository

Online Demo

Rapid Prototyping with Unity and React

Deploying a Machine Learning Model to a Serverless Backend with SAM CLI

--

--

Joshua Jarvis
Carnegie Mellon Robotics Academy

Full Stack Ruby on Rails engineer at Carnegie Mellon Robotics Academy in Pittsburgh.