TensorFlow: How to freeze a model and serve it with a python API
Morgan
1.4K44

Hi, I trained an image classification model using convnet and then froze it.

I followed your tutorial as is and served it using Flask for image prediction, it works well for 10–20 images after that it hogs all of the RAM and CPU.

Currently I am running it on hardware containing 12 GB RAM , i7 processor.

Code-

img_size = 128
num_channels = 3
img_size_flat = img_size * img_size * num_channels
classes = ['class_a', 'class_b']

app = Flask(__name__)
@app.route('/api/image/classify/upload', methods=['POST'])
def upload():
file = request.files['file']
file.save(os.path.join('/tmp', file.filename))
img = image_util.read_img('/tmp/'+ file.filename, img_size)
pred = prediction.eval(feed_dict={x: [img[0].reshape(img_size_flat)], keep_prob: 1}, session=persistent_sess)
result = classes[pred[0]]
classification = persistent_sess.run(tf.nn.softmax(y) * 100,
feed_dict={x: [img[0].reshape(img_size_flat)], keep_prob: 1})
data = {"result": result, "class_a_score": round(classification[0, 0].item(), 2),
"class_b_score": round(classification[0, 1].item(), 2)}
os.remove('/tmp/'+file.filename)
return jsonify(results=data)

def _
load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

# Then, we can use again a convenient built-in function to import a graph_def into the
# current default Graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="prefix",
op_dict=None,
producer_op_list=None
)
return graph
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="model/frozen_model.pb", type=str,
help="Frozen model file to import")

args = parser.parse_args()
print('Loading the model')
graph = _load_graph(args.frozen_model_filename)
x = graph.get_tensor_by_name('prefix/x:0')
y = graph.get_tensor_by_name('prefix/add_4:0')
keep_prob = graph.get_tensor_by_name('prefix/Placeholder:0')
prediction = tf.argmax(y, 1)
persistent_sess = tf.Session(graph=graph)
print('Starting the API')
app.run(host="0.0.0.0",port=int("8080"))