Image Classification App using TensorFlow Lite & Flutter
On-device Machine Learning using Deep Learning Neural Networks in TensorFlow and Keras (Step By Step Process)
Machine learning and Artificial Intelligence take the development of mobile applications to a new level. Apps that use machine learning can detect speech, images, and body language. AI gives us new and compelling ways to engage and connect with people in the world around us. But how do we integrate machine learning into our mobile applications?
Developing mobile applications that include machine learning has been a difficult task for a long time. But dev platforms and tools like Firebase’s ML, and TensorFlow Lite have made it a piece of cake. These tools provide users with pre-trained machine learning models as well as tools for training and importing traditional models. But how do we develop a compelling experience on top of those machine learning models?
That’s where Flutter enters.
Flutter SDK is a portable UI toolkit developed by Google and its open community to improve Android, IOS, Web, and Desktop applications. At its core, Flutter integrates an efficient graphics engine with the Dart programming language. Using Flutter, we can build mobile applications with machine learning capabilities such as image classification and object detection, on both Android and iOS platforms.
In this article, we will combine the power of Flutter and on-device ML to develop a Flutter application that can detect animals.
Let’s have a quick look at what we are gonna build today:
Get the Dataset
I have downloaded the dataset from here.
Just download the dataset and choose your favourite classes(in this case, animals) and delete the remaining folders. You can download the dataset from here.
Also, you can use any other dataset if you wish to.
This article will be divided into 2 modules:
- Training a CNN(Convolutional Neural Network)
- Creating Flutter UI and Importing the .tflite file(ML Model) into the Flutter project
The entire, completed project can be found and cloned from my GitHub Repo → https://git.io/JEqvP
Module 1: Training Neural Networks
There are two ways to train our model:-
- Teachable Machine (if you don’t want to build Tensorflow models on your own)
- TensorFlow and Keras (if you have a little bit of knowledge about Deep Learning)
Method 1: Using Teachable Machine
What is a Teachable Machine? → Teachable Machine is a web-based tool that makes creating machine learning models fast, easy, and accessible to everyone. It can be used to recognize images, sounds ,or poses.
Now, let’s see how we can create our machine learning model using Teachable Machine.
First, go to Teachable Machine and open Image Project. To get started with training the model we will need to create five classes (as per my dataset), namely “Elephant”, “Kangaroo”, “Panda”, “Penguin” and “Tiger” and upload training images to build a model.
Replace Class 1
name with Elephant
and click on Upload
to upload the training images for the cats.
Now repeat the process for other classes. Change Class 2
to Kangaroo
and upload the training images for kangaroos, change Class 3
to Panda
and upload the training images for pandas and do the same process for the rest animals.
It’s gonna take time if you have uploaded a lot of training images, so sit back and enjoy your coffee :)
After your model is trained, click on the export model and download the Tensorflow Lite Floating Point Model.
Warning!!! In case your model can’t recognize our cute penguin
You can find my model.tflite & labels.txt files directly from here.
Method 2: Using TensorFlow and Keras
What is TensorFlow? → TensorFlow is an open-source artificial intelligence library, using data flow graphs to build models. It allows developers to create large-scale neural networks with many layers. TensorFlow is mainly used for: Classification, Perception, Understanding, Discovering, Prediction and Creation.
Let's get our hands dirty by writing some code:
Let’s start by opening Jupyter Notebook (or Google Colab):
- The
os
module will provide us functions for fetching contents of and writing to a directory. - Set the
base_dir
variable to the location of the dataset containing the training images.
IMAGE_SIZE = 224
→ the image size that we are going to set the images the dataset to.BATCH_SIZE = 64
→ the number of images we are inputting into the neural network at once.rescale=1./255
reduces the file size, to reduce the training time.- Datasets have a Test set and a Training set. The training set is to train our model and the test(validation) set is to measure how accurate our model is. So with
validation_split=0.2
, we are telling Keras to use 80% for training and 20% for accuracy testing (validation). - Then, we have two generators (
train_generator
andval_generator
), which take the path to the directory & generate batches of augmented data, which in this case give the output: Found 2872 images belonging to 36 classes and Found 709 images belonging to 36 classes.
- Print all keys and classes (labels) of the dataset to re-check if everything is working fine.
- Flutter requires two files: model.tflite and labels.txt .
- The
‘w’
in the code creates a new file called labels.txt having the labels(names of animals), which if already exists, then overwrites it.
Now we will use MobileNetV2, which is a convolutional neural network architecture that seeks to perform well on mobile devices. It is based on an inverted residual structure where the residual connections are between the bottleneck layers.
- In our case, the fully connected output layers of the model used to make predictions are not loaded, allowing a new output layer to be added and trained. So, we are setting the
include_top
argument toFalse
.
base_model.trainable=False
to freeze all the weights before compiling the model.- Now we will add our hidden layers:
- Convo2D is a 2D convolution layer that creates a convolution kernel that is a wind with layers input which helps produce a tensor of outputs. It is trying to understand the image’s patterns.
▹‘relu’
stands for rectified linear unit activation function. It is a piecewise linear function that will output the input directly if it is positive, otherwise, it will output zero. - Dropout layer prevents Neural Networks from Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
▹ The Dropout layer randomly sets input units to 0 with a frequency of rate at each step during training time, which helps prevent overfitting.
▹ Note that the Dropout layer only applies when training is set to True such that no values are dropped during inference. - GlobalAveragePooling2D layer calculates the average output of each feature map in the previous layer, thus reducing the data significantly and preparing the model for the final layer.
▹ The 2D Global average pooling block takes a tensor of size (input width) x (input height) x (input channels) and computes the average value of all values across the entire (input width) x (input height) matrix for each of the (input channels). - Dense layer is a deeply connected layer in which each neuron receives input from all neurons of its previous layer. ‘5’ here stands for the number of classes (here types of animals).
▹‘softmax’
converts a real vector to a vector of categorical probabilities.
- Before training the model we need to compile it and define the loss function, optimizers, and metrics for prediction. So, we use
model.compile
which defines the loss function, the optimizer, and the metrics, because a compiled model is needed to train (since training uses the loss function and the optimizer). - We will use
Adam
which is a popular optimizer, designed specifically for training deep neural networks. Adam is a replacement optimization algorithm for stochastic gradient descent for training deep learning models.
- Epochs → An epoch means training the neural network with all the training data for one cycle. In an epoch, we use all of the data exactly once. A forward pass and a backward pass together are counted as one pass: An epoch is made up of one or more batches, where we use a part of the dataset to train the neural network.
▹ Higher the number, the more accurate the neural network, but having the number too high could cause Overfitting, i.e being too precise to a point where the NN is only able to recognize images that are present in the dataset and no other images.
It’s gonna take time if you have uploaded a lot of training images, so sit back and wait for the model to get trained.
Now we have to convert our Neural Network Model to a .tflite file which we can use in our Flutter App.
- Keras SavedModel uses
tf.saved_model.save
to save the model and all trackable objects attached to the model - To convert a SavedModel to a TensorFlow Lite model we use
tf.lite.TFLiteConverter.from_saved_model
.
If you are using Google Colab, then at first upload the dataset.zip file to drive, mount the drive, extract files using colab and use it. Finally you can download the model.tflite and labels.txt file by using following codes:
from google.colab import files
files.download(‘model.tflite’)
files.download(‘labels.txt’)
Warning!!! In case your model can’t recognize our cute panda
You can find my model.tflite & labels.txt files directly from here.
Module 2: Importing and using TensorFlow Lite file in our Flutter app
Open terminal, then navigate to your project directory and run the command ‘flutter run project_name
’ or if you are using Visual Studio Code, then open command palette from ‘view’ of topbar (or by pressing ctrl+shift+P) and choose the option for flutter to create a new app and write the project name and hit Enter.
Let’s get our hands dirty by writing some code once again:
Next, head over to the ‘pubspec.yaml’, add the following dependencies, and save (may be you have to run '
flutter pub get
’ command, by this command flutter saves the concrete package version found in thepubspec.lock
lockfile):
dependencies:
flutter:
sdk: flutter
tflite: ^1.1.2
image_picker: ^0.8.3+2
For tflite to work, in
android/app/build.gradle
, set minSdkVersion to 19 and add the following setting inandroid
block.
aaptOptions {
noCompress 'tflite'
noCompress 'lite'
}
android/app/build.gradle path
For image_picker to work, in
/ios/Runner/Info.plist
, add the following to yourInfo.plist
file.
<key>NSCameraUsageDescription</key>
<string>Need Camera Access</string>
<key>NSMicrophoneUsageDescription</key>
<string>Need Microphone Access</string>
<key>NSPhotoLibraryUsageDescription</key>
<string>Need Gallery Accesss</string>
Create a folder named “assets” and place the model.tflite & labels.txt files within it.
Then add their existency in pubspec.yaml file like following.
Then save the file (may be you have to run ‘
flutter pub get
’ command, by this command flutter saves the concrete package version found in thepubspec.lock
lockfile) like previous.
Now it’s time to create the UI and Functions of our Flutter App.
In the ‘main.dart’ file, return MaterialApp that has a parameter home: Home()
Then create a new ‘
home.dart
’ file having the Stateful class Home(). This will be our homepage. Let’s start making the Functional Flutter App by importing the necessary packages and creating the functions:
_loading
→ used to check if an image has been chosen or not_image
→ image that is chosen from gallery or camera_output
→ prediction made by TensorFlow Modelpicker
→ allows us to pick an image from gallery or camera.
Next, we will write 6 different methods for the class:
- The first 2 methods :
initState()
→ This is the first method that is called when the Home widget is created i.e we the app is launched and navigated to Home(), before actually building the widget itself, anything inside initState() function will be called or initialized first and the widgets are built later. In this case, we will load our model using loadModel(), which is another method that will be written later. After that, we will pass in a value.dispose()
→ This method disposes and clears our memory.
- The last 4 methods:
3. classifyImage()
→ this method runs the classification model on the image. The numResults is the number of classes (here the number of animals) we have, then adding setState to save changes.
4. loadModel()
→ this function will load our model, hence we put it inside the initS method.
5. pickImage()
→ this function is used to grab the image from the camera.
6. pickGalleryImage()
→ this function is used to grab the image from the user’s gallery.
Let’s create the AppBar:
Now it’s time for the body part of our homepage. Let’s make a container to hold an image that the user has selected.
We have used ClipRRect to give nice circular borders to the image.
Next, we are going to make two GestureDetectors, that onTap: refer to the pickImage and the pickGalleryImage function respectively.
- NOTE:
pickImage
(without parenthesis) insideonTap
is a function reference, which basically means it is not executed immediately, it is executed after the user clicks on the specific widget which is known as a callback. pickImage()
is a function call and it is executed immediately.
Done!
Now run the project:
▹ In the terminal run the command flutter run
▹ Or if you are using VS Code, then from TopBar click on Run Without Debugging (ctrl+F5)
Hope your app will not crash during build and no error will come.
Open the app and take photos or select images from Gallery and predict 🤩
Now you can feel the power of AI 😉
Once you get the hang of it, you can see how easy is it to use TensorFlow Lite with Flutter to develop proof-of-concept machine learning mobile applications. It’s incredible to make this very powerful app with this small piece of code. To improve your knowledge, you can visit the Kaggle site and download various datasets to develop different classification models and use those in your own app.