Published in


Offline On Device Text Classification using MLKit

Machine Learning has proven to be a great advantage over the simple rule-based system. However, it comes with its own set of complexities such as training model, its size, computation, etc. As a result, it becomes challenging to use machine learning for mobile applications, where users expect a quick response.

But with the release of TensorFlow lite by google, it’s now possible to ship and run any deep learning model directly on the device using Firebase MLKit.

Before delving deeper into this, let’s first understand the key advantages of having an ML model on the device:

  • No server communication and hence reduced hosting cost
  • Offline support — Will work without Internet
  • Speed — Speed of the task will improve as all processes are running locally
  • Privacy — Data will reside inside the user’s device

We will be using python as the backend to train and convert a model to the Tflite type. Below is an overview of the topics we shall be covering:

  • Data preparation and preprocessing
  • Building word tokenizer
  • Building a text classifier model using bag-of-words as Feature using Keras.
  • Converting Keras model (.h5) to Tflite format.
  • Creating an android application to run inference on the offline model.

Data Preparation

We need to first create a dataset for text classification. For simplicity, we can use SNIPS intent classification dataset with classes.

You can download the dataset from here.

import csvsentences , labels = [], []
with open('data.csv','r')as f:
data = csv.reader(f)
for row in data:

Building Word Tokenizer

Since Machine Learning works only on numbers, we need to first transform sentences to fixed number representation. For this, we will create a word_index dictionary, with a mapping of each word to a unique identity number.

Here we will read uniques words from a sentence list and assign them a unique index. This will then be used to convert sentences to list of numbers:

sentences = [re.sub(r'.,:?{}', ' ', sentence) for sentence in sentences]corpus = " ".join(sentences)
words = set(doc.split())
word_index = {word: index for index, word in enumerate(words)}
with open( 'word_index.json' , 'w' ) as file:
json.dump( word_index , file )

Building a Text Classifier model

We will build a text classifier (using the bag-of-words feature) using DNN architecture and bag-of-words as input feature:

from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from keras.layers import Dense, Input, Dropout
from tensorflow.python.keras import models, optimizers, losses, activations
from keras.layers.normalization import BatchNormalization
from keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
LE = LabelEncoder()def train_and_eval(sentences, label):# converting categorical label
labels = LE.fit_transform(labels)
labels = np.array( labels )
num_classes = len(labels)
onehot_labels = tf.keras.utils.to_categorical(labels ,

setences_tokens = [sentence.split() for sentence in sentences]
tokenizer = tf.keras.preprocessing.text.Tokenizer()
tokenizer.word_index = word_index
sentences_features = tokenizer.texts_to_matrix(setences_tokens)
train_features, val_features, train_labels, val_labels =
train_test_split(sentences_features, onehot_labels, test_size = 0.1)
feature_input = Input(shape=(sentences_features.shape[1],))
dense = Dense(128, activation=activations.relu)
merged = BatchNormalization()(dense)
merged = Dropout(0.2)(merged)
merged = Dense(64, activation=activations.relu)(merged)
merged = BatchNormalization()(merged)
merged = Dropout(0.2)(merged)
preds = Dense(num_classes, activation=activations.softmax)(merged)
model = Model(inputs=[word_input], outputs=preds)
optimizer='nadam', metrics=['acc'])
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
model.fit([train_features], train_labels,
validation_data=([val_features], val_labels),
epochs=200, batch_size=8, shuffle=True,

Run the method given below to test your model by giving a model path and word_index path:

def test(sentence, model_path, word_index_path)classifier = models.load_model( 'models/models.h5' )
tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='.,:?{} ')
sentences = re.sub(r'.,:?{}', ' ', sentence)
with open(word_index_path, 'r') as f:
tokenizer.word_index = json.loads(f.read())
tokenized_messages = tokenizer.texts_to_matrix(sentence.split())
p = list(classifier.predict(tokenized_messages)[0])
for index, each in enumerate(p):
print(index, each)

Converting Keras Model (.h5) to Tflite format

We need to convert the above model file to Tflite format, which we will then ship to the ML kit and android device.

def convert_model_to_tflite(keras_model_path):tf.logging.set_verbosity( tf.logging.ERROR )
converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(
keras_model_path )converter.post_training_quantize = True
tflite_buffer = converter.convert()
open( 'model.tflite' , 'wb' ).write( tflite_buffer )
print( 'TFLite model created.')

Creating the Device Application

Given below is the basic flow of how the ML model works on the device.

Let’s now discuss step-by-step the process we will be following to run inference.

Starting your project

  1. Add word_index.json and model.tflite inside assets of your android project.
  2. Add the dependencies for the ML Kit Android libraries to your module (app-level) Gradle file (usually app/build.gradle):
dependencies {
// ...
implementation 'com.google.firebase:firebase-ml-model-interpreter:21.0.0'
apply plugin: 'com.google.gms.google-services'

Also, in your build.gradle ( app-level ), add these lines, which will disallow the compression of .tflite files.

android {
buildTypes {release {
aaptOptions {
noCompress "tflite"

Hosting Models on Firebase

Follow the below steps to host your model.tflite mile to MLKit console.

  1. In the ML Kit section of the Firebase console, click the Custom tab.
  2. Click Add custom model (or Add another model).
  3. Specify a name that will be used to identify your model in your Firebase project, then upload the TensorFlow Lite model file (usually ending in .tflite or .lite ).
<uses-permission android:name="android.permission.INTERNET" />

Define Constants value used for Model

// model name given to custom model stored on MLKit
public static String REMOTE_MODEL_NAME = "mlmodel";
// model name given to model stored locally (can be the same as on MLkit)
public static String LOCAL_MODEL_NAME = "mlmodel";
// file for word dict with word to index map
public static String WORD_DICT_FILE = "word_index.json";
// file for model stored locally inside assets
public static String LOCAL_MODEL_FILE = "model.tflite";
// input shape to your model (max value of index in word_index.json file)
public static Integer MODEL_INPUT_SHAPE = 30;
// number of classes for your text classification task
public static Integer MODEL_NUM_CLASS = 8;

Creating Model Input for Given Text

This method will return a list of integers in the required shape expected by the model. Here are the steps involved:

  1. Read word_index file from assets.
  2. Clean the text, removing punctuations, extra spaces, etc.
  3. Create a list of zeros of the size of the model input shape.
  4. Split text into words, based on words present in the text, it finds the index of that word from word_index and assigns value 1 that index in the list of the above-created zeros.

Code for the above implementation is given below:

public static String cleanText(String text){String clean_text = text.toLowerCase();
clean_text = clean_text.replaceAll("[.,:?{}]+", " ");
clean_text = clean_text.trim();
return clean_text;
private float[][] textToInputArray(String text) throws JSONException {

float[][] input = new float[1][MODEL_INPUT_SHAPE];
JSONObject word_dict = new JSONObject(readJSONFromAsset(WORD_DICT_FILE));String clean_text = cleanText(text);
String[] words = clean_text.split(" ");
for (String word : words) {if (word_dict.has(word)) {
int index = word_dict.getInt(word);
input[0][index] = 1;
return input;

Run Classification

Call run inference method with the above-processed model input. It returns the label (int) with the maximum confidence score.

public class MLModel {float[] probabilities = new float[Constant.MODEL_NUM_CLASS];public void configureHostedModelSource() {// [START mlkit_cloud_model_source]
FirebaseModelDownloadConditions.Builder conditionsBuilder =
new FirebaseModelDownloadConditions.Builder().requireWifi();
// Enable advanced conditions on Android Nougat and newer.
conditionsBuilder = conditionsBuilder
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();// Build a remote model source object by specifying the name you assigned the model
// when you uploaded it in the Firebase console.
FirebaseRemoteModel cloudSource = new FirebaseRemoteModel.Builder(Constant.REMOTE_MODEL_NAME)
FirebaseModelManager.getInstance().registerRemoteModel(cloudSource);// [END mlkit_cloud_model_source]
public void configureLocalModelSource() {// [START mlkit_local_model_source]
FirebaseLocalModel localSource =
new FirebaseLocalModel.Builder(Constant.LOCAL_MODEL_NAME) // Assign a name to this model
// [END mlkit_local_model_source]
private FirebaseModelInterpreter createInterpreter() throws FirebaseMLException {
// [START mlkit_create_interpreter]
FirebaseModelOptions options = new FirebaseModelOptions.Builder()
FirebaseModelInterpreter firebaseInterpreter =
// [END mlkit_create_interpreter]
return firebaseInterpreter;
}private FirebaseModelInputOutputOptions createInputOutputOptions() throws FirebaseMLException {// [START mlkit_create_io_options]FirebaseModelInputOutputOptions inputOutputOptions =new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, MODEL_INPUT_SHAPE})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, MODEL_NUM_CLASS})
// [END mlkit_create_io_options]
return inputOutputOptions;
}public float[] runInference(float[][] input) throws FirebaseMLException {
FirebaseModelInterpreter firebaseInterpreter = createInterpreter();
FirebaseModelInputOutputOptions inputOutputOptions = createInputOutputOptions();
// [START mlkit_run_inference]FirebaseModelInputs inputs = new FirebaseModelInputs.Builder().add(input) // add() as many input arrays as your model requires.build();firebaseInterpreter.run(inputs, inputOutputOptions)
new OnSuccessListener<FirebaseModelOutputs>() {
public void onSuccess(FirebaseModelOutputs result) {
// [START mlkit_read_result]
float[][] output = result.getOutput(0);
for (int i = 0; i < MODEL_NUM_CLASS; i++) {
probabilities[i] = output[0][i];
// probabilities = output[0];
Log.d("Success prediction", "" + probabilities[7]);
// [END mlkit_read_result]
.addOnFailureListener(new OnFailureListener() {
public void onFailure(@NonNull Exception e) {
Log.d("Error prediction", e.toString());
// Task failed with an exception
// ...
return probabilities;
// [END mlkit_run_inference]

I hope the above helps you in getting started with ML on-device. Please do try the above and let us know if you have any feedback. We will be sharing more details in the following blog.

Haptik is hiring. Do visit our careers page.

Originally published at https://haptik.ai on November 5, 2019.




Haptik is an artificial intelligence company powering conversational assistants for brands to transform customer experiences. Follow us for cool Conversational AI insights, product updates and more.

Recommended from Medium

Quora Insincere Question Classification(Attention model with LSTM)

Explain by Example: Machine Learning

Softmax vs LogSoftmax

COVID 19 Detection From Speech Using Deep Learning

A Guide to Natural Language Processing

Should Dhivehi/Thaana Matter in the Digital Age?

Real Time Custom Object Detection

How to use deep network on sentinel-1 SAR images for change detection?

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store


Global leaders in Conversational AI, powering Intelligent Virtual Assistants (IVA) that transform Customer Experience

More from Medium

Semantic Analysis in NLP

Introducing “Idea Summarization” For Natural Language Processing


Using NLP at scale to better help people get the right job — Part 1: problem statement and…