Introducing Java APIs for Deep Learning Inference with Apache MXNet

By : Andrew Ayres, Qing Lan, Naveen Swamy, Piyush Ghai, Yizhi Liu

Apache MXNet is an open-source deep learning framework used to train, and deploy deep neural networks. It is scalable, allowing for fast model training, and supports a flexible programming model and multiple languages (C++, Python, Julia, Clojure, JavaScript, R, Scala).

Today, the Apache MXNet community is pleased to announce the preview of Java APIs for Inference. MXNet Java APIs make it easy to leverage deep learning models within systems and applications built on top of the popular Java language and runtime environment. A Machine Learning expert can train the models using Python, fine tune and save the model, and one can load the model and use it in production for inference using the Java APIs.

Inference using MXNet Java APIs

In this post, we outline the available Java APIs and walk you through quick hands-on setup to get started with them.


MXNet-Java API : What’s new ?

Ease of use was the center of our thought process when we were designing the APIs. With just a few lines of code, one can load state-of-the-art models trained using Apache MXNet and begin making predictions. Java APIs also have automated resource management system making it easy to manage the native memory footprint without any degradation in performance.

Here are the available Inference APIs in Java in MXNet:

  • Predictor API
  • Object Detector (Single and Batch)

Predictor API provides methods to perform inference on a pre-trained model.

ObjectDetector is a wrapper on top of Predictor API. It provides methods to detect distinct objects in an image as well as their locations in the image.


Quick Start with Java APIs

The Java APIs currently support Java 8 and higher. A preview of MXNet Java package is available on Maven Central. The official release will be alongside the upcoming MXNet 1.4 release.

Step 0 : Environment Setup

Follow the step-by-step tutorial here to configure your development environment to use the Java APIs.

Step 1 : Loading a pre-trained model with Predictor API

MXNet has a vast collection of pre-trained state-of-the-art models in deep learning available on the MXNet Model Zoo. There are also pre-trained state-of-the-art models implemented using Gluon (an imperative API for MXNet) available on the Gluon Model Zoo. Later in this post, we will run an inference example on a Resnet50 Single Shot Detection model taken from the Model Zoo.

To load a pre-trained MXNet model using Java APIs, one requires path to the downloaded model and description of model inputs. To specify an input expected by the model, we create an object of type DataDesc. The input details are typically defined during the model training phase. The DataDesc takes in the input layer name, the shape of the input, datatype, and the order of the data.

Shape inputShape = new Shape(new int[] {1,3,224,224});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW");
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);

It is easy to specify whether you want to run the inference on CPUs or GPUs (if you have a GPU backed machine) by specifying it in a Context object.

List<Context> context = new ArrayList<>();
context.add(Context.cpu());
// For GPU, context.add(Context.gpu());
The examples in this tutorial have been run on a MacBook Pro using CPU context. To use MXNet with an NVIDIA GPU on Mac, follow the instructions here.
String modelPathPrefix = "path-to-model";
Predictor predictor = new Predictor(modelPathPrefix, inputDescList, context);

The modelPath is a folder, which contains a symbol file (representing the model layers), a params file (trained model weights) and any other auxiliary files required by the model.

Step 2 : Inference using the Predictor

The Predictor class has three prediction functions, which take in an input and produce the predictions as an output. The inputs to prediction function can be either NDArray or one-dimensional Java List or a one-dimensional Java array.

In MXNet, NDArray is the core data structure for all the mathematical computations involved in dealing with deep learning models.

Here are the three sample calls to the Predictor API:

List<NDArray> result = predictor.predictWithNDArray(inputNDArray);

or

List<Float> result = predictor.predict(inputFloatList);

or

float[] result = predictor.predict(inputFloatArray);

Step 3 : It’s showtime!

Now that we’ve briefly talked about what the Predictor API looks like, let’s take a shot at Object Detection example, written in Java from the MXNet repository. We will use the Object Detector API to identify objects and their location in an image.

Follow the steps here to download the model files. We’ll require three model files to work with this tutorial : symbol file, params file and synset.txt file. The symbol file contains the description of the model architecture, i.e. the layers used in the model. Params file contains the trained model weights assigned to the layers. The synset file contains the class labels used in training.

Let’s define the input shapes, input descriptors and context which will be used later on.

Shape inputShape = new Shape(new int[] {1,3,512,512});
DataDesc inputDescriptor = new DataDesc("data", inputShape, DType.Float32(), "NCHW");
List<DataDesc> inputDescList = new ArrayList<DataDesc>();
inputDescList.add(inputDescriptor);
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());

Let’s go over the variables we described.

  1. The inputShape defined above represents a batch input with a batch size of 1, having 3 channels (RGB) in an image, and the height of each image in the batch is 512 and width of each image is 512.
  2. The inputDescriptor object’s first parameter is the input layer’s name in the model symbol file followed by the inputShape, the datatype of input : Float32 and “NCHW”, where N stands for batch size, C stands for channels, H is height, W is width of the image.

Now, let’s define an instance of ObjectDetector, a high level API that provides method to detect distinct objects in an image, their predicted labels and locations in the image. It also contains utility methods to perform image pre-processing on the input images. To create an instance of ObjectDetector we need modelPath, inputDescriptor and context object.

String modelPathPrefix = "path-to-model/resnet50_ssd";
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
We need to specify model name as a prefix after the model path, to correctly load it using the Predictor API. eg : If model name is resnet50_ssd and the model files are downloaded in /tmp/model/ folder, then the modelPathPrefix would be /tmp/model/resnet50_ssd.

Now let’s take the input image and load it so that we can feed it to our model. We will use an image taken from here as input.

String inputImagePath = "path-to-downloaded-image";
BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
Input Image : A grumpy dog. Source

Now we can perform inference by callling :

int numberOfObjectsToDetect = 3; // returns top 3 objects in image
List<List<ObjectDetectorOutput>> output = objDet.imageObjectDetect(img, numberOfObjectsToDetect);
ProTip : imageObjectDetect method returns a nested List which can be simplified by the following the code snippet.

ObjectDetectorOutput is a utility class which contains predicted label of the object, probability of the predicted label, Xmax, Xmin, Ymax and Ymin. The Xs and Ys represent pixel values for the detected object’s location in original image and can be used to form a bounding box around the detected object.

Here’s the output generated after converting it using the ProTip :

Class: car
Probabilties: 0.98847263
Coord:312.21335,72.0291,456.01443,150.66176
Class: bicycle
Probabilties: 0.94833825
Coord:155.95807,149.96362,383.8369,418.94513
Class: dog
Probabilties: 0.8281818
Coord:83.82353,179.13998,206.63783,476.7875

Here’s a visualization of predicted results :

Step 4 : What’s Next ?

You can try running more examples on the MXNet repository under Java examples folder.


Conclusion

MXNet Java Inference APIs allows developers to leverage pre-trained Deep Learning models using Apache MXNet to get started with Deep Learning.

Apache MXNet is an open source project. If you are excited by this and would like to contribute, join the project here.