Bring Your Image Classification Model to Life with Flutter

Andrii Makarenko
Geek Culture
Published in
10 min readMar 26, 2023

Welcome to the final article in my series on Conquering Rock-Paper-Scissors with Flutter and a Custom VGG16 Model for Image Classification. I’m excited to have you here with me to walk through the process of bringing image classification model to life in a Flutter app.

Throughout this series, I’ve covered everything from data preparation and training for custom CNNs using Keras, to integrating pre-trained models for improved performance. Now, in this final article, and I’ll guide you through the process of building a user interface and integrating your model with a Flutter app for real-time prediction of hand gestures.

Series overview

Before we get started, let’s recap what you have learned in the previous articles. In the first article Creating a Winning Model with Flutter and VGG16: A Comprehensive Guide, we covered the basics of data preparation and training for a custom Convolutional Neural Network using the Keras framework. In the second article Boost Your Image Classification Model with pretrained VGG-16, we integrated a pre-trained VGG-16 model into our custom detection task, which resulted in a significant performance boost, even with a small dataset.

Now, let’s dive into the final article and see how you can bring your image classification model to life in a Flutter app. I will use the Flutter framework to build a user interface that allows the user to use the camera of the device to detect their hand gesture (rock, paper, or scissors), and then use our image classification model to predict the correct gesture.

To get started, you will need to create a new Flutter project and import the necessary dependencies. We will also need to add code to handle the camera input and display the output on the screen.

Next, you will need to integrate your image classification model into the Flutter app. You will do this by loading the model into memory and then using it to make predictions on the camera stream.

So no more theory, just code.

Model preparation

Since you created and trained your model using the Keras framework, which is built on top of TensorFlow, you can convert your model to the TensorFlow format and run it on a Flutter app.

To convert the Keras model to TensorFlow format, you can use the tf.lite.TFLiteConverter.from_keras_model() function. This function creates a converter object for your model, which can then be used to convert the model to TensorFlow format.

Additionally, you can also convert the TensorFlow model to TensorFlow Lite format for optimized performance on mobile devices. TensorFlow Lite is a lightweight version of TensorFlow that is optimized for mobile and embedded devices. It enables faster inference times and smaller model sizes, making it ideal for mobile applications.

Here’s an example of how to convert a Keras model to TensorFlow Lite format using the from_keras_model() function:

def convert(model):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('rock_paper_scissors_model.tflite', 'wb') as f:
f.write(tflite_model)

That’s it, now you have everything ready to start building your flutter app.

Flutter Part

To begin integrating your image classification model with a Flutter app, you need to create a new Flutter project and add some dependencies to our pubspec.yaml file. You will also need to create an assets folder at the root of our project and put our generated rock_paper_scissors_model.tflite file from the previous step into it.

Here are the steps to follow:

  1. Create a new Flutter project using your preferred IDE or running flutter create myapp it in the terminal.
  2. Open the pubspec.yaml file and add the following dependencies:
dependencies:
flutter:
sdk: flutter
tflite_flutter: ^0.9.0
tflite_flutter_helper: ^0.3.1
camera: ^0.9.4+5

The tflite_flutter and tflite_flutter_helper packages provide the tools for running your TensorFlow Lite model on mobile devices and the camera the package is used to access the device's camera for capturing frames from it.

3. Create an assets folder in the root of your project if it doesn’t already exist.

4. Put the rock_paper_scissors_model.tflite file that you generated in the previous step into the assets folder.

5. Do not forget to register assets in pubspec.yaml

flutter:
assets:
- assets/rock_paper_scissors_model.tflite
uses-material-design: true

After you have it done you can start working on creating UI and running the model on camera.

Create StatefulWidget widget and initialize cameras in initState method

class ScannerScreen extends StatefulWidget {
@override
_ScannerScreenState createState() => _ScannerScreenState();
}

class _ScannerScreenState extends State<ScannerScreen> {
late CameraController cameraController;

bool initialized = false;
bool isWorking = false;

@override
void initState() {
super.initState();
initialize();
}

Future<void> initialize() async {
final cameras = await availableCameras();
// Create a CameraController object
cameraController = CameraController(
cameras[0], // Choose the first camera in the list
ResolutionPreset.medium, // Choose a resolution preset
);

// Initialize the CameraController and start the camera preview
await cameraController.initialize();
// Listen for image frames
await cameraController.startImageStream((image) {
// Make predictions only if not busy
if (!isWorking) {
processCameraImage(image);
}
});

setState(() {
initialized = true;
});
}

After you should create Classifier a class that will be responsible for the classification process.

enum DetectionClasses { rock, paper, scissors, nothing }

class Classifier {
/// Instance of Interpreter
late Interpreter _interpreter;

static const String modelFile = "rock_paper_scissors_model.tflite";

/// Loads interpreter from asset
Future<void> loadModel({Interpreter? interpreter}) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
modelFile,
options: InterpreterOptions()..threads = 4,
);

_interpreter.allocateTensors();
} catch (e) {
print("Error while creating interpreter: $e");
}
}

/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
}

In loadModel the method I’m initializing Interpreter from tflite_flutter package and loading my model into it. You can also verify your model information using _interpreter . For example, you could check the input shape of your model using the following line:

_interpreter.getInputTensor(0).shape // Will return [1, 150, 150, 3]

Next, you can create an instance of Classifier in _ScannerScreenState and call loadModel into initialize method.

class _ScannerScreenState extends State<ScannerScreen> {
final classifier = Classifier();

...

Future<void> initialise() async {
await classifier.loadModel();
...

So not you could add processCameraImage the method which will take images from the camera frame by frame and do prediction based on it.

Future<void> processCameraImage(CameraImage cameraImage) async {
setState(() {
isWorking = true;
});

DetectionClasses results = await classifier.predict(convertedImage);

if (detected != result) {
setState(() {
detected = results;
});
}

setState(() {
isWorking = false;
});
}

Clarifier should convert CameraImage to list of pixels values in shape [1, 150, 150, 3] and put them into Interpreter . To do it you first should convert the camera image to package:image/image.dart instanceImage . To do it you can use the following code:

import 'package:camera/camera.dart';
import 'package:image/image.dart' as imageLib;

/// ImageUtils
class ImageUtils {
/// Converts a [CameraImage] in YUV420 format to [imageLib.Image] in RGB format
static imageLib.Image convertYUV420ToImage(CameraImage cameraImage) {
final int width = cameraImage.width;
final int height = cameraImage.height;

final int uvRowStride = cameraImage.planes[1].bytesPerRow;
final int uvPixelStride = cameraImage.planes[1].bytesPerPixel!;

final image = imageLib.Image(width, height);

for (int w = 0; w < width; w++) {
for (int h = 0; h < height; h++) {
final int uvIndex =
uvPixelStride * (w / 2).floor() + uvRowStride * (h / 2).floor();
final int index = h * width + w;

final y = cameraImage.planes[0].bytes[index];
final u = cameraImage.planes[1].bytes[uvIndex];
final v = cameraImage.planes[2].bytes[uvIndex];

image.data[index] = ImageUtils.yuv2rgb(y, u, v);
}
}
return image;
}

/// Convert a single YUV pixel to RGB
static int yuv2rgb(int y, int u, int v) {
// Convert yuv pixel to rgb
int r = (y + v * 1436 / 1024 - 179).round();
int g = (y - u * 46549 / 131072 + 44 - v * 93604 / 131072 + 91).round();
int b = (y + u * 1814 / 1024 - 227).round();

// Clipping RGB values to be inside boundaries [ 0 , 255 ]
r = r.clamp(0, 255);
g = g.clamp(0, 255);
b = b.clamp(0, 255);

return 0xff000000 |
((b << 16) & 0xff0000) |
((g << 8) & 0xff00) |
(r & 0xff);
}
}

After add predict method to Classifier .

import 'dart:typed_data';

import 'package:image/image.dart' as img;
import 'package:rock_paper_scissors_mobile/classes.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {

...

Future<DetectionClasses> predict(img.Image image) async {
img.Image resizedImage = img.copyResize(image, width: 150, height: 150);

// Convert the resized image to a 1D Float32List.
Float32List inputBytes = Float32List(1 * 150 * 150 * 3);
int pixelIndex = 0;
for (int y = 0; y < resizedImage.height; y++) {
for (int x = 0; x < resizedImage.width; x++) {
int pixel = resizedImage.getPixel(x, y);
inputBytes[pixelIndex++] = img.getRed(pixel) / 127.5 - 1.0;
inputBytes[pixelIndex++] = img.getGreen(pixel) / 127.5 - 1.0;
inputBytes[pixelIndex++] = img.getBlue(pixel) / 127.5 - 1.0;
}
}

// Reshape to input format specific for model. 1 item in list with pixels 150x150 and 3 layers for RGB
final input = inputBytes.reshape([1, 150, 150, 3]);

// Output container
final output = Float32List(1 * 4).reshape([1, 4]);

// Run data throught model
interpreter.run(input, output);

// Get index of maxumum value from outout data. Remember that models output means:
// Index 0 - rock, 1 - paper, 2 - scissor, 3 - nothing.
final predictionResult = output[0] as List<double>;
double maxElement = predictionResult.reduce(
(double maxElement, double element) =>
element > maxElement ? element : maxElement,
);
return DetectionClasses.values[predictionResult.indexOf(maxElement)];
}
}

Now you can build Ui for your app:

class _ScannerScreenState extends State<ScannerScreen> {

...

@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: const Text('Flutter Camera Demo'),
),
body: initialized
? Column(
children: [
SizedBox(
height: MediaQuery.of(context).size.width,
width: MediaQuery.of(context).size.width,
child: CameraPreview(cameraController),
),
Text(
"Detected: ${detected.label}",
style: const TextStyle(
fontSize: 28,
color: Colors.blue,
),
),
],
)
: const Center(child: CircularProgressIndicator()),
);
}

That’s it. But if you will run it you probably will notice that app is lagging. It’s because the running process is complex and resource-intensive in UI-Isolate. To have it work smoothly and fast you need to consider moving the calculation process to separate isolate.

Calculation isolate

First of all create IsolateUtils using the following code:

/// Bundles data to pass between Isolate
class IsolateData {
CameraImage cameraImage;
int interpreterAddress;
SendPort responsePort;

IsolateData({
required this.cameraImage,
required this.interpreterAddress,
required this.responsePort,
});
}

class IsolateUtils {
static const String DEBUG_NAME = "InferenceIsolate";

late Isolate _isolate;
final ReceivePort _receivePort = ReceivePort();
late SendPort _sendPort;

SendPort get sendPort => _sendPort;

Future<void> start() async {
_isolate = await Isolate.spawn<SendPort>(
entryPoint,
_receivePort.sendPort,
debugName: DEBUG_NAME,
);

_sendPort = await _receivePort.first;
}

static void entryPoint(SendPort sendPort) async {
final port = ReceivePort();
sendPort.send(port.sendPort);

await for (final IsolateData isolateData in port) {
Classifier classifier = Classifier();
// Restore interpreter from main isolate
await classifier.loadModel(interpreter: Interpreter.fromAddress(isolateData.interpreterAddress));

final convertedImage = ImageUtils.convertYUV420ToImage(isolateData.cameraImage);
DetectionClasses results = await classifier.predict(convertedImage);
isolateData.responsePort.send(results);
}
}

void dispose() {
_isolate.kill();
}
}

It’s gets IsolateData from the port and initialize Classifier with Interpreter form main isolate using his address Interpreter.fromAddress(isolateData.interpreterAddress)

After it’s done the same preprocessing with CameraImage as before and puts it to Classifier . DetectionClasses the result sends back to the main isolate using isolateData.responsePort.send(results); .

After you have it done, you have to provide a few improvements to your _ScannerScreenState code.
Create IsolateUtils instance and call isolateUtils.start() in initialize method.

Add inference method

Future<DetectionClasses> inference(CameraImage cameraImage) async {
ReceivePort responsePort = ReceivePort();
final isolateData = IsolateData(
cameraImage: cameraImage,
interpreterAddress: classifier.interpreter.address,
responsePort: responsePort.sendPort,
);

isolateUtils.sendPort.send(isolateData);
var result = await responsePort.first;

return result;
}

and call it in processCameraImage method:

Future<void> processCameraImage(CameraImage cameraImage) async {
setState(() {
isWorking = true;
});

final result = await inference(cameraImage);

if (detected != result) {
setState(() {
detected = result;
});
}

setState(() {
lastShot = DateTime.now();
isWorking = false;
});
}

The full code of _ScannerScreenState is following:

class ScannerScreen extends StatefulWidget {
@override
_ScannerScreenState createState() => _ScannerScreenState();
}

class _ScannerScreenState extends State<ScannerScreen> {
late CameraController cameraController;
late Interpreter interpreter;
final classifier = Classifier();
final isolateUtils = IsolateUtils();

bool initialized = false;
bool isWorking = false;
DetectionClasses detected = DetectionClasses.nothing;

@override
void initState() {
super.initState();
initialize();
}

Future<void> initialize() async {
// Load main isolate Interpreter
await classifier.loadModel();

final cameras = await availableCameras();
// Create a CameraController object
cameraController = CameraController(
cameras[0], // Choose the first camera in the list
ResolutionPreset.medium, // Choose a resolution preset
);

// Start Inference isolate
await isolateUtils.start();

// Initialize the CameraController and start the camera preview
await cameraController.initialize();
// Listen for image frames
await cameraController.startImageStream((image) {
// Make predictions only if not busy
if (!isWorking) {
processCameraImage(image);
}
});

setState(() {
initialized = true;
});
}

Future<void> processCameraImage(CameraImage cameraImage) async {
setState(() {
isWorking = true;
});

final result = await inference(cameraImage);

if (detected != result) {
detected = result;
}

setState(() {
isWorking = false;
});
}

@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: const Text('Flutter Camera Demo'),
),
body: initialized
? Column(
children: [
SizedBox(
height: MediaQuery.of(context).size.width,
width: MediaQuery.of(context).size.width,
child: CameraPreview(cameraController),
),
Text(
"Detected: ${detected.label}",
style: const TextStyle(
fontSize: 28,
color: Colors.blue,
),
),
],
)
: const Center(child: CircularProgressIndicator()),
);
}

Future<DetectionClasses> inference(CameraImage cameraImage) async {
ReceivePort responsePort = ReceivePort();
final isolateData = IsolateData(
cameraImage: cameraImage,
interpreterAddress: classifier.interpreter.address,
responsePort: responsePort.sendPort,
);

isolateUtils.sendPort.send(isolateData);
var result = await responsePort.first;

return result;
}

@override
void dispose() {
cameraController.dispose();
isolateUtils.dispose();
super.dispose();
}
}

By doing these steps you should have the following result:

Summary

In this final article, you have successfully integrated your custom image classification model into a Flutter app. By combining the power of machine learning with the user-friendly Flutter framework, you have created a real-time hand gesture detection app that is accurate, responsive, and easy to use.

Throughout this series, I have covered a range of topics, from the basics of building a custom CNN model to the more advanced techniques of transfer learning and model optimization. We have also explored the challenges of integrating machine learning models into mobile apps and provided practical solutions to overcome them.

By following this series, you now have the skills and knowledge to build your own custom image classification models, optimize them for performance, and deploy them in a mobile app using Flutter and TensorFlow Lite. I hope you have found this series useful and informative.

If you have any questions or feedback, feel free to leave a comment below or check out the code for this project on my GitHub repository:

Thank you for reading and see you soon :P

--

--

Andrii Makarenko
Geek Culture

Mobile software engineer. Passionate about Android & Flutter dev. Curious about ML. Taking first steps in career as a tech lead.