This tutorial covers how to train a model from scratch with tf.Keras Sequential API, convert the tf.Keras model to tflite format, and run the model on Android. I will walk through an example with the MNIST data for image classification, and share some of the common issues you may face. This tutorial focuses on the end to end experience, and I will not go in-depth with deep learning, the various tf.Keras APIs or Android development.
Download my sample code and follow along:
- Run in Colab - Training model with tf.Keras and convert Keras model to TFLite (link to Colab notebook) .
- Run in Android Studio - DigitRecognizer (link to Android app).
1. Train an custom classifier
Load the data
We will use the MNST data which is available as part of the tf.Keras framework.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
Next we will reshape the input image from 28x28 to 28x28x1, normalize it and one-hot encode the labels.
Define Model architecture
Then we will define the network architecture with CNN.
# Define the model architecture
model = keras.models.Sequential([
# Must define the input shape in the first layer of the neural network
keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(28,28,1)),
keras.layers.Dropout(0.3),keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
# Compile the model
Then we train the model with model.fit().
2. Model Saving and Conversion
After training we will save a Keras model and convert it to TFLite format.
Save a Keras Model
Here is how to save a Keras Model -
# Save tf.keras model in HDF5 format
keras_model = "mnist_keras_model.h5"
Convert Keras model to tflite
There are two options when using TFLite converter to convert the Keras model to the tflite format - 1) from the command line or 2) convert directly in your python code which is recommended.
1) Conversion via command line
$ tflite_convert \
$ --output_file=mymodel.tflite \
2) Conversion via Python code
This is the prefer method for conversion, if you have access to the model training code.
# Convert the model
tflite_model = converter.convert()# Create the tflite model file
tflite_model_name = "mymodel.tflite"
You can set post training quantize to true for the converter
# Set quantize to true
Validate the Converted Model
After converting Keras model to tflite format, it’s important to validate that it is performing on par as your original Keras model. See below Python code snippets on how to run inference with your tflite model. The example input is random input data and you will need to update it for your own data.
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") interpreter.allocate_tensors()# Get input and output tensors
input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()# Test model on random input data
input_shape = input_details['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details['index'], input_data)
output_data = interpreter.get_tensor(output_details['index']) print(output_data)
Protip: make sure to always test the tflite model after conversion and before putting it on Android. Otherwise when it’s not working on your Android app, it’s unclear whether the problem is with your Android code or with the ML model.
3. Implement the tflite model on Android
Now we are ready to implement the tflite model on Android. Create a new Android project and follow these steps
- Place mnist.tflite model under assets folder.
- Update build.gradle to include tflite dependency.
- Create a CustomView for user to draw digits.
- Create a Classifier that does digit classification
- Input image from custom view
- Preprocess the image
- Classify image with model
- Post processing
- Display result in UI
The Classifier class is where most of the ML magic happens. Make sure the dimensions you set in the class match what the model is expecting:
- image shape of 28x28x1
- 10 classes for the 10 digits: 0, 1, 2, 3…9
To classify an image, follow these steps:
- Pre-process the input image. Convert Bitmap to ByteBuffer and convert the pixels to grayscale since MNIST datasets are grayscale.
- Run inference with the interpreter which was created by memory map the model file under the assets folder.
- Post-process the output result for display in UI. The result we get back has 10 probabilities and we will pick the digit with the highest probability to display in the UI.
Challenges throughout the process
Here are the challenges you might encountered:
- During tflite conversion if you get an error that “an operation is not supported by tflite”, you should request TensorFlow team to add the operation or create the custom operator yourself. See the list of supported operations here.
- Sometimes the conversion seems to be successful but the converted model turns out to not to be working: for example after conversion the classifier may be classifying randomly with ~0.5 accuracy on either positive or negative test. (I encountered that bug in tf 1.10 and it was later fixed in tf 1.12).
If the Android app crashes, look at the stacktrace errors from Logcat:
- make sure the input image size and color channel is set correctly to match the input tensor size that he model is expecting.
- make sure in build.gradle aaptOptions is set to not compress the tflite file.
Overall training a simple image classifier with tf.Keras is a breeze, saving and converting the Keras model into tflite is fairly easy too. At the moment how we implement the tflite model on Android is still a bit tedious, which hopefully gets improved in the future.