Deploy ML models using Flask as REST API and access via Flutter app
Simplicity of Flask , Awesomeness of Flutter and Keras ‘s image classification model
Introduction
Machine Learning has become one of the cool technologies in the recent times, almost every software product out in market uses ML in one or the other way. Let’s see how to build an application that can upload images to server and make predictions on it (image classification ). These images can be accessed by an app and you can simply search an image by its content.
We will use Flask (Python framework) as back end for our REST API, Flutter for mobile app and Keras for image classification. We will also use MongoDB as our database to store data about the images and classify images using Keras ResNet50 model, using a pretrained model seems to be useful for this purpose. We can use a custom model if needed by saving it save_model() and load_model() methods available in Keras. Keras would take around 100 mb to download the model for the first time. To know more about available models refer documentation
Let us start with Flask
Defining a route in Flask is pretty simple , use the decorator @app.route(’/’) , where @app is the name of the object containing our Flask app. lets look at an example
from flask import Flask
app = Flask(__name__)@app.route(’/’)
def hello_world():
return ‘Hello, World!’
Import Flask from flask and provide a name for app , here 'app' is name of object containing our app. Next, we will use the decorator to define route so that whenever a request comes to that route(Eg: http://127.0.0.1:5000/) flask will automatically return
Hello World!
as response. Yes, it is that simple!
If you want to know more about Flask check out the documentation. Now take a look at our back end code:
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as img
from keras.preprocessing.image import img_to_array
import numpy as np
from PIL import Image
from keras.applications.resnet50 import ResNet50,decode_predictions,preprocess_input
from datetime import datetime
import io
from flask import Flask,Blueprint,request,render_template,jsonify
from modules.dataBase import collection as db
We have some import statements which include tensorflow since we are using tensorflow as back end for keras, and numpy for dealing with multi-dimensional arrays.
mod=Blueprint(‘backend’,__name__,template_folder=’templates’,static_folder=’./static’)
UPLOAD_URL = ‘http://192.168.1.103:5000/static/'
model = ResNet50(weights=’imagenet’)
model._make_predict_function()
Since this example uses Flask-blueprints (way to organize files as modules), the first line makes an object of it.
One thing to note is that now all our route decorators will use @mod.route(‘/’) to define route. Our model will be Resnet50 trained on imagenet dataset and we call _make_predict_function() on it. There are chances you may get errors if this method is not used. If you have custom model you can use it the same way by changing
model = ResNet50(weights=’imagenet’)
to
model = load_model(“saved_model.h5”)
@mod.route(‘/predict’ ,methods=[‘POST’])
def predict():
if request.method == ‘POST’:
# check if the post request has the file part
if ‘file’ not in request.files:
return “No file found”user_file = request.files[‘file’]
temp = request.files[‘file’]
if user_file.filename == ‘’:
return “file name not found …”else:
path=os.path.join(os.getcwd()+’\\modules\\static\\’+user_file.filename)
user_file.save(path)classes = identifyImage(path)#save image details to databasedb.addNewImage(
user_file.filename,
classes[0][0][1],
str(classes[0][0][2]),
datetime.now(),
UPLOAD_URL+user_file.filename)
return jsonify({
“status”:”success”,
“prediction”:classes[0][0][1],
“confidence”:str(classes[0][0][2]),
“upload_time”:datetime.now()
})
The above route accepts as POST request and checks for file then the image is passed to identifyImage(file_path) method which is as shown below:
def identifyImage(img_path):
image = img.load_img(img_path,target_size=(224,224))
x = img_to_array(image)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
preds = decode_predictions(preds,top=1)
print(preds)
return preds
This method accepts as file path as argument. The image is then transformed into 224*224 size since our model requires this exact shape for image. We will then preprocess image and pass it to model.predict() function. Now, our model will make prediction and the decoded prediction will return 2D array of tuples(top=1 returns only 1 prediction with highest confidence).
We will then save the image details into our MongoDB database, db.addData() will take care of this. Database code looks like this:
from pymongo import MongoClient
from bson import ObjectId
client = MongoClient(“mongodb://localhost:27017”) #host uri
db = client.image_predition #Select the databaseimage_details = db.imageData
def addNewImage(i_name,prediction,conf,time,url):
image_details.insert({
“file_name”:i_name,
“prediction”:prediction,
“confidence”:conf,
“upload_time”:time,
“url”:url
})def getAllImages():
data = image_details.find()
return data
Since this example uses blueprint we can write our API code will be in a separate file like this :
from flask import Flask,render_template,jsonify,Blueprint
mod = Blueprint(‘api’,__name__,template_folder=’templates’)
from modules.dataBase import collection as db
from bson.json_util import dumps@mod.route(‘/’)
def api():
return dumps(db.getAllImages())
We have a couple of import statements ‘dumps ’ helps us in converting our pymongo(python api of MongoDB) objects in json values.
db.getAllImages() fetches all images we have uploaded. Our API can be accessed through the endpoint
http://127.0.0.1:5000/api # address may vary depending upon the host provided in app.run()
Only important parts of the code areexplained here . To know how files are arranged check out GitHub repository. To know more about Pymongo check here.
Flutter Application
The application will use the REST API to fetch image and display it , we can also search an image by the content.Our app will look like this
ImageData class which will act as our model class
import ‘dart:convert’;
import ‘package:http/http.dart’ as http;
import ‘dart:async’;
class ImageData
{String uri;
String prediction;
ImageData(this.uri,this.prediction);
}Future<List<ImageData>> LoadImages() async
{
var data = await http.get(
‘http://192.168.1.103:5000/api/'); #localhost api path
var jsondata = json.decode(data.body);
List<ImageData>list = [];
for (var data in jsondata) {
ImageData n = ImageData(data[‘url’],data[‘prediction’]);
list.add(n);
}returnlist;}
Here we fetch json data and convert it into List of objects of ImageData and return to Future Builder with the help of LoadImages() function.
Upload image to server
uploadImageToServer(File imageFile)async{print(“attempting to connect to server……”);var stream = new http.ByteStream(DelegatingStream.typed(imageFile.openRead()));var length = await imageFile.length();print(length);var uri = Uri.parse(‘http://192.168.1.103:5000/predict');print(“connection established.”);var request = new http.MultipartRequest(“POST”, uri);var multipartFile = new http.MultipartFile(‘file’, stream, length,filename: basename(imageFile.path));//contentType: new MediaType(‘image’, ‘png’));request.files.add(multipartFile);var response = await request.send();print(response.statusCode);}
In order to make our Flask app available on a local network, make sure that the app is not in debug mode by changing it to False and find the ipv4 address from command line using ipconfig command. Here the address is ‘192.168.1.103’
app.run(debug=False,host=’192.168.1.103',port=5000)
Firewall may prevent the app from accessing the localhost so make sure that firewall is turned off.
The complete code reference to Flutter app is available in GitHub repository. Here are some other useful links
Keras : https://keras.io/
Flutter : https://flutter.dev/
MongoDB : https://www.tutorialspoint.com/mongodb/
Harvard University CS50 course on Python and flask : https://www.youtube.com/watch?v=j5wysXqaIV8&t=5515s (watch lectures 2,3,4)
GitHub : https://github.com/SHARONZACHARIA