On-device ML with TFLite models in Flutter apps

Make use of Flutter’s platform-specific channel on android

SHARON ZACHARIA
Flutter Community
5 min readAug 22, 2019

--

Machine learning models can be accessed on mobile platforms by serving as a REST API. But what if we could use a model locally from the mobile device and make inference on it, this would help in offline use of models. Tensor Flow Lite models can be easily loaded on a mobile device and can be used as needed.

We will see how to make a text classification app which can detect offensive sentences.

  • Build a Keras model.
  • Save the model and convert it into tflite format.
  • Load the model on a mobile platform(android).
  • Use flutter’s platform-specific channel to communicate with java code to make inferences.

Keras Model

Our model takes a matrix as input, so we need to convert strings into the corresponding matrix before feeding it into the neural network. The method we use to convert string to matrix will be explained in a moment before that let's look at our data set format.

Dataset

Data set used for this model is in CSV format which consists of 3500 text and corresponding labels indicating offensive(1)or not(0). You can find the data-set here. The quality of data determines the quality of your model, you can try a different data set if needed, here we will use this one.

still proclaiming your stupidity,1
read the article find your answers there,0

We can load CSV using pandas and remove stop-words from the data-set.

Text to matrix

We will make use of Tokenizer in Keras for this purpose.

The above code initializes tokenizer, tokenizer.word_index gives word and index value(The most occurring word will have value 1 and least will have 1000 here)of the integer passed to num_words parameter(1000 here). Word index for this data set will be like this:-

“like”: 1, “would”: 2, “get”: 3, “people”: 4, “go”: 5, “know”: 6, “one”: 7,.................... “boston”: 997, “banned”: 998, “gas”: 999, “fall”: 1000,

After this process if our input text is “people like that” then the matrix that we feed into the neural network will be like [0,1,0,0,1…………..],it should be noted that the first element will be always 0 and the word ‘like’ appears at index 1 of matrix. Convert this word index to a json file because we need this word index on a mobile device. There are different modes (we use binary here ) used while converting sequences to matrix , if needed you can find more about tokenizer here.

Next, we will define our model

Build a model using Keras sequential API, the input shape will be a number of elements in our matrix, here it will be 1000. Drop out layers helps in reducing overfitting. Our model takes an array of 1000 elements which means training was done by taking only first 1000 occurring words where there are about 10,000-word indexes in our json file increase max words in the model may contribute to improving the number of words model learns.

The model has training accuracy around 95% and 80% in the test , to improve the accuracy try changing the parameters or we can use a dataset with much more data items. For now, we will stick with this.

Convert to TFLite

TensorFlow Lite is an open-source deep learning framework for on-device inference. Before we convert our model to tflite format we need to save the Keras model, the following code saves the model.

keras_file = “modelname.h5”
model.save(keras_file)

To convert the model to tflite the following code can be used, in case this does not work on your PC try using Google Colab

We are done with the model part now at this point we would have a json file with word index and a tflite file. You can find the complete code for Keras model here with explanations.

Mobile Platform

We will see how to load these files on our android device and do inference on the model, before that move on to android folder in your flutter app.

  • Add the following dependency in build.gradle file
implementation ‘org.tensorflow:tensorflow-lite:+’
  • Place tflite model inside the assets folder of android
  • Open MainActivity.java and add the missing methods , these methods helps in communicating with dart code to get the input data and return the value predicted by our model.
  • Add the json file in the assets folder of flutter project (not in android)

Now that we have added required files and dependency, there is one more dart class to look at, the purpose of this class is to convert the new text given by user to a matrix which can be given to our model for prediction. This class provides the same output our Tokenizer class gave in Keras model. Explanations of logic used are given along with code

Instantiate this class with count(no of elements in matrix) and the path to the json file and calling the getTokenized() method returns the matrix corresponding to the text given as the parameter. We will then pass this matrix to java code using flutters platform-specific channel, the below method helps in this process.

Find more on flutters' platform-specific code here.

Flutter application will have a textfield and a raised button on its click we will convert the user text to matrix and send it to native code using a platform-specific channel. This will be a future so when we get the result back from native code we will call setState() and update the text color according to the sentiment of the text.

flutter application

The above-mentioned procedures can be used on any tflite models, based on the need some changes will be needed depending upon the input shape and out the shape of the model. The complete code can be found here.

Useful links

--

--