Getting started with Deep Learning in browser with TF.js

Rishit Dagli
Analytics Vidhya
Published in
4 min readMar 31, 2020

--

source: tensorflow.org

If you are a beginner in Machine Learning and want to get started with developing models in the browser this is for you, if you already know about the TensorFlow Python wrapper and want to explore TF.js you have come to the right place! Unlike of traditional blogs, I believe in hands-on learning and that is what we will do here. I would advise you to first read the prequel here.

All the code is available here-

We will continue with what we did in the earlier blog. In the last blog we made our model and also pre-process the data, let us now fit the model and predict from it.

Training the model

With our model instance created and our data represented as tensors we have everything in place to start the training process. Add this code to your script.js

Let us now break this down.

model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});

We have to ‘compile’ the model before we train it. To do so, we have to specify a number of very important things:

  • optimizer: This is the algorithm that is going to govern the updates to the model as it sees examples. There are many optimizers available in TensorFlow.js. Here we have picked the adam optimizer as it is quite effective in practice and requires no configuration.
  • loss: this is a function that will tell the model how well it is doing on learning each of the batches (data subsets) that it is shown. Here we use meanSquaredError to compare the predictions made by the model with the true values.

You can read more about them here.

const batchSize = 32;
const epochs = 50;

Now we will pick an optimal batch size and epochs

return await model.fit(inputs, labels, {
batchSize,
epochs,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});

model.fit is the function we call to start the training loop. It is an asynchronous function so we return the promise it gives us so that the caller can determine when training is complete. To monitor training progress we pass some callbacks to model.fit. We use tfvis.show.fitCallbacks to generate functions that plot charts for the ‘loss' and ‘mse' metric we specified earlier.

You just made a model and have all the required functions to make it traain and fit, what more you will now call these functions.

Calling the model

const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;

// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');

When you add this code to the run function you are essentially calling the model fit function. When you refresh the page now you will be able to see a training graph.

These are created by the callbacks we created earlier. They display the loss and mse, averaged over the whole dataset, at the end of each epoch. When training a model we want to see the loss go down. In this case, because our metric is a measure of error, we would want to see it go down as well.

Make predictions

Now we need to make predictions through our model, let us do this. Add this to your script.js to make predictions.

Let’s break this down

const xs = tf.linspace(0, 1, 100);      
const preds = model.predict(xs.reshape([100, 1]));

We generate 100 new ‘examples’ to feed to the model. Model.predict is how we feed those examples into the model. Note that they need to be have a similar shape ([num_examples, num_features_per_example]) as when we did training.

const unNormXs = xs
.mul(inputMax.sub(inputMin))
.add(inputMin);

const unNormPreds = preds
.mul(labelMax.sub(labelMin))
.add(labelMin);

This peice of code will unnormalize the data so I can get which is not 0–1

Let’s now see this in action, add this to your run function

testModel(model, data, tensorData);

You will get similar results

Concluding

We created an intermediate model with just 2 layers, tweak around the hyper parameters to see what more you could do. You just worked around the basics of creating a simple model using TF.js completely in the browser and created some wonderful visualization.

About Me

Hi everyone I am Rishit Dagli

LinkedIn — linkedin.com/in/rishit-dagli-440113165/

Website — rishit.tech

If you want to ask me some questions, report any mistake, suggest improvements, give feedback you are free to do so via the chat box on the website or by mailing me at —

  • rishit.dagli@gmail.com
  • hello@rishit.tech

--

--