On-Device ML using Flutter and TensorFlow Lite (pt.2): Consume your trained model in Flutter

Roman Jaquez
Flutter Community
Published in
10 min readJul 15, 2023

This is a two-part article on On-Device ML using Flutter; this article deals with the part of building the Flutter app that consumes trained model the tutorial; see part 1 of this series where I train the Tensorflow Lite model. I invite you to visit other related articles from the community, such as this one and this one. This is my take on how I’d approach it, with my own spin to it.

With our trained Tensorflow Lite ML model ready to go, let’s build a Flutter app to consume this model, to which we’ll feed the expected input (a value in Celsius) and display the expected output (a value in Fahrenheit).

The following schematics is what we’ll be accomplishing in this tutorial:

We’ll create a Flutter app that will allow users to slide thier fingers across the screen in a vertical fashion — sliding upwards increases the degrees, downwards decreases it. As values between 0 to 100 are generated, we’ll be feeding them to the ML model and getting the expected output.

Prerequisites: You must have Flutter installed on your machine before proceeding with the following steps, as well as having an IDE suitable for Flutter development (i.e. Visual Studio Code).

Package Requirements / Dependencies:

  • tflite_flutter: We’ll be using this plugin to perform inference on-device. TensorFlow Lite Flutter plugin provides a flexible and fast solution for accessing TensorFlow Lite interpreter and performing inference.
  • google_fonts: Not going with the default Flutter fonts; we want to give it some styling.
  • flutter_riverpod: We want to implement some level of state management in order to share data across widgets, have widgets listen to data changes being broadcasted — you get the gist!
  • flutter_animate: We want to add some flare to certain aspects of this app in a simple, easy, straightforward, but robust way.

Let’s proceed!

Create a new project

Let’s go ahead and create a new project:

flutter create temp_tflite_flutter

At the root of the temp_tflite_flutter project, go ahead and install the required packages:

flutter pub add flutter_animate flutter_riverpod google_fonts tflite_flutter

At the root of the project, create a folder called assets, with a subfolder called models. From the previous tutorial, add the generated .tflite model inside the models folder. Your structure should look like this:

#FOLDER STRUCTURE

/temp_tflite_flutter
/assets
/models
modelc2f.tflite
/lib
main.dart

# ... other files

In the pubspec.yaml file, add / enable the assets section, as such:

assets:
- assets/models/

Edit the main.dart

Let’s start bringing the pieces in. Let’s create a StatelessWidget called TempApp wrapped inside a ProviderScope (since we’ll be using Riverpod) which encapsulates our MaterialApp, with a home widget called TempAppMain, as such:

import 'package:flutter/material.dart';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:google_fonts/google_fonts.dart';

void main() {
runApp(
const ProviderScope(
child: TempApp()
)
);
}

class TempApp extends StatelessWidget {
const TempApp({super.key});

@override
Widget build(BuildContext context) {
return MaterialApp(
debugShowCheckedModeBanner: false,
theme: ThemeData(
textTheme: GoogleFonts.montserratTextTheme()
),
home: const TempAppMain(),
);
}
}

class TempAppMain extends StatefulWidget {
const TempAppMain({super.key});

@override
State<TempAppMain> createState() => _TempAppMainState();
}

class _TempAppMainState extends State<TempAppMain> {
@override
Widget build(BuildContext context) {

// rest of the code will go here
return Container();
}
}

With the main scaffold in place, let’s recap on what we’ll be doing.

We’ll create the following custom widgets:

  • TempSlideDetector: this will be a widget that will sit on top of all other widgets, and will serve as the surface that will interpret sliding events; this surface will broadcast information about the direction and the position value in which we’re sliding our fingers (positive values for up, negative values for down).
  • TempSlider: I decided to use a simple layout using some core widgets (Stack, Positioned, Container) to make a simple yet fancier slider widget (I know, I know — I could’ve just used the Slider widget and be done with it; but if you know me by know I’m a bit creative and love creating Flutter UIs outside the norms!)
  • TempInfoDisplay: This widget will be listening on the values generated by the TempSlider (values from 0 to 100) and display the values in Celsius, as well as show the appropriate icon depending on the range in which the current value being display falls under
  • TempMLDisplay: this will display the values generated by the Tensorflow Lite ML Model sitting in the assets folder. We’ll display it in a label that will show the Fahrenheit value generated.

Create the corresponding Riverpod Providers

We need to create some Riverpod providers that will be capturing the state in a decoupled fashion as well as broadcasting the updates to listening widgets. We need the following:

  • sliderValueProvider: type StateProvider; will keep track of the value generated by our custom slider widget; this will represent the scrolling value that our TempSlider will require to slide the bar up or down according to the corresponding values captured during sliding our fingers on the TempSlideDetector.
// add this at the bottom of the main.dart

// tracks the current value being generated by the slider
final sliderValueProvider = StateProvider<double>((ref) => 0.0);
  • tempValueProvider: type StateProvider; will keep track of the Celsius temp value (from 0 to 100) generated by the slide detector; both the TempInfoDisplay and TempMLDisplay will be listening to changes broadcasted by this provider.
// add this at the bottom of the main.dart

// tracks the temperature value generated by the slider
final tempValueProvider = StateProvider<double>((ref) => 0);
  • tfLiteProvider: type FutureProvider.family; we made it a FutureProvider since this will generate a Future since loading the .tflite model will happen asynchronously, therefore we need to wait for it to load before proceeding; and .family since it will take a parameter — the corresponding double value representing degree in Celsius, from which we want the model to give us the Fahrenheit value.
// provides the value generated by the trained ML model given the proper input
final tfLiteProvider = FutureProvider.family<double, double>((ref, arg) async {

final interpreter = await tfl.Interpreter.fromAsset('assets/models/modelc2f.tflite');
final input = [[arg]];
final output = List<double>.filled(1, 0).reshape([1,1]);


interpreter.run(input, output);
return output[0][0] as double;
});

Let’s dissect the code:

  • Using the alias to the tflite package, we call the tfl.Interpreter.fromAsset to load the TFLite model from the assets folder asynchronously, and holding it in a final property called interpreter. This an instance of the Tensorflow Lite interpreter against which we’ll perform the inference on our trained model.
  • Set up the input of the model — the data structure through which we’ll supply the values the model needs to process: a tensor of shape [1,1] (one column, one row), represented by an array with one value (the arg parameter being passed in, holding the degree value in Celsius) inside another array, and holding it in a final property called input.
  • Set up the output of the model — the data structure that we’ll use to receive what the model spits out. In our case, we are receiving a tensor of shape [1,1] (one row, one column), and holding it in a final property called output.
  • Then, call the run method on the interpreter reference, which performs the inference agains a model that takes one input, and provides one output.
  • Lastly, if the execution was successful, we extract the value from the output based on its shape (the first value out of the first array — thus [0][0], since the model will supply it as [[31.999]]

With the providers in place, let’s add the widgets that will consume them.

TempSlideDetector

This widget is nothing more than a transparent Container serving as a surface capturing gestures, hence wrapping it inside a GestureDetector and tapping into its onVerticalDragUpdate event, from which we extract the delta.dy and determine whether the drag is upwards or downwards; then proceed to generate some synthetic values representing the drag offset and a value between 0 and 100 for the degree in Celsius. We then go ahead and broadcast the generated values via their corresponding providers accordingly.

// widget to detect sliding up and down the screen
class TempSlideDetector extends ConsumerStatefulWidget {

final double initialYPosition;
const TempSlideDetector({
required this.initialYPosition,
super.key
});

@override
TempSlideDetectorState createState() => TempSlideDetectorState();
}

class TempSlideDetectorState extends ConsumerState<TempSlideDetector> {

double yPosition = 0;
double offset = 5;
double tempValue = 0;
double tempIncrement = 1;

@override
void initState() {
super.initState();

yPosition = widget.initialYPosition;
}

@override
Widget build(BuildContext context) {

return GestureDetector(
onVerticalDragUpdate: (details) {

if (details.delta.dy > 0) { // dragging dowwards
if (tempValue > 0) {
yPosition += offset;
tempValue -= tempIncrement;
}
}
else { // dragging upwards
if (tempValue < 100) {
yPosition -= offset;
tempValue += tempIncrement;
}
}

ref.read(sliderValueProvider.notifier).state = yPosition;
ref.read(tempValueProvider.notifier).state = tempValue;
},
child: Container(
color: Colors.transparent,
),
);
}
}

TempSlider

This is nothing more than a set of Container widgets: one for the bottommost background, to which we apply a color that we derive from the broadcasted tempValue, with which we generate a color from blue to red; another Container for a slightly darker gradient (for styling purposes); and lastly a Container wrapped inside a Positioned.fill widget that will serve as the sliding bar that goes up and down based on the drag offset value being listened to from the sliderValueProvider, which we position accordingly inside a Stack widget by modifying its top property.

class TempSlider extends ConsumerWidget {

final double initialYPosition;
const TempSlider({
required this.initialYPosition,
super.key});

@override
Widget build(BuildContext context, WidgetRef ref) {

var tempValue = ref.watch(tempValueProvider);
var verticalPosition = ref.watch(sliderValueProvider);
var yPosition = verticalPosition > 0 ? verticalPosition : initialYPosition;

var calculatedOpacity = (1 - (tempValue * 0.025));
calculatedOpacity = calculatedOpacity < 0 ? 0 : calculatedOpacity;

return Stack(
children: [

// background container
Container(
color: Color.fromRGBO((tempValue.toInt() + 75), 0, 255 - (tempValue.toInt() + 75), 1)
),

// background gradient
Container(
decoration: BoxDecoration(
gradient: LinearGradient(
colors: [
Colors.black.withOpacity(0.5),
Colors.black.withOpacity(0.8)
],
begin: Alignment.topCenter,
end: Alignment.bottomCenter
)
)
),

// slider container
Positioned.fill(
top: yPosition,
child: Container(
decoration: BoxDecoration(
gradient: LinearGradient(
colors: [
Color.fromRGBO((tempValue.toInt() + 75), 0, 255 - (tempValue.toInt() + 75), 1),
Color.fromRGBO((tempValue.toInt() + 75), 0, 255 - (tempValue.toInt() + 75), 0.5),
],
begin: Alignment.topCenter,
end: Alignment.bottomCenter,
),
borderRadius: const BorderRadius.only(topLeft: Radius.circular(50), topRight: Radius.circular(50))
),
child: Padding(
padding: const EdgeInsets.all(20),
child: Opacity(
opacity: calculatedOpacity,
child: Column(
mainAxisSize: MainAxisSize.min,
children: [
const Icon(Icons.swipe_vertical, size: 40, color: Colors.white)
.animate(
onComplete: (controller) {
controller.repeat();
},
).slideY(
begin: 0.5, end: 0,
duration: 1.seconds,
curve: Curves.easeInOut
).fadeIn(),
const SizedBox(height: 30),
const Text(
'Drag up and down the screen\nto change temperature', textAlign: TextAlign.center, style: TextStyle(color: Colors.white)
),
],
),
),
),
),
)
]
);
}
}

TempInfoDisplay

This is nothing more than a widged that listens to the tempValue broadcasted by the tempValueProvider, which we use to display the Celsius value, and a corresponding icon (iconWidget) based on some hard-coded ranges (cold, warm, hot — for illustration purposes)

// widget to display the value from the slider (from 0 to 100)
class TempInfoDisplay extends ConsumerWidget {
const TempInfoDisplay({super.key});

@override
Widget build(BuildContext context, WidgetRef ref) {

IconData iconWidget = Icons.ac_unit;
final tempValue = ref.watch(tempValueProvider);

// set the appropriate icon widget
if (tempValue >= 30) {
iconWidget = Icons.local_fire_department;
}
else if (tempValue >= 15 && tempValue < 30) {
iconWidget = Icons.waves;
}

return Column(
children: [
Icon(iconWidget, size: 100, color: Colors.white),
Row(
mainAxisSize: MainAxisSize.min,
children: [
const Icon(Icons.thermostat, size: 60, color: Colors.white),
Text('${tempValue.toInt()}°c',
style: const TextStyle(fontSize: 90, color: Colors.white)
),
],
),
],
);
}
}

TempMLDisplay

Now, this is the star of the show; this is the widget that displays the generated output from the Tensorflow Lite ML model. Remember we’re loading the model asynchronously and consequently feeding a value into it in Celsius so we can get the corresponding Fahrenheit from it.

We read both the tempValue (to get the Celsius value) and the tfLiteProvider (to which we need to feed the tempValue).

We use the .when from the AsyncValue obtained by the FutureProvider tfLiteProvider, and we handle all three cases — data, error and loading — and display the appropriate widgets accordingly. We pull the value obtained from the trained ML model via the data callback, perform a .round() operation to round it appropriately and show it inside a Text widget.

// widget to display the temperature value generated by the ML model
class TempMLDisplay extends ConsumerWidget {

const TempMLDisplay({super.key});

@override
Widget build(BuildContext context, WidgetRef ref) {
final tempValue = ref.watch(tempValueProvider);
final processedTempValue = ref.watch(tfLiteProvider(tempValue));

return processedTempValue.when(
data:(data) {
return Text('${data.round()}°f', style: const TextStyle(color: Colors.white, fontSize: 30));
},
error:(error, stackTrace) => Text(error.toString()),
loading:() => const CircularProgressIndicator(valueColor: AlwaysStoppedAnimation(Colors.white),),
);
}
}

Bringing it all Together

Now let’s go back up to the TempAppMain widget, and replace the whole _TempAppMainState by the following code, which brings all the pieces together inside a Stack widget. Just for kicks, we want to start the position of the bar at 200px from the bottom of the screen, so we generate an initial value using MediaQuery.sizeOf and feed this to both the TempSlider and TempSlideDetector so they can use it as an initial anchor.

class _TempAppMainState extends ConsumerState<TempAppMain> {

double yPosition = 0;
double initialValue = 0;
bool initialValueSet = false;

@override
void initState() {
super.initState();
}

@override
Widget build(BuildContext context) {

if (!initialValueSet) {
initialValue = MediaQuery.sizeOf(context).height - 200;
yPosition = initialValue;
initialValueSet = true;
}

return Scaffold(
body: Stack(
children: [
TempSlider(initialYPosition: yPosition),

const Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
crossAxisAlignment: CrossAxisAlignment.center,
children: [
TempInfoDisplay(),
TempMLDisplay()
],
),
),

TempSlideDetector(initialYPosition: yPosition)
],
)
);
}
}

After running the project, you should get the following output:

And that’s On-Device ML for you! Looking slick, and pretty responsive I might add!

Recommendations: this model is pretty simple and small, which I’m wrapping the loading and inference inside a simple FutureProvider; however, your model may perform something more complex or do some heavier computations, for which I’d recommend using Isolates for it — spin up a separate background process to handle the workload and just reply back to the main isolate with just the result of the computation, which will aleviate the UI from the burden of having to spend more cycles than it should on other activities other than what’s good at — rendering the UI.

Thank you so much for taking the time in coming along for this series; I hope you found these articles useful.

Full Code to the Project (including Google Colab File (.ipynb file), trained model file (.tflite) and Flutter app inside this Github Repo.

Clap the heck out of it to motivate me to write more and continue sharing with the community.

Cheers!

--

--

Roman Jaquez
Flutter Community

Flutter GDE / GDG Lawrence Lead Organizer / Follow me on Twitter @drcoderz — Subscribe to my YouTube Channel https://tinyurl.com/romanjustcodes