Deploying Keras Deep Learning Models with Flask
This post demonstrates how to set up an endpoint to serve predictions using a deep learning model built with Keras. It first introduces an example using Flask to set up an endpoint with Python, and then shows some of issues to work around when building a Keras endpoint for predictions with Flask.
Productizing deep learning models is challenging, or at least has been for me in the past, for a number of reasons:
- Model Serialization: Standard approaches for serializing models, such as PMML, have only limited support. For example keras2pmml lacks relu activations, which means that the DataFlow + PMML approach I presented in my model production post is not viable.
- Large Libraries: In my past post, I showed how to use AWS lambda functions to host scikit-learn models. This approach is problematic for Keras, because the uncompressed Keras and Tensorflow libraries exceed the 256MB limit of file uploads for AWS lambda.
- Run Time: Both batch and real-time predictions are hard to scale, because most of my model prediction experience is in Java. I previously blogged about providing real-time estimates using Jetty, and batch estimates using Google’s DataFlow. This post shows how to use Flask in place of Jetty, when you need to use a Python library for making estimates.
The goal of this post is to show how to set up a Keras model as an endpoint on an EC2 instance with AWS. I got started by exploring the following example:
Some of the issues that I’ll cover include handling a custom metric when using model persistence with Keras, dealing with multi-threading concerns when using Keras in combination with Flask, and getting it all running on an EC2 instance. The complete code listing for this post is available on GitHub.
This post covers setting up a simple Flask app, and then shows to use Flask to set up an endpoint with a Keras model. It assumes that readers are familiar with setting up an EC2 instance with jupyter, which is covered here.
Hello World with Flask
Flask is a Python library that makes it easy to set up Python functions that can be invoked via the web. It uses annotations to provide metadata about which functions to set up at which endpoints. To use Flask, you’ll first need to install the module:
pip3 install --user Flask
To get familiar with Flask, we’ll set up a simple function that echoes a passed-in parameter. The snippet below first instantiates a Flask app, defines the function, and then launches the app. With Flask, the app.route annotation is used to specify where to make functions available on the web, and which methods to allow. With the code below, the function will be available at location:5000/predict
. The function checks the request.json
and request.args
objects for input parameters, which are used based on how the function is called (e.g. browser get vs curl post). If a msg parameter has been passed to the function, when it is echoed to the JSON response returned by the function.
# load Flask
import flask
app = flask.Flask(__name__)# define a predict function as an endpoint
@app.route("/predict", methods=["GET","POST"])
def predict():
data = {"success": False} # get the request parameters
params = flask.request.json
if (params == None):
params = flask.request.args # if parameters are found, echo the msg parameter
if (params != None):
data["response"] = params.get("msg")
data["success"] = True # return a response in json format
return flask.jsonify(data)# start the flask app, allow remote connections
app.run(host='0.0.0.0')
When you run python3 Flask_Echo.py
you’ll get the following result:
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
We can now connect to the function to test it. I included host='0.0.0.0'
in order to enable remote connections, since the Flask app is running on an EC2 instance. If you’re using EC2, you’ll need to modify the security group to allow access to Flask on port 5000, similar to allowing Jupyter access on 8888.
The function can be called using a web browser or curl. I used curl from a Windows environment, otherwise you can use -d '{"msg":"Hello World"}'
. The result is the same for both approaches, a JSON response from the client which repeats the passed in msg parameter.
# Browser
http://54.227.110.43:5000/predict?msg=HelloWorld# Curl
>curl -X POST -H "Content-Type: application/json" -d "{ \"msg\":
\"Hello World\" }" http://54.227.110.43:5000/predict# Response
{
"response": "Hello World",
"success": true
}
We now have the ability to set up Python functions as web endpoints, the next step is to have the function call a trained deep net.
Flask & Keras
To use Keras for Deep Learning, we’ll need to first set up the environment with the Keras and Tensorflow libraries and then train a model that we will expose on the web via Flask.
# Deep Learning setup
pip3 install --user tensorflow
pip3 install --user keras
pip3 install --user pandas
Since I used an EC2 instance without an attached GPU, no additional configuration is necessary for running Keras in CPU mode.
Model Training
I created a binary classifier using a simple network structure. The input to the model is a feature array which describes which games a user has previously played, and the output is the likelihood of the player to play a specific game in the future. More details about training the model are available in my past post on deep learning.
# import panda, keras and tensorflow
import pandas as pd
import tensorflow as tf
import keras
from keras import models, layers# Load the sample data set and split into x and y data frames
df = pd.read_csv("https://github.com/bgweber/Twitch/raw/
master/Recommendations/games-expand.csv")
x = df.drop(['label'], axis=1)
y = df['label']# Define the keras model
model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10,)))
model.add(layers.Dropout(0.1))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.1))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
# Use a custom metricfunction
def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
keras.backend.get_session().run(tf.local_variables_initializer())
return auc # Compile and fit the model
model.compile(optimizer='rmsprop',loss='binary_crossentropy',
metrics=[auc])
history = model.fit(x, y, epochs=100, batch_size=100,
validation_split = .2, verbose=0)
# Save the model in h5 format
model.save("games.h5")
The code snippet defines a custom metric function, which is used to train the model to optimize for the ROC AUC metric. The main additional to this code is the last step, which serializes the model to the h5 format. We can later load this model in the Flask app to serve model predictions. The model can be trained by running python3 Flask_Train.py
which generates games.h5
.
Model Deployment
Now that we have our environment set up and have trained a deep learning model, we can productize the Keras model with Flask. The complete code listing for serving model predictions is shown below. The overall structure of the code is the same as our previous example, but the main difference is loading the model before defining the predict function, and using the model in the predict function. In order to reload the model, we need to pass in the custom metric function as an input argument to load_model, using the custom_objects parameter.
# Load libraries
import flask
import pandas as pd
import tensorflow as tf
import keras
from keras.models import load_model
# instantiate flask
app = flask.Flask(__name__)
# we need to redefine our metric function in order
# to use it when loading the model
def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
keras.backend.get_session().run(tf.local_variables_initializer())
return auc
# load the model, and pass in the custom metric function
global graph
graph = tf.get_default_graph()
model = load_model('games.h5', custom_objects={'auc': auc})
# define a predict function as an endpoint
@app.route("/predict", methods=["GET","POST"])
def predict():
data = {"success": False}
params = flask.request.json
if (params == None):
params = flask.request.args
# if parameters are found, return a prediction
if (params != None):
x=pd.DataFrame.from_dict(params, orient='index').transpose()
with graph.as_default():
data["prediction"] = str(model.predict(x)[0][0])
data["success"] = True
# return a response in json format
return flask.jsonify(data)
# start the flask app, allow remote connections
app.run(host='0.0.0.0')
It’s also necessary to set up a reference to the tensorflow graph using tf.get_default_graph()
. If this step is omitted, an exception may occur during the predict step. The condition with graph.as_default()
is used to grab a threadsafe reference to the graph when making predictions. In the predict function, the request arguments are converted to a data frame and then passed to the Keras model to make a prediction. Additional details on using the passed in arguments are covered in my models as a service post.
The Flask app can be deployed by running python3 Flask_Deploy.py
. You can connect to the app the same way as before, but you’ll need to specify values for attributes G1 to G10. I used a browser to test the endpoint, which produced the following results:
# Browser
http://54.227.110.43:5000/predict?g1=1&g2=0&g3=0&g4=0&g5=0&g6=0&g7=0&g8=0&g9=0&g10=0# Response
{
"prediction":"0.04930059",
"success":true}
}
You now have an EC2 instance that can serve Keras predictions on the web!
Conclusion
Deploying deep learning models is non-trivial, because you need to use an environment that supports a tensorflow runtime. To provide a Keras model as a service, I showed how Flask can be used to serve predictions with a pre-trained model. This approach isn’t as scalable as the AWS lambda method I discussed in my last approach, but may be more suitable for your use case. Flash is also useful for setting up local services when prototyping.
Ideally, I’d like to be able to combine the annotations in Flask with the scalability of AWS lambda, without the intermediate step of installing libraries to a directory and uploading the result. AWS SageMaker helps work towards this goal, and I’ll be exploring this tool in more detail in the coming weeks.