With the recent release of MXNet version 1.2.0, the new MXNet Scala Inference API was released. This release focuses on optimizing the developer experience for inference applications written in Scala. Scala is a general-purpose programming language that supports both functional programming and a strong static type system, and is used with high scale distributed processing with platforms such as Apache Spark.
The new MXNet Scala Inference API supports inference on both CPUs and GPUs. With its model loader and predictor modules, you can write an image classification app in Scala in just a few lines of code. We also made it easy for you to start using MXNet-Scala by making the packages available on Maven.
This is a 3-part blog series introducing the capabilities of the Scala Inference APIs. In this introduction post, we introduce the MXNet Scala Infer API.
MXNet Scala Inference API Overview
The highlights of this API are that it provides both model loading and prediction functionalities. Implementing inference (model loading and prediction) and applications requires only a few lines of code.
The Scala Inference API includes the following functionalities:
- Image Classification
- Object Detection
- Single and Batch Image Classification
Loading a model with Predictor
Loading a model is simple.
Predictor's primary inputs are the path to the model and a description of the model's inputs. Defining CPU or GPU processing and the checkpoint of the model are optional. The model's input description,
inputDescriptor, describes the input expected by the model. These details are typically defined as part of the model training phase, and they are available in the training code. They include the input layer name, input shape, input data type, and in the case of image data, the data order.
val inputDescriptor = IndexedSeq(DataDesc(“data”, Shape(1,3,224,224), DType.Float32, “NCHW”))
val predictor = Predictor(“/path/to/model”, inputDescriptor)
inputDescriptor is required to define the input source and configuration for the model. In the previous code block,
“data” is the name of the input layer of the model, and the second parameter,
Shape(1,3,224,224), is the shape of the input image.
“NCHW" defines the data order: batch size (in this case, one image), number of channels, height and width.
DType.Float32 is the expected data type for the input data.
Inference with Predictor
Predictor has two functions for ingesting content as input, and generating predictions as an output. Inputs are one-dimensional arrays or NDArrays, and the output predictions remain the same type as the input. In MXNet, NDArray is the core data structure for all mathematical computations. An NDArray represents a multidimensional, fixed-size homogenous array. Passing Input NDArray is useful when you have to perform multiple operations on the same input. IndexedSeq is required when you have multiple inputs.
Given an NDArray,
predictWithNDArray will return a result in NDArray format:
val resultNDArray = predictor.predictWithNDArray(inputNDArray)
Given an indexed sequence array,
predict will return a result of the same format:
val result = predictor.predict(input)
Classifier is useful for classification tasks, such as image classification and text classification. Classifier has the features of Predictor, and also provides the ability to attach labels to the results. The index to label mappings are expected to be in a file called
synset.txt, and placed in the same directory as your model, please see an example here. You can create an instance of
Classifier the same way you create a
Predictor, by providing it a path to the model and the model's input description.
The arguments are same as that of the
Predictor. The following code block will create a Predictor instance and load the model and labels.
val classifier = Classifier(“/path/to/model”, inputDescriptor)
In the following code block, you can see how
classify function takes an array
topK to return the number of desired prediction results with highest probability. The result will be a sorted list that contains the detected class, derived from the labels from the synset file, and the class confidence. Like
Predictor, there is a function in
Classifier to handle NDArray inputs as well:
val result = classifier.classify(input, topK)
val resultNDArray = classifier.classifyWithNDArray(inputNDArray, topK)
ImageClassifier is a functional example of how you can extend Classifier for many types of classification tasks. In this case, ImageClassifier is setup to accept both arrays and NDArrays of raw data, or a single image or batch of images that can be preprocessed to conform to the model’s inputs. It returns a list of class labels and class confidence. It also provides an essential toolkit for image preprocessing necessary for the classification process.
val img = ImageClassfier.loadImageFromFile("/path/to/image")
val reshapedImage = ImageClassifier.reshapeImage(img, 224, 224)
val imgInNDArray = ImageClassifier.bufferedImageToPixels(reshapedImage, Shape(224,224))
dThen, we can feed in a instance of Image Classifier to do image Classification:
val classifier = new ImageClassifier("/path/to/model", inputDescriptor)
ObjectDetector demonstrates extending
ImageClassifier as a Single-Shot Detector. Input images are analyzed for distinct objects and their labels and locations are returned. As in the previous examples,
ObjectDetector takes a model path and input descriptor.
val objectDector = new ObjectDetector("/path/to/model", inputDescriptor)
ObjectDetector accepts both image and NDArray inputs for inference.
val result = objectDector.imageObjectDetect(inputImage, topK)
val result_nd = objectDector.objectDetectWithNDArray(NDArrayInput, topK)
The prediction result is
(String, Array[Float]), where
String is the category the object belongs to, and
Array[Float] contains the
topK elements (e.g. the top 10 results with descending confidence).
Accuracy, Xmin, Ymin, Xmax, Ymax. The X's and Y's represent the detected object's pixel locations in the original image, starting from the upper left and ending in the lower right, to form a box around the detected object.
Conclusion and Learn More
The Scala Inference API was designed for ease of use and it offers additional capabilities beyond those discussed in this blog post.
Predictor can handle text inputs and outputs, so you can try it with an LSTM or other text-based model. You could use it for multi-format models like visual question and answer. Community participation is encouraged through questions, requests, and contributions!
To learn more, check out the image classification example on MXNet GitHub repository.