Classification Model On Custom Dataset Using Tensorflow.js Made Simple: Here’s What You Need to Know
Build a model from scratch and use that to get a prediction on the browser.
So, I recently had a chance to try out Tensorflow.js. It is a library for machine learning in JavaScript. You can develop ML models in JavaScript, and use ML directly in the browser or in Node.js.
Since I was new to it, I wanted to try it on a custom dataset for the classification task. Believe me when I tell you I couldn’t find a single tutorial that shows how to build a model from scratch and use that to get a prediction on the browser. I explored a lot and applied all my knowledge to this single project. Hope this article helps you get a jumpstart on your tensorflow.js learning journey.
Get the dataset
For this project, I’ll be using the ‘Flower Classification’ dataset which I downloaded from Kaggle. The files are in the given format:
Train -> 5 folders with the name of the folder being the class
Test -> 5 folders with the name of the folder being the class
Getting and processing the data
Create a data.js file, which will be used for data loading. Now here we assume that every image in the folder name “n_flowername.png” where “n” is the number of image and “flower” is the flower name.
Create a model
Now let’s create a demo model for training the data. This should be a new file model.js, this file will contain only the model architecture.
Since we have a multiclass classification problem, we’ll be using ‘categorical crossentropy’. If you have only two classes you can go for ‘binary crossentropy’.
Let’s put it all together
Create a main.js file that will call all the functions and execute them. Be sure to change the path where the dataset is present
Now let’s use the browser to get a prediction
We’ll first create predict.js in which we’ll load our pre-trained model and get a prediction.
NOTE!
Despite the fact that all of the files are static, our Html file will still need a server to load the model.json. It’s the same as hosting a website.
You can do so by using ‘Web Server for Chrome’ extension in Google Chrome( I’m assuming you’re using Google Chrome to open HTML file)
Just launch the app and keep the settings as below.
Choose a folder where your HTML file is located and then you’re done. Open the first link given below “Web Server URL” and your web app will work just fine.
Find the whole code here → Flower Classification
Conclusion
In this article, we explored how to use a custom dataset for the classification task in tensorflow.js. If you want to create a web app, this project will guide you to create an app in which you only need a model and the inference will be carried out on the client-side.
Thank you!
I hope you enjoyed the article!