Since long I’ve been thinking of creating an API which fellow developers could use over cloud. But I didn’t want it to be the traditional Hello World API or simple SQL Flask API — supporting the classical user name and email ID GET, PUT, POST, DELETE REST requests. Since AI and ML are so pervasive now, I thought of giving ML a try — and it is easy : )
Image Classification using PyTorch
Since everything’s on the cloud and the free versions I needed an AI/ML project with small, lightweight dependencies for the cloud. I forked a project named Img2vec that uses PyTorch to generate feature vectors for a dataset of images and then does a simple match(cosine similarity in sklearn) of a test image with others using pretrained models. The readme provides enough snippets to make the demo work on your system.
Training the image data set
For the sake of simplicity, I’m skipping the data set training part — there are lot of docs over the internet. For our API purpose, I’m generating image vectors and storing them as csv files. Since computations will happen over cloud, if we generate vectors on runtime, the API will always timeout. But the API will generate vector for the image data you’ll ping and read other vectors from csv files.
Setting up the Flask app
For a simple image match, I needed a POST or PUT request. There are many frameworks that support REST features, I used flask on Python 3.6. As the official site reads
Flask is a lightweight WSGI web application framework. It is designed to make getting started quick and easy
So here’s the idea:
- The server receives an image as byte stream
- Those bytes are encoded back to an image using Pillow in Python
- A feature vector is generated for that image
- That vector is matched against the training data images
- Closest matching images are returned back (with a download URL and a percentage match value)
Here’s the gist of the API :
# endpoint to detect image
# Image converted in Base64 encoded byte stream
bytes = request.get_data()
results = search(bytes) # returns a list of image matches
Note: The test Image posted is not normal byte stream, its base64 encoded image byte stream — a standard way to ship binary data across networks.
Note: Get the source code here.
Setting up the Heroku server
Setting up the server is pretty straight forward, but there are a few issues I encountered:
- Slug Size: Heroku provides a max slug size of 500 MB, the total data an app can hold(code, executables, and media files like images, pdfs) after compression. This is an issue since python libraries like SciPy and PyTorch are pretty hefty in size — PyTorch with CUDA 8 is ~600MB alone. But there are workarounds.
- PyTorch Version: PyTorch provides a CPU only build variant, a small 45 MB library providing all the features we need for deployment.
- SciPy Manual Install: Don’t know why, but an online install of SciPy on Heroku turns out to be buggy(more of it here). Instead I downloaded the SciPy whl(or the source code) as a file and manually installed it on the server.
- JPEG vs PNG: In the demo code, I’ve used JPEG extension files(with base64 encoded data beginning from /9j/9…). So if you ping the server with images in other formats like PNG(base64 encoded data starts with iVi…) you will get an error.
You can take a demo in many ways
- You can install an app(get it here) on your android device. Click an image it’ll take some time and will display the closest matching images.
- You can use this python script. Just enter the relative path of the image on your desktop and it will display the results
- You can make a PUT request online(I recommend Hurl.it for starters) to this URL with base64 encoded image data as body.
For the sample cat image shown, the results are