Custom TensorFlow Lite model implementation in Android

Dheeraj Kumar
Walmart Global Tech Blog
6 min readMay 17, 2024

Nowadays everyone is looking to empower their apps with machine learning. One way to add machine learning capabilities in mobile apps is by using ML(machine learning) models. These models are already pre-trained on data and are ready for use. In general, we use tflite (Tensorflow Lite) models in Android and coreML models in iOS. In this blog we will explore how tflite model can be implemented on Android platform.

Prerequisites:

  • Basic understanding of Android using kotlin
  • Android sample project
  • A tflite model

Implementation

Let’s start with the implementation. Following are the steps to implement a tflite model in Android.

Step 1: Add tflite model to assets folder

To demonstrate the implementation we will be using the following tflite model that can be downloaded from the tensorflow hub.

https://tfhub.dev/sayakpaul/lite-model/cartoongan/fp16/1

This tflite model converts an image to a cartoon image. We will be using this model for running the inference. The term inference refers to the process of executing a TensorFlow Lite model on-device in order to make predictions based on input data.

Every tflite model has some defined format of input and output. We should know these formats prior to using any tflite model. The tflite model used in this example also requires the input to be in a defined format and produces the output in a defined format.

As tflite model used in this blog processes the image, we should know the shape of the input and output before running the inference. The term inference means process of executing a tflite model on device to make predictions or get output based on input data.

According to the docs this model was quantized using float16 quantization. This model takes fixed-shaped 224*224(width * height) input images with BGR channel ordering. So while running inference on this model we need to provide 224*224 size image as input.

Quantization is a conversion technique that can reduce model size while also improving CPU and hardware accelerator latency, with little degradation in model accuracy. This can be performed using the Tensorflow-Lite Converter.

You can run the following python script to find the input and the output shape of the tflite model. To run this script, python and tensflow libraries should be pre-installed on your machine.

tflite-model-shape-checker.py

import tensorflow as tf
TFLITE_PATH = "./lite_model_cartoongan.tflite" // use absolute path here for tflite model.interpreter = tf.compat.v1.lite.Interpreter(model_path=TFLITE_PATH)
interpreter.allocate_tensors()
# printing the result
print("Input shape :\n" , interpreter.get_input_details())
print("Output shape: \n", interpreter.get_output_details())

You can run this script with the following command:

python3 tflite-model-shape-checker.py

Output:

Input shape :
[{'name': 'input_photo', 'index': 0, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'shape_signature': array([ 1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
Output shape:
[{'name': 'final_output', 'index': 108, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'shape_signature': array([ 1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

Step 2: Adding Dependencies

Following dependencies are required to run inference on custom tflite model. Add these dependencies to your app (module level) build.gradle file.

implementation("org.tensorflow:tensorflow-lite:2.9.0")
implementation("org.tensorflow:tensorflow-lite-task-vision:0.3.1")
implementation("org.tensorflow:tensorflow-lite-gpu:2.9.0")

Step 3: Loading tflite model

In Android we use Interpreter class to load the model and to run the inference. Use the following code snippet to create Interpreter class object

val TFLITE_MODEL_NAME = "lite_model_cartoongan.tflite"//should be in assets folder
private fun createInterpreter() {
val tfLiteOptions = Interpreter.Options()//can be configure to use GPUDelegate
interpreter = Interpreter(FileUtil.loadMappedFile(context, TFLITE_MODEL_NAME), tfLiteOptions)
}

Here, you can configure tflite options to run the inference on GPU Delegate but this tflite model which we are using in this example doesn’t support GPU Delegate. So we are not configuring GPU Delegate here. GPU Delegate helps in improving the performance by reducing inference time.

Step 4: Preparing input for inference

Interpreter class takes ByteBuffer as input. So first we need to convert bitmap image into bytebuffer. As we already know this model has [1, 224, 224, 3] input shape. This shape array defines how we can feed input to model while running inference.
- 0th index represents number of inputs it can process at one time.
- 1st and 2nd index represent bytebuffer array of 224*224 size.
- 3rd index represents RGB channel values for one pixel.

Now, lets convert bitmap into bytebuffer. Following method will be taking bitmap, width and height as input parameters and converts that bitmap into required input shape [1, 224, 224, 3] bytebuffer.

private fun getInputImage(width: Int, height: Int): ByteBuffer {
val inputImage =
ByteBuffer.allocateDirect(1 * width * height * 3 * 4)// input image will be required input shape of tflite model
inputImage.order(ByteOrder.nativeOrder())
inputImage.rewind()
return inputImage
}
private fun convertBitmapToByteBuffer(bitmapIn: Bitmap, width: Int, height: Int): ByteBuffer {
val bitmap = Bitmap.createScaledBitmap(bitmap, width, height, false) // convert bitmap into required size
// these value can be different for each channel if they are not then you may have single value instead of an array
val mean = arrayOf(127.5f, 127.5f, 127.5f)
val standard = arrayOf(127.5f, 127.5f, 127.5f)
val inputImage = getInputImage(width, height)
val intValues = IntArray(width * height)
bitmap.getPixels(intValues, 0, width, 0, 0, width, height)
for (y in 0 until width) {
for (x in 0 until height) {
val px = bitmap.getPixel(x, y)
// Get channel values from the pixel value.
val r = Color.red(px)
val g = Color.green(px)
val b = Color.blue(px)
// Normalize channel values to [-1.0, 1.0]. This requirement depends on the model.
// For example, some models might require values to be normalized to the range
// [0.0, 1.0] instead.
val rf = (r - mean[0]) / standard[0]
val gf = (g - mean[0]) / standard[0]
val bf = (b - mean[0]) / standard[0]
//putting in BRG order because this model demands input in this order
inputImage.putFloat(bf)
inputImage.putFloat(rf)
inputImage.putFloat(gf)
}
}
return inputImage
}

We need to normalize RGB channel values into [0.0, 1.0] or [-1.0. 1.0] range by mean and standard values. The mean and standard values are/should be provided by model provider itself. The ML model we are using in this blog has not defined mean and standard values. Hence we are using the default mean and standard values as 127.5f.

Step 5: Running inference

We will be using this bytebuffer to run inference. Also while running we have to provide output array as input params to run method of Interpreter class. The size of this output array should be same as output shape of the tflite model.

private fun runInference(bitmap: Bitmap): Array<Array<Array<FloatArray>>> {
val outputArr = Array(1) {
Array(224) {
Array(224) {
FloatArray(3)
}
}
}
val byteBuffer = convertBitmapToByteBuffer(bitmap, 224, 224)
interpreter?.run(byteBuffer, outputArr)
return outputArr
}

Step 6: Interpreting Output

Now, the output of the inference must be interpreted into meaningful ways. As mentioned in step 1, this tflite model is returning image back with some effects, lets convert this output array back to an image.

private fun convertOutputArrayToImage(inferenceResult: Array<Array<Array<FloatArray>>>): Bitmap {
val output = inferenceResult[0]
val bitmap = Bitmap.createBitmap(224, 224, Bitmap.Config.ARGB_8888)
val pixels = IntArray(224 * 224)
        var index = 0
for (y in 0 until 224) {
for (x in 0 until 224) {
val b = (output[y][x][0] + 1) * 127.5
val r = (output[y][x][1] + 1) * 127.5
val g = (output[y][x][2] + 1) * 127.5
val a = 0xFF
pixels[index] = a shl 24 or (r.toInt() shl 16) or (g.toInt() shl 8) or b.toInt()
index++
}
}
bitmap.setPixels(pixels, 0, 224, 0, 0, 224, 224)
return bitmap
}

In some cases we may need to normalize the channel values to get the correct output. The process of normalizing channel values are predefined by the original model provider. In our case, normalizing process is defined in the Metadata section of this link.

Great. Now run your app and let’s see the output

Input Image: This is the input image which we are feeding to model.

Output Image: this is output image after running the inference. This tflite model has converted a normal image to a cartoon image.

Git repository url for this sample app:

By following these steps you can implement any tflite model in your Android app.

Following are some of the examples where custom tflite models can help in solving the problem

  1. Human face and body detection
  2. Image recognition
  3. Gesture recognition
  4. Object detection (vehicles, animals, fruits etc.)

Thank you for reading. Happy coding 🥳 !!

References:

https://www.tensorflow.org/lite/guide/inference
https://www.tensorflow.org/lite/performance/gpu
https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview

--

--