Using Flask to optimize performance with Mask R-CNN segmentation

How to improve Mask R-CNN segmentation performance using a Flask web service.

Dino Fejzagić
medialesson
4 min readJul 6, 2020

--

In a recent project I was facing the challenge of creating a machine learning powered photo booth solution. The software allows its users to take pictures of themselves in a car, have the background automagically removed without the usage of a greenscreen, and replaced with another more pleasing background scenery. Among other things, performance was key here. The segmentation needed to be finished by the time it takes the user to exit the car and walk up to the nearby pickup station. That gives us around five to ten seconds.

In this story you’ll learn how you can use Flask, a micro web framework written in Python, to drastically reduce the time it takes for the pre-trained Mask R-CNN model to generate bounding boxes and segmentation masks for the captured photograph.

Mask R-CNN

The Mask R-CNN for object detection and segmentation framework is a Python implementation based on the work published by Facebook Research and allows to get bounding boxes and classifications for a variety of object types, such as cars and people, for any input image provided. Of course its performance is heavily bound to the computational power of the host machine.

A busy street with objects classified by Mask R-CNN.
Source: https://github.com/matterport/Mask_RCNN/raw/master/assets/street.png

That said, we built a powerful rig equipped with a Nvidia Titan RTX GPU with CUDA support and a high end CPU for any CPU bound work that needed to be done, which is mostly OpenCV based image processing and error corrections. With this setup we achieved a total computation time of around 15 seconds per photograph for an input / output resolution of 1920x1920 pixels.

Here’s one example of the processing steps from input to output with me and my colleague:

Background segmentation steps and output

So how does Flask help here?

Once you have started playing around with Mask R-CNN you will notice that the model and its weights need to be loaded and initialized whenever you want to run segmentation on an image. Say you had a Python script called my_segmentation_script.py which takes the path to the input image as an argument, you’d execute it like this:

Now this runs fine and will perform the required work, assuming you have setup Mask R-CNN properly. The issue here is that once the script is done it will clean up any allocated resources and release the model and its weights again. So for any subsequent image they have to be loaded again, which costs us around up to 10 seconds of time! So let’s try and get rid of this overhead using Flask.

Follow the instructions over here to install and setup your Flask web service. Once it’s running you’ll be able to create Views which you can think of as API endpoints. Following example code creates a GET endpoint for the api/v1/ route of your flask app that returns "Hello World" when called in a browser:

All we need to do now is create a view / endpoint for our segmentation python code to execute and pass it an input and output path as parameters so it can read in the image and output the segmentation at the desired path. The important differentiation is that we will only load the model and the weights once when the Flask app starts and then re-use it throughout the session / lifetime of the service. Here’s the full code and then we’ll take a look at it in more detail:

Most if this is just boilerplate code for importing the required Python modules and creating the Flask view. Pay attention to lines 11 through 23 as we are initializing the keras session here and then loading in the model weights. At line 23 we are then finally assigning the created session to be the active session.

Next we are simply defining a POST endpoint that takes the input and output path in its request body. Note lines 30, 31 and 32 here as this is where we are making sure that we are using the globally defined session and graph variables we just created.

Using main.doSegmentation(rcnn, input_file_path, output_file_path) we are then passing everything to our segmentation code, which makes use of the pre-loaded weights.

Summary

By loading the model weights only once and then using them in a shared session throughout the lifetime of the Flask service we were able to save another ~10 seconds of computation time for a captured photograph to run through background segmentation. Combined with decent hardware this fulfills our requirements in performance for the photo booth to work.

Please share in the comments if you run in to any issues or if you have questions. I hope you find this little trick useful and can boost your performance just as I did!

--

--

Dino Fejzagić
medialesson

Senior XR Engineer and CTO CodeEffect GmbH. My topics include Unity, AR/VR/XR, Game Dev, Software Engineering