Optimizing inference on CPU in the Upcoming Apache MXNet 2.0

Adam Grygielski
Apache MXNet
Published in
11 min readFeb 22, 2021

Authors: Adam Grygielski, Bartłomiej Gawrych, Sylwester Fraczek

Introduction

Deep Learning inference is the process of deploying a trained neural network to perform prediction on unseen data. It is a commonly deployed workload in cloud servers. To provide a good user experience it has to have high performance so it is important to use optimized solutions. Optimized means also reduced hardware load and energy cost.

There are two types of performance bottlenecks to consider when optimizing a neural network. One is heavy compute-bound operations like convolution or fully-connected and the other is many small memory-bound elementwise operations such as ReLU (or transposition).

A couple of methods have been invented for optimizing Neural Networks. Operator fusion allows to chain operations together to speed up memory-bound operations by reducing memory IO operations. Quantization speeds up compute-bound operations by lowering precision and therefore simplifying computation while also reducing the amount of data being processed.

Apache MXNet (incubating) in the upcoming version 2.0 introduces some changes in the interface. Gluon API has now become the default superseding symbolic and model API. It is unifying the flexibility of imperative programming with the performance benefits of symbolic programming. Also MXNet 2.0 now fully supports NumPy semantics.

Operator Fusion

In previous versions of Apache MXNet operators fusion was enabled by default when MXNet was built with Intel oneDNN library. However, with version 2.0 of MXNet, Module API and GraphExecutor were replaced by CachedOp executor. Gluon’s Block is now the first choice API to define and execute a model instead of a symbolic API.

Let’s recap what operator fusion is and what benefits come along with it, but before we move on to that, it is worth mentioning that MXNet support two execution modes:

  • imperative mode - model is executed step-by-step as defined and we can access each tensor between operator execution
  • symbolic mode — based on defined model MXNet engine creates graph which can be optimized by e.g. pre-allocating space — we can’t access[1] tensors between operator execution

An imperative mode is great for debugging purposes and symbolic mode provides a big performance boost. More information about the imperative and symbolic mode you can find here.

Every model can be represented as a directed graph where operators and tensors are represented as nodes. Each connection between operators, where the output of a single operator can be input to another one, is represented in the graph as an edge.

Operator fusion is nothing else but replacing two or more subsequent operators with a new single operator that combines all functions of these operators. As a result, we get a smaller graph — only one kernel is being invoked instead of a few (one per operator). These fused kernels are optimized to run as fast as possible by utilizing all available features of a modern CPU. We also get rid of MXNet engine overhead between operator calls and we’re reducing memory access (read/write) operations.

How to fuse operators in my model?
To fuse model in MXNet 2.0 there are two requirements:

  • the model must be defined as a subclass of HybridBlock or Symbol
  • the model must have specific operator patterns which can be fused

At the time of publishing this article in MXNet we support three major fusing patterns:

  • Convolution + BatchNorm + ReLU/GELU/LeakyReLU/sigmoid/SoftReLU/tanh + Elementwise Add
  • FullyConnected + Elementwise Op (ReLU, Square, Logistic, SoftReLu, BoundedRelu, Sqrt, Exp, Abs)
  • BatchNorm + ReLU

As an example we define example network (sample block from ResNet architecture):

Both HybridBlock and Symbol classes provide API to easily run fusion of operators. All we have to do is add single line of code enabling fusion passes on our model:

net.optimize_for(data, backend='MKLDNN')

If we would like to apply passes on a Symbol all we have to do is to call optimize_for function without passing input data on a symbol instance. Notice that Symbol’s optimize_for function is not done in-place, so we must assign it to a new variable:

optimized_symbol = sym.optimize_for(backend='MKLDNN')

For the above model definition in a naive benchmark with artificial data, we can gain up to 1.25x speedup without any accuracy loss on our testing machine. Detailed data on other models’ speedup is shown in the “Performance and accuracy results” part of the article.

Quantization

Similar to 1.x versions, MXNet 2.0 continues supporting model quantization from float32 to either signed or unsigned INT8 (s8/u8). Using a quantized model for running inference on modern CPUs can greatly increase the performance of your workloads. Quantized models utilize VNNI (Vector Neural Network Instructions) instruction set to speed-up execution of compute-heavy operations like convolution or a dot product. Moreover, using int8 data type reduces the amount of data reads in memory-bound operations like pooling or elementwise functions. MXNet uses optimized kernels delivered by Intel® oneDNN[3] library to speed up model execution.

During the quantization procedure, a floating-point model is first fused using fuses mentioned in the previous chapter. After that, operators that have int8 kernels are being marked as quantized and surrounded by quantize/requantize/dequantize operators. At last, the model can be either calibrated to get rid of requantize operators or left as it is to calculate scales during runtime. There are 2 major ways of quantizing an fp32 model:

  • without calibration. This way, all we have to do is call the quantize_net function with an fp32 Gluon model and a list of input data shapes. However, this approach is not recommended in terms of performance. It will result in requantize nodes in the graph that calculate min/max values during each forward pass. Calibrating a model before deploying it results in much faster inference.
  • with calibration. In this approach, after quantizing the graph, a model is run with user-supplied calibration data to collect statistics of quantized layers and set min/max values as parameters. The choice of these thresholds is based on the chosen calibration method.

Currently, there are three calibration methods supported:

  • naive — min/max values from the calibration run.
  • entropy— uses KL divergence to determine the best symmetrical quantization thresholds for a given histogram of values.
  • custom — uses user-defined CalibrationCollector to control a calibration process.

In MXNet 2.0, the quantization procedure has been adjusted to work well with Gluon models since it’s the main API now. The goal was to allow the user to quantize fp32 HybridBlock model in just a few lines of code.

Quantization flow in MXNet 2.0

As an example of a quantization procedure, we will use pretrained resnet50_v1 from model_zoo.vision. To get it we simply have to run the following code:

To compare the performance, we will use simple function calculating total inference time on the model with an artificial data:

Now, to get a quantized model, all we need to do is call quantize_net function from contrib.quantize and provide list of input data shapes:

This way, we have quantized `resnet50_v1` ready to run inference tasks. However, we didn’t use any calibration data, therefore we won’t get a satisfactory performance boost.
We can compare these 2 models with our benchmark_net function to calculate the total speedup. To get better results, we should first hybridize both models with static_shape and static_alloc flags set to True. We can do it to let MXNet know that we won’t be changing shapes in the model during runtime thus it can pre-allocate the memory so we won’t get runtime allocations. To properly evaluate the performance benefits of quantization, we should compare it with fp32 model with MKLDNN backend enabled. It is because the quantization procedure fuses a graph before quantizing it so if we won’t do it, we will get the summed benefit of fusing and quantization.

Output:

> Speedup: 0.72x

As we can see, we didn’t get any performance benefits from using int8. It came out, that calculating min/max during runtime adds a big overhead. Now let’s try calibrating the model before actual execution. The only difference in code is that we have to provide quantize_net with calibration data stored in DataLoader class. For this example, we will use the same dummy_data but in real use-case, it would be most probably small portion of validation dataset. Notice, that we no longer have to provide data_shape attribute because it will be taken from calib_data. If we don’t specify calib_batches parameter, the whole calib_data will be used.

Output:

> Speedup: 3.8x

This time we’ve got significant performance boost with just few lines of code.

The other aspect of lowering the precision of a model is a difference in its accuracy. We will check that on previously tested resnet50_v1 with ImageNet dataset. To run this example you will need ImageNet dataset prepared with this tutorial and stored in path_to_imagenet. Let’s compare top1 and top5 accuracy of standard fp32 model with quantized int8 model calibrated using naive and entropy calibration mode. We will use only 10 batches of the validation dataset to calibrate quantized model.

Output:

> FP32 Top1 Accuracy: 0.76364 Top5 Accuracy: 0.93094
> INT8Naive Top1 Accuracy: 0.76028 Top5 Accuracy: 0.92796
> INT8Entropy Top1 Accuracy: 0.76404 Top5 Accuracy: 0.93042

We can see that we’ve got almost the same accuracy for quantized model but we are able to run it with much less computing power and lower latencies. The difference between calibration method is dependent on the model itself, used activation layers and the size of calibration data. It should be chosen empirically. However, if you are not experienced with custom calibration methods, you should stick to entropy by default.

Custom layer collectors and calibrating the model

We have prepared an interface to give the user as much flexibility as possible in almost every step of quantization and one of the results of this approach is mentioned earlier custom LayerOutputCollector parameter in quantization API.

Layer collectors are responsible for collecting statistics of each node in the graph — it means that we can observe the input/output data of every operator executed. We are able to do this by utilizing the register_op_hook method of HybridBlock class.

To write your own layer collector your class has to inherit from the CalibrationCollector class, which we have provided, to stay consistent within API. As “code is worth a thousand words” below we show an example implementation of CalibrationCollector:

We are ‘injecting’ names of nodes which require calibration into the include_layers instance attribute of custom collector — it is organized in this way because collecting statistics for every node is very time consuming, but thanks to this you can skip nodes that do not need calibration (also you can ignore this attribute and implement your own logic of picking nodes to calibrate).

After collecting all statistic data post_collect function is called. In post_collect you can implement additional logic, which will process gathered data in the implemented way and at the end return dictionary of node name as key and tuple of min and max threshold as value, which will be used to calibrate nodes.

Example of usage with quantization API:

It is not easy to write your own logic of collector, but if you know what you are doing you can get better accuracy. In the previous version of MXNet and GluonNLP our colleagues wrote layer collector for BERT models — they clip minimum/maximum values of some layers to get better accuracy BertLayerCollector (note that it’s not compliant with MXNet 2.0 — you can refer to predefined collectors in MXNet 2.0 code).

Performance and accuracy results

Here we present performance results of some CV models. We compare Fused and Quantized models to Base models hybridized with static memory allocation.

Relative Inference Performance (img/sec) for Batch Size 128

As you can see, popular ResNet50 gains 62% speedup with just operator fusion and it is more than 6x faster than Base fp32 model while quantized. We may also observe that mobilenet benefits the most from presented optimizations. It’s mainly due to the efficient fusing. Moreover, in the quantized version, relu6 (which is the activation function in mobilenet) can be achieved just by adjusting scale factors of int8 computation thus we could optimize it even more.

ImageNet(ILSVRC2012) TOP1 validation accuracy

These are the results of TOP1 ImageNet accuracy for computer vision models and as you can see — by calling single function you can gain a significant performance boost almost without losing accuracy. In the chart, the biggest accuracy drop is visible in mobilenet v2, but it is still only half a percentage point less than fp32 model. On the other hand, VGG19 has no accuracy lost at all and after quantization, it is over 4 times faster than the baseline model.

Summary

In this article, we showed how you can speed up your MXNet model on CPU with oneDNN support. If you’re fine with tiny accuracy drop, model quantization is a great method of significantly improving performance. Otherwise, you can still use model fusion supported by oneDNN primitives without any side effects. In this case, the performance boost is not so impactful as in the quantized model but accuracy doesn’t change.

Moreover, the article described how you can utilize the new CalibrationCollector class to have better control of the calibration process. Such a solution gives you flexibility and a chance to increase the accuracy of your quantized model.

Many things have changed since the introduction of MXNet 2.0. Gluon API has taken over Executor-based execution thus now users have to manually call optimize_for function to get benefits from CPU optimization passes. We have also introduced the new BatchNorm + ReLU fusion to address ResNet v2 models. Our future plans are to add support for more and more int8 models, not only in a CV domain. We are also exploring new fusion opportunities so we can speed up popular models even more.

References

[1] Devlin, Jacob, et al. “BERT: Pre-training of deep bidirectional transformers for language understanding.” arXiv pre-print arXiv:1810.04805 (2018).
[2] Banu Nagasundaram. “Vector Neural Network Instructions Enable Int8 AI Inference on Intel Architecture.” https://www.intel.ai/vnni-enables-inference/, 2019.
[3] Intel® oneDNN library (https://github.com/oneapi-src/oneDNN)

Benchmark environment

CPU: Intel® Xeon® Platinum 8280L CPU @ 2.70GHz
Memory: 187 GB RAM
OS: CentOS Linux 7 (Core)
Compiler: gcc (GCC) 7.3.1 20180303 (Red Hat 7.3.1–5)
MXNet Commit SHA: 3746babc8fdb211584a9a661207061cb646b01a8
oneDNN Commit SHA: 2e4732679f0211bb311780d0f383cf2dce9baca7

Notices and Disclaimers

© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex.

--

--