Image Classification App using TensorFlow Lite & Flutter

On-device Machine Learning using Deep Learning Neural Networks in TensorFlow and Keras (Step By Step Process)

SAGNIK GHOSH
MLSAKIIT
12 min readAug 27, 2021

--

Machine learning and Artificial Intelligence take the development of mobile applications to a new level. Apps that use machine learning can detect speech, images, and body language. AI gives us new and compelling ways to engage and connect with people in the world around us. But how do we integrate machine learning into our mobile applications?

Developing mobile applications that include machine learning has been a difficult task for a long time. But dev platforms and tools like Firebase’s ML, and TensorFlow Lite have made it a piece of cake. These tools provide users with pre-trained machine learning models as well as tools for training and importing traditional models. But how do we develop a compelling experience on top of those machine learning models?
That’s where Flutter enters.

Flutter SDK is a portable UI toolkit developed by Google and its open community to improve Android, IOS, Web, and Desktop applications. At its core, Flutter integrates an efficient graphics engine with the Dart programming language. Using Flutter, we can build mobile applications with machine learning capabilities such as image classification and object detection, on both Android and iOS platforms.

In this article, we will combine the power of Flutter and on-device ML to develop a Flutter application that can detect animals.

Let’s have a quick look at what we are gonna build today:

Get the Dataset

I have downloaded the dataset from here.
Just download the dataset and choose your favourite classes(in this case, animals) and delete the remaining folders. You can download the dataset from here.
Also, you can use any other dataset if you wish to.

This article will be divided into 2 modules:

  • Training a CNN(Convolutional Neural Network)
  • Creating Flutter UI and Importing the .tflite file(ML Model) into the Flutter project

The entire, completed project can be found and cloned from my GitHub Repo → https://git.io/JEqvP

Module 1: Training Neural Networks

There are two ways to train our model:-

  1. Teachable Machine (if you don’t want to build Tensorflow models on your own)
  2. TensorFlow and Keras (if you have a little bit of knowledge about Deep Learning)

Method 1: Using Teachable Machine

What is a Teachable Machine? Teachable Machine is a web-based tool that makes creating machine learning models fast, easy, and accessible to everyone. It can be used to recognize images, sounds ,or poses.

Now, let’s see how we can create our machine learning model using Teachable Machine.

First, go to Teachable Machine and open Image Project. To get started with training the model we will need to create five classes (as per my dataset), namely “Elephant”, “Kangaroo”, “Panda”, “Penguin” and “Tiger” and upload training images to build a model.

Teachable Machine

Replace Class 1 name with Elephant and click on Upload to upload the training images for the cats.

Click on the upload button and choose images from your folder or drag & drop images

Now repeat the process for other classes. Change Class 2 to Kangaroo and upload the training images for kangaroos, change Class 3to Panda and upload the training images for pandas and do the same process for the rest animals.

Click on the Train Model button after uploading all the images

It’s gonna take time if you have uploaded a lot of training images, so sit back and enjoy your coffee :)

After your model is trained, click on the export model and download the Tensorflow Lite Floating Point Model.

Exporting Tensorflow Lite Floating Point Model

Warning!!! In case your model can’t recognize our cute penguin

You can find my model.tflite & labels.txt files directly from here.

Method 2: Using TensorFlow and Keras

What is TensorFlow? TensorFlow is an open-source artificial intelligence library, using data flow graphs to build models. It allows developers to create large-scale neural networks with many layers. TensorFlow is mainly used for: Classification, Perception, Understanding, Discovering, Prediction and Creation.

Let's get our hands dirty by writing some code:

Let’s start by opening Jupyter Notebook (or Google Colab):

Code Cell 1: Importing Libraries & Training Modules
  • The os module will provide us functions for fetching contents of and writing to a directory.
  • Set the base_dir variable to the location of the dataset containing the training images.
Code Cell 2: Preprocessing (format images before they are used by model training and inference)
  • IMAGE_SIZE = 224 → the image size that we are going to set the images the dataset to.
  • BATCH_SIZE = 64 → the number of images we are inputting into the neural network at once.
  • rescale=1./255 reduces the file size, to reduce the training time.
  • Datasets have a Test set and a Training set. The training set is to train our model and the test(validation) set is to measure how accurate our model is. So with validation_split=0.2, we are telling Keras to use 80% for training and 20% for accuracy testing (validation).
  • Then, we have two generators (train_generator and val_generator), which take the path to the directory & generate batches of augmented data, which in this case give the output: Found 2872 images belonging to 36 classes and Found 709 images belonging to 36 classes.
Code Cell 3: creating a labels.txt file that will hold all our labels(names of animals)
  • Print all keys and classes (labels) of the dataset to re-check if everything is working fine.
  • Flutter requires two files: model.tflite and labels.txt .
  • The ‘w’ in the code creates a new file called labels.txt having the labels(names of animals), which if already exists, then overwrites it.
Now we will use MobileNetV2, which is a convolutional neural network architecture that seeks to perform well on mobile devices. It is based on an inverted residual structure where the residual connections are between the bottleneck layers.
Code Cell 4: Creating a base model for Transfer Learning
  • In our case, the fully connected output layers of the model used to make predictions are not loaded, allowing a new output layer to be added and trained. So, we are setting the include_top argument to False.
Code Cell 5: Adding Hidden Layers to Neural Networks
  • base_model.trainable=False to freeze all the weights before compiling the model.
  • Now we will add our hidden layers:
  1. Convo2D is a 2D convolution layer that creates a convolution kernel that is a wind with layers input which helps produce a tensor of outputs. It is trying to understand the image’s patterns.
    relu stands for rectified linear unit activation function. It is a piecewise linear function that will output the input directly if it is positive, otherwise, it will output zero.
  2. Dropout layer prevents Neural Networks from Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
    ▹ The Dropout layer randomly sets input units to 0 with a frequency of rate at each step during training time, which helps prevent overfitting.
    Note that the Dropout layer only applies when training is set to True such that no values are dropped during inference.
  3. GlobalAveragePooling2D layer calculates the average output of each feature map in the previous layer, thus reducing the data significantly and preparing the model for the final layer.
    The 2D Global average pooling block takes a tensor of size (input width) x (input height) x (input channels) and computes the average value of all values across the entire (input width) x (input height) matrix for each of the (input channels).
  4. Dense layer is a deeply connected layer in which each neuron receives input from all neurons of its previous layer. 5’ here stands for the number of classes (here types of animals).
    softmax converts a real vector to a vector of categorical probabilities.
Code Cell 6: Compiling the model
  • Before training the model we need to compile it and define the loss function, optimizers, and metrics for prediction. So, we use model.compilewhich defines the loss function, the optimizer, and the metrics, because a compiled model is needed to train (since training uses the loss function and the optimizer).
  • We will use Adam which is a popular optimizer, designed specifically for training deep neural networks. Adam is a replacement optimization algorithm for stochastic gradient descent for training deep learning models.
Code Cell 7: Training
  • Epochs → An epoch means training the neural network with all the training data for one cycle. In an epoch, we use all of the data exactly once. A forward pass and a backward pass together are counted as one pass: An epoch is made up of one or more batches, where we use a part of the dataset to train the neural network.
    Higher the number, the more accurate the neural network, but having the number too high could cause Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
Code Cell 7: Output of training process

It’s gonna take time if you have uploaded a lot of training images, so sit back and wait for the model to get trained.

Now we have to convert our Neural Network Model to a .tflite file which we can use in our Flutter App.
Code Cell 8: Converting the Trained Neural Network Model into a Tensorflow Lite file
If you are using Google Colab, then at first upload the dataset.zip file to drive, mount the drive, extract files using colab and use it. Finally you can download the model.tflite and labels.txt file by using following codes:
from google.colab import files
files.download(‘model.tflite’)
files.download(‘labels.txt’)
  • You can find my Jupyter Notebook from here and Google Colab file from here.

Warning!!! In case your model can’t recognize our cute panda

You can find my model.tflite & labels.txt files directly from here.

Module 2: Importing and using TensorFlow Lite file in our Flutter app

Open terminal, then navigate to your project directory and run the command ‘flutter run project_name’ or if you are using Visual Studio Code, then open command palette from ‘view’ of topbar (or by pressing ctrl+shift+P) and choose the option for flutter to create a new app and write the project name and hit Enter.

Let’s get our hands dirty by writing some code once again:

Next, head over to the ‘pubspec.yaml’, add the following dependencies, and save (may be you have to run 'flutter pub get’ command, by this command flutter saves the concrete package version found in the pubspec.lock lockfile):

dependencies:
flutter:
sdk: flutter
tflite: ^1.1.2
image_picker: ^0.8.3+2

For tflite to work, in android/app/build.gradle, set minSdkVersion to 19 and add the following setting in android block.

aaptOptions {
noCompress 'tflite'
noCompress 'lite'
}
android block in android/app/build.gradle path

For image_picker to work, in /ios/Runner/Info.plist, add the following to your Info.plist file.

<key>NSCameraUsageDescription</key>
<string>Need Camera Access</string>
<key>NSMicrophoneUsageDescription</key>
<string>Need Microphone Access</string>
<key>NSPhotoLibraryUsageDescription</key>
<string>Need Gallery Accesss</string>

Create a folder named “assets” and place the model.tflite & labels.txt files within it.
Then add their existency in pubspec.yaml file like following.

Then save the file (may be you have to run ‘flutter pub get’ command, by this command flutter saves the concrete package version found in the pubspec.lock lockfile) like previous.

Now it’s time to create the UI and Functions of our Flutter App.

In the ‘main.dart’ file, return MaterialApp that has a parameter home: Home()

Code: main.dart

Then create a new ‘home.dart’ file having the Stateful class Home(). This will be our homepage. Let’s start making the Functional Flutter App by importing the necessary packages and creating the functions:

Code: lib/home.dart at initial phase
  • _loading → used to check if an image has been chosen or not
  • _image → image that is chosen from gallery or camera
  • _output → prediction made by TensorFlow Model
  • picker → allows us to pick an image from gallery or camera.

Next, we will write 6 different methods for the class:

Code: First two methods in lib/home.dart
  • The first 2 methods :
  1. initState() → This is the first method that is called when the Home widget is created i.e we the app is launched and navigated to Home(), before actually building the widget itself, anything inside initState() function will be called or initialized first and the widgets are built later. In this case, we will load our model using loadModel(), which is another method that will be written later. After that, we will pass in a value.
  2. dispose()This method disposes and clears our memory.
Code: Last four methods in lib/home.dart
  • The last 4 methods:

3. classifyImage() → this method runs the classification model on the image. The numResults is the number of classes (here the number of animals) we have, then adding setState to save changes.

4. loadModel() → this function will load our model, hence we put it inside the initS method.

5. pickImage() → this function is used to grab the image from the camera.

6. pickGalleryImage() → this function is used to grab the image from the user’s gallery.

Let’s create the AppBar:

Code: AppBar of homepage in lib/home.dart

Now it’s time for the body part of our homepage. Let’s make a container to hold an image that the user has selected.

Code: Container part of homepage in lib/home.dart

We have used ClipRRect to give nice circular borders to the image.

Code: Output display of homepage in lib/home.dart

Next, we are going to make two GestureDetectors, that onTap: refer to the pickImage and the pickGalleryImage function respectively.

  • NOTE: pickImage (without parenthesis) inside onTap is a function reference, which basically means it is not executed immediately, it is executed after the user clicks on the specific widget which is known as a callback.
  • pickImage() is a function call and it is executed immediately.
Code: Two Gesture Detectors of homepage in lib/home.dart

Done!

Now run the project:
▹ In the terminal run the command flutter run
▹ Or if you are using VS Code, then from TopBar click on Run Without Debugging (ctrl+F5)

Hope your app will not crash during build and no error will come.

Open the app and take photos or select images from Gallery and predict 🤩

Now you can feel the power of AI 😉

Once you get the hang of it, you can see how easy is it to use TensorFlow Lite with Flutter to develop proof-of-concept machine learning mobile applications. It’s incredible to make this very powerful app with this small piece of code. To improve your knowledge, you can visit the Kaggle site and download various datasets to develop different classification models and use those in your own app.

Thanks for reading, I hope you learnt something! Any comments, doubts or suggestions are highly valuable to me. You can reach me out by:

Github
Linkedin
Instagram
Facebook
Portfolio
Twitter
Gmail

--

--