Speed up your BERT inference by 3x on CPUs using Apache TVM

Haichen Shen
Apache MXNet
Published in
6 min readJul 8, 2020

Introduction

BERT (Bidirectional Encoder Representations from Transformer) [1], a pre-trained natural language processing (NLP) model, was proposed by Google in 2018 and now plays an important role in NLP tasks such as text classification, sentiment analysis, question answering, and more. However, BERT is known to be compute intensive even for inference. As we are seeing the trend from lightweight solutions such as LSTM to the usage of BERT in production pipelines, BERT inference performance is becoming increasingly critical to reach the desired latency and scalability.

Recently, deep learning compilers such as TVM [2], MLIR [3], and Glow [4], which use compiler techniques to optimize model inference, have gained traction in both academia and companies across several industries. The progress [2, 5] has demonstrated that deep learning compilers can achieve lower inference latency compared to frameworks. At Amazon Web Services, we make multiple contributions to the Apache TVM open source project and we use TVM to speed up many model inference use cases, including BERT, on various platforms. We also offer a service, called Amazon SageMaker Neo, that provides a managed experience for compilation on a variety of frameworks, operators and hardware targets.

In this blog, we are going to share our recent progress on improving BERT inference performance on CPUs (e.g., c5 and m5 instances on Amazon EC2) and show you how to use TVM to reproduce our results. Overall, using TVM can help achieve up to 2.9x lower latency on EC2 c5.9xlarge instances and up to 2.3x higher throughput. More importantly, achieving these results requires almost no developer effort as we will demonstrate in this post.

BERT inference improvement

We will use the BERT-base [1] and DistilBERT [6] models to demonstrate our improvement of BERT inference performance using TVM. The BERT-base model contains 12 layers of transformer blocks and consists of 11.2 GFLOPs (floating operations) and 109M weights. DistilBERT is a distilled and smaller version of the BERT model, and only has 6 layers of transformer blocks. It contains 5.6 GFLOPs and 67M weights, and thus runs faster during inference. Computation is counted for batch size 1 and sequence length 128.

We take the pre-trained BERT-base and DistilBERT models from the GluonNLP (0.9.1) model zoo. The data type is fp32 in all benchmarks, as Intel CPUs do not support fp16 operations. We measure the inference latency for sequence lengths 64, 128, and 256 when batch size is 1. We compare the latency of Apache MXNet (mkl 1.6.0) and Apache TVM on EC2 c5.9xlarge instance (Intel CPU with 18 physical cores). The table below shows that the latency can be reduced by between 2.1x to 2.9x via TVM optimization for the BERT base and DistilBERT models. Notably, the latency of DistilBERT under sequence length 128 is only 9.5ms on CPUs after optimization. In comparison, the ONNX runtime achieves 9ms on similar CPUs using a 3-layer BERT model which is 2x smaller than the DistilBERT model.

We further evaluate the throughput of BERT with batch size 4 and sequence length 128 on more types of EC2 instances, including c5.2xlarge, c5.9xlarge, m5.2xlarge, and m5.12xlarge. TVM consistently achieves higher throughput for the BERT base and DistilBERT models with an average of 2x improvement.

The performance gain comes from three aspects: (1) small operators get fused together to reduce overhead in memory transfer between cache and main memory, (2) kernels generated by TVM achieve better performance in general, and (3) TVM performs graph-level optimizations and replaces some heavyweight math operators such as erf with an approximate implementation.

How can you optimize BERT using TVM?

Next, let’s walk through the steps to reproduce the results shown above.

We first launch a CPU instance such as c5.9xlarge on Amazon EC2 using AWS Deep Learning AMI, Ubuntu 18.04. After we ssh into the server, we activate the conda environment mxnet_p36 and install GluonNLP Python packages:

source activate mxnet_p36
pip install gluonnlp==0.9.1

Next, we install TVM:

pip install https://tvm-build-public.s3-us-west-2.amazonaws.com/dlami-cpu-mkl/tvm-0.7.dev1-cp36-cp36m-linux_x86_64.whl
pip install https://tvm-build-public.s3-us-west-2.amazonaws.com/dlami-cpu-mkl/topi-0.7.dev1-py3-none-any.whl

These Python wheel packages are compiled for the Deep Learning AMI, and are not guaranteed to be compatible with other environments. You can also compile TVM from source (instructions in the appendix).

After we’ve installed TVM, we can optimize the BERT model. You can find the steps to train and fine-tune a BERT model using GluonNLP from the tutorials on the GluonNLP website. Once we’ve instantiated a BERT model from GluonNLP, we can compile and optimize it using TVM.

We first need to convert the BERT model from MXNet to TVM Relay IR given an input shape mapping. Relay IR is an intermediate representation used in TVM that represents the computation graph of model architecture.

# Load model from GluonNLP
...
mx_model = nlp.model.BERTClassifier(...)
# Convert MXNet model to TVM
shape_dict = {
'data0': (batch, seq_length),
'data1': (batch, seq_length),
'data2': (batch,)
}
mod, params = relay.frontend.from_mxnet(mx_model, shape_dict)

We then define the target and optimization level to compile the model using TVM. The CPUs in both c5 and m5 instances support AVX-512 instructions, which can boost floating computation via vectorization. So the CPU architecture (-mcpu=skylake-avx512) is specified in the target to enable this. We also utilize the Intel MKL library (-libs=cblas), which features an optimized BLAS library on Intel CPUs to speedup the computation of operators like matrix multiplication. Third, we explicitly turn on an optimization pass called “FastMath”. This pass will replace certain heavyweight math operators such as erf with an approximate implementation in order to reduce its latency.

target = "llvm -mcpu=skylake-avx512 -libs=cblas"
with relay.build_config(opt_level=3, required_pass=["FastMath"]):
graph, lib, cparams = relay.build(mod, target, params=params)

Finally, we create a lightweight executor included in the TVM, and initialize it with weights and input data. We can then run the executor and examine the output. We also did a sanity check for the TVM output and verified the correctness.

ctx = tvm.cpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
rt.set_input(data0=inputs, data1=token_types, data2=valid_length)
rt.run()
out = rt.get_output(0)
print(out.ansumpy())
# verify the correctness
tvm.testing.assert_allclose(out.asnumpy(), mx_out.asnumpy(), rtol=1e-3, atol=1e-3)

The full script can be found at https://gist.github.com/icemelon9/860d3d2c9566d6f69fa8112840dd95c1.

Conclusion

In summary, we are seeing significant improvement of BERT inference using TVM on CPUs. The latency of BERT inference is reduced up to 2.9x and the throughput is increased up to 2.3x. It takes only a few lines of code to achieve such improvement and make the deployment.

Today, this solution works well in scenarios where padding to fixed sequence length is possible. We are currently working on expanding the support to include dynamic sequence length. Stay tuned for more news.

Appendix

Compile TVM from source

This document from the TVM website already provides the instructions how to compile and install TVM from source. Here we only describe some specific dependencies and configures to achieve the best performance for BERT inference.

First, you need to install the dependencies in addition to the building requirements included in the TVM document: LLVM (>=6.0) and Intel MKL library (instructions can be found here).

Next, we can compile the source and build the shared library.

git clone --recursive https://github.com/apache/incubator-tvm.git tvm
cd tvm && mkdir build && cd build
cmake -DUSE_LLVM=/path/to/llvm-config -DUSE_BLAS=mkl -DUSE_OPENMP=intel ..
make -j4

Finally, we install the TVM python package.

cd python; python setup.py install --user; cd ..
cd topi/python; python setup.py install --user; cd ../..

Reference

[1] Jacob Devlin, et al. “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.” arXiv pre-print arXiv:1810.04805 (2018).
[2] Tianqi Chen, et al. “TVM: An Automated End-to-End Optimizing Compiler for Deep Learning.” 13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18), 2018.
[3] Chris Lattner, et al. “MLIR: A Compiler Infrastructure for the End of Moore’s Law”. arXiv preprint arXiv:2002.11054, 2020.
[4] Nadav Rotem, et al. “Glow: Graph Lowering Compiler Techniques for Neural Networks”. CoRR, abs/1805.00907, 2018.
[5] Liu, Yizhi, et al. “Optimizing CNN Model Inference on CPUs.” 2019 USENIX Annual Technical Conference (USENIX ATC 19), 2019.
[6] Sanh, Victor, et al. “DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter.” arXiv preprint arXiv:1910.01108 (2019).

--

--