E2E tf.Keras to TFLite to Android

Margaret Maynard-Reid

This tutorial covers how to train a model from scratch with tf.Keras Sequential API, convert the tf.Keras model to tflite format, and run the model on Android. I will walk through an example with the MNIST data for image classification, and share some of the common issues you may face. This tutorial focuses on the end to end experience, and I will not go in-depth with deep learning, the various tf.Keras APIs or Android development.

Download my sample code and follow along:

  • Run in Colab - Training model with tf.Keras and convert Keras model to TFLite (link to Colab notebook) .
  • Run in Android Studio - DigitRecognizer (link to Android app).

1. Train an custom classifier

Load the data

We will use the MNST data which is available as part of the tf.Keras framework.

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

Preprocess data

Next we will reshape the input image from 28x28 to 28x28x1, normalize it and one-hot encode the labels.

Define Model architecture

Then we will define the network architecture with CNN.

def create_model():

# Define the model architecture
model = keras.models.Sequential([
# Must define the input shape in the first layer of the neural network
keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(28,28,1)),
keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')

# Compile the model

return model

Train Model

Then we train the model with model.fit().

validation_data=(x_test, y_test))

2. Model Saving and Conversion

After training we will save a Keras model and convert it to TFLite format.

Save a Keras Model

Here is how to save a Keras Model -

# Save tf.keras model in HDF5 format
keras_model = "mnist_keras_model.h5"
keras.models.save_model(model, keras_model)

Convert Keras model to tflite

There are two options when using TFLite converter to convert the Keras model to the tflite format - 1) from the command line or 2) convert directly in your python code which is recommended.

1) Conversion via command line

$ tflite_convert \ 
$ --output_file=mymodel.tflite \
$ --keras_model_file=mymodel.h5

2) Conversion via Python code

This is the prefer method for conversion, if you have access to the model training code.

# Convert the model
tflite_model = converter.convert()
# Create the tflite model file
tflite_model_name = "mymodel.tflite"
open(tflite_model_name, "wb").write(tflite_model)

You can set post training quantize to true for the converter

# Set quantize to true 

Validate the Converted Model

After converting Keras model to tflite format, it’s important to validate that it is performing on par as your original Keras model. See below Python code snippets on how to run inference with your tflite model. The example input is random input data and you will need to update it for your own data.

# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()
# Test model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], input_data)
output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)

Protip: make sure to always test the tflite model after conversion and before putting it on Android. Otherwise when it’s not working on your Android app, it’s unclear whether the problem is with your Android code or with the ML model.

3. Implement the tflite model on Android

Now we are ready to implement the tflite model on Android. Create a new Android project and follow these steps

  1. Place mnist.tflite model under assets folder.
  2. Update build.gradle to include tflite dependency.
  3. Create a CustomView for user to draw digits.
  4. Create a Classifier that does digit classification
  5. Input image from custom view
  6. Preprocess the image
  7. Classify image with model
  8. Post processing
  9. Display result in UI

The Classifier class is where most of the ML magic happens. Make sure the dimensions you set in the class match what the model is expecting:

  • image shape of 28x28x1
  • 10 classes for the 10 digits: 0, 1, 2, 3…9

To classify an image, follow these steps:

  • Pre-process the input image. Convert Bitmap to ByteBuffer and convert the pixels to grayscale since MNIST datasets are grayscale.
  • Run inference with the interpreter which was created by memory map the model file under the assets folder.
  • Post-process the output result for display in UI. The result we get back has 10 probabilities and we will pick the digit with the highest probability to display in the UI.

Challenges throughout the process

Here are the challenges you might encountered:

  • During tflite conversion if you get an error that “an operation is not supported by tflite”, you should request TensorFlow team to add the operation or create the custom operator yourself. See the list of supported operations here.
  • Sometimes the conversion seems to be successful but the converted model turns out to not to be working: for example after conversion the classifier may be classifying randomly with ~0.5 accuracy on either positive or negative test. (I encountered that bug in tf 1.10 and it was later fixed in tf 1.12).

If the Android app crashes, look at the stacktrace errors from Logcat:

  • make sure the input image size and color channel is set correctly to match the input tensor size that he model is expecting.
  • make sure in build.gradle aaptOptions is set to not compress the tflite file.
aaptOptions {
noCompress "tflite"

Overall training a simple image classifier with tf.Keras is a breeze, saving and converting the Keras model into tflite is fairly easy too. At the moment how we implement the tflite model on Android is still a bit tedious, which hopefully gets improved in the future.

Margaret Maynard-Reid

Written by

Google Developer Expert for ML | TensorFlow & Android

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade