MXNet 1.2 adds built-in support for ONNX

Authors: Anirudh Acharya, Rajan Singh , Roshani Nagmote, Hagay Lupesko — Amazon AI software engineers.

MXNet 1.2 adds built-in support for ONNX

With the latest Apache MXNet 1.2 release, MXNet users can now use a built-in API to import ONNX models into MXNet. With this new functionality, developers can import models created with other neural network frameworks into MXNet and use it for inference or to fine-tune the model.

What is ONNX?

Open Neural Network Exchange (ONNX) is an open source serialization format to encode deep learning models. ONNX defines the format for the neural network computational graph and an extensive list of operators often used in neural network architectures. With ONNX being supported by an increasing list of frameworks and hardware vendors, this serialization standard facilitates developers working on deep learning to migrate between frameworks easily.

Overview of ONNX Import API

The new ONNX import API provides functionality to import ONNX models into MXNet as a symbolic graph. The API is easy to use, and as we’ll see below requires just a few lines of code to import ONNX models. This new API, built into MXNet 1.2, replaces the older onnx-mxnet GitHub repository, and significantly enhances ONNX conformance and operator support.

The import API requires the path to the ONNX model file as input, and returns an MXNet symbol object denoting the model graph, and two python dictionary objects containing the model parameters.

sym, arg_params, aux_params = mx.contrib.onnx.import_model("model.onnx")

To learn more about MXNet’s ONNX API, see the MXNet docs website.

Quick Start

Let’s get started with using the new import API to load an emotion detection model. The model we will import is called FERPlus, and is based on a 2016 paper by Barsoum et al. It was built and trained using Microsoft Cognitive Toolkit (CNTK), and the code is available on the FERPlus repository.

We’ll start by installing MXNet and ONNX. Note that ONNX installation requires Conda, you can install it following the Conda website.

pip install mxnet==1.2
conda install -c conda-forge onnx==1.1.2
pip install Pillow
pip install matplotlib

Now that we have the pre-requisites installed, let’s go ahead and import the model into MXNet.

Run the code below in your Python interpreter or IDE to download the ONNX model into your working directory, and import it into MXNet:

import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
model_file = 'emotion_ferplus.onnx'
model_bucket_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/emotion_ferplus/'
mx.test_utils.download(model_bucket_url + model_file)
sym, arg_params, aux_params = onnx_mxnet.import_model(model_file)

We have now successfully imported the ONNX model into an MXNet symbolic graph, and have also obtained the model’s trained parameters in arg_params and aux_params variables.
 
We will go ahead and run inference through this model, and examine the inference output.

Let’s download an input image and plot it so we can see what it looks like:

from PIL import Image
mx.test_utils.download('https://s3.amazonaws.com/onnx-mxnet/examples/emotion_ferplus/input.png')
Image.open("input.png").show()
A Warriors fan on half time, feeling optimistic

The model was trained on 64x64, 8 bit, grayscale images, normalized to [0..1]. Let’s pre-process the image to fit the model input requirements:

import numpy as np
color_image = Image.open("input.png").resize((64,64))
bw_image = color_image.convert('L')
processed_image = np.asarray(bw_image.getdata(),dtype=np.float64)
processed_image = (processed_image - 127.5)/127.5
# Plot the processed image on a graph.
from matplotlib import pyplot as plt
plt.imshow(processed_image.reshape((64,64)), cmap='Greys', interpolation='nearest')
plt.show()

Following is the processed image that will be fed to the network.

We will now create an MXNet Module object, bind the imported symbol graph and the input data shapes, and assign the model’s parameters. The input name is obtained from the ONNX model file:

# Input data preparation, the name and shape are defined in the ONNX model
inputs = mx.nd.reshape(mx.nd.array(processed_image),shape=(1,1,64,64))
input_name = u'Input2505'
data_shapes = [(input_name, inputs.shape)]
# Initialize and bind the Module
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=[input_name], label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg_params, aux_params=aux_params)

Lets run inference on the model with the test image:

mod.forward(mx.io.DataBatch([inputs]))
result = mod.get_outputs()[0]
print(result)

Printing the inference result should show the tensor below:

[[ 2.6267748   6.7440767  -0.09765175 -0.2639766  -0.40274802 -2.807253   -3.051616   -1.7467498 ]]
<NDArray 1x8 @cpu(0)>

The model is trained to classify facial images into 8 emotion classes: Neutral, Happiness, Surprise, Sadness, Anger, Disgust, Fear and Contempt. The output plotted above shows the aggregated activations of the different emotion classes.

We will now post-process the output, by running it through a softmax function, to map the aggregated activations into probabilities across the 8 output classes and will then plot the results:

# Running the results through softmax
softmax_output = result[0].softmax()
# Let's plot the histogram, and map it to the emotion classes
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111)
emotion_classes = ['Neutral', 'Happy', 'Surprise', 'Sad', 'Anger', 'Disgust', 'Fear', 'Contempt']
plt.ion()
plt.bar(emotion_classes, softmax_output.asnumpy())
plt.show()

As the output below clearly shows, the model predicts this Warriors fan is pretty happy about his team! Go Warriors!

Learn More

We’re working with the ONNX and Apache MXNet communities to further develop ONNX and enhance ONNX support in MXNet.

To learn more about ONNX support in MXNet, see the ONNX API docs in MXNet website.

ACKNOWLEDGEMENTS - Special thanks to Hao Jin for his contributions to MXNet backend operators for ONNX.