From Three Hours to 25 Minutes: Our Journey of Optimizing Mask R-CNN Training Time Using Apache MXNet

Lin Yuan
Apache MXNet
Published in
12 min readApr 1, 2020

Written by: Karan Jariwala, Jerry Zhang, Lin Yuan, Omar Orqueda@ Amazon and Przemek Tredak@ NVIDIA

And thanks to Sandeep Krishnamurthy, Andrea Olgiati, Zdravko Pantic@ Amazon and Triston Cao@ NVIDIA for their feedbacks

Overview

The size and complexity of deep neural network (DNN) models has increased drastically in the past few years. On one hand, these sophisticated models have raised the bar in the models’ accuracy for various domains such as Computer Vision, Natural Language Understanding, etc. However, on the other hand, the sheer size of model parameters and their gradients have posed significant challenges to train these models on existing compute infrastructure. For example, the BERT model that captured a lot of attention in the NLP domain contains over 330 million parameters, while the Mask R-CNN that achieves state of the art performance on instance segmentation has 51 million parameters.

One of the practical approaches is to horizontally scale the training tasks using data-parallel distributed training. However, it is not trivial to linearly scale efficiency as the number of workers increases. Some of the challenges that arise include: (1) communication overhead needs to be considered in large training clusters; (2) worker load balancing becomes harder in synchronous stochastic gradient descent-based training methods; (3) large batch size will lead to lower accuracy and require more sophisticated hyper-parameter tuning; (4) validation between epochs needs to be distributed, otherwise it becomes a bottleneck; (5) how to leverage the network topology in the cluster.

Given a task to minimize the total training time of large DNNs and meet target accuracy, we need to address all the challenges mentioned above. This blogpost records our journey of optimizing the training time of Mask R-CNN from three hours to 25 minutes using the Apache MXNet framework on the AWS cloud.

Where did we start

We began training Mask R-CNN using Apache MXNet v1.5 together with the Horovod distributed training library on four Amazon EC2 P3dn.24xlarge instances, the most powerful GPU instances on AWS. These instances provide up to 100 Gbps of networking throughput, 96 custom Intel Xeon Scalable (Skylake) vCPUs, eight NVIDIA V100 Tensor Core GPUs with 32 GB of memory each, and 1.8 TB of local NVMe-based SSD storage.

The GluonCV model zoo contains numerous pre-trained models of various Computer Vision applications. The Mask R-CNN network was selected from the GluonCV model zoo. The dataset used to train the model is MS-COCO 2017, and the target accuracies selected are based on MLPerf benchmark:

  • Box min AP: 0.377
  • Mask min AP: 0.339

Our training time on four instances was more than four hours with a global batch size of 32 and we could not meet the target accuracy on test data (our box min AP was 0.357, and mask min AP was 0.324). Also, the training time doesn’t include the validation time since the distributed validation feature wasn’t implemented at that time. Hence, we saved the model at the end of training and run validation once to get the accuracy metrics. We decided to improve training time and target accuracy on single-node and four-node clusters first and then increase the cluster size to 24 to improve the scaling performance.

What are the challenges we solved

There were two problems we needed to solve before we could embark on optimizing training time. The first was accuracy. We needed to achieve the target accuracy even at the cost of longer training time. To do that, we down-scaled the model training to one node with a smaller batch size (32 samples) and tuned hyper-parameters to help the model converge. We achieved bounding box min AP 0.379 and mask min AP 0.343, both of which exceeded the target accuracy. Second, we measured the network bandwidth in the cluster using nccl-test script and iftop. We found a list of nodes that provide the fastest inter-node and inter-GPU communication.

There were three areas we investigated to reduce total training time. First was NVIDIA GPU performance. We performed profiling on a single GPU using NVIDIA Nsight Systems. Figure 1 below shows the percentage of runtime spent on each operation. We targeted the ones with a large percentage and found ways to optimize them. Second is scaling efficiency. We measured throughput as we scaled up the number of instances for training. We noticed a reduction in scaling efficiency beyond four nodes. We ran horovod_timeline to understand the time spent on communication and analyzed the utilization of network bandwidth and also the overlap between GPU computation and inter-GPU communication in each iteration. Third is data loading. We checked if there was any GPU idle time waiting for data, and if the data loading throughput was optimal.

Fig. 1: CUDA profiling result when training Mask R-CNN on a single GPU before and after our optimization.

Performance optimization is an iterative process

Our performance optimization was not a one-shot effort. It involved many iterations of trials and rectification. Given a tight deadline, we needed to apply some estimation and pick the most promising experiment to try. It required some intuition and good judgment, since the estimation at the beginning of the optimization could differ very much from the actual results. As we gradually optimized the training time and optimization solution space became smaller, our estimations became more accurate. To speed up the work, we also divided the tasks in parallel. A group of engineers focused on optimizing runtime in a single NVIDIA GPU; the other group of engineers tried to optimize scaling efficiency on a large cluster. We also had several applied scientists focus on improving model accuracy using various techniques such as optimizer enhancement, hyperparameter tuning, gradient accumulation etc.

Fig. 2: Iterative optimization process

We worked in a two-week sprint. In each sprint, we created a priority queue of all the possible optimization techniques we could employ to improve performance, gave a guesstimate on the percentage of performance improvement for each technique and picked a few techniques to implement based on both the impact (by percentage of improvement) and feasibility. After each sprint, we measured the actual performance gain we achieved for each of the optimization techniques implemented. By comparing the actual improvement with the expected one, we gained more understanding of the overall performance bottleneck, which helped give a more accurate estimate of performance improvement. We then re-estimated the performance improvement for each remaining techniques in the queue and/or added new techniques. We repeated this process for each new sprint.

As we got closer to the optimal training time, we had much higher confidence when evaluating the performance improvement for each technique, and at the same time the solution space for future optimizations became smaller. We changed the sprint period from two weeks to one week to allow a faster iteration to optimize total training time. After every iteration of integrating the optimizing techniques mentioned below, we reduced the total training plus validation time from three hours to 35 minutes to 31.8 minutes to 25 minutes, which was the fastest on AWS compared with PyTorch (26 minutes) and TensorFlow (27 minutes).

What worked and what did not

We experimented with various techniques trying to achieve the fastest training time on 24 nodes while meeting the target accuracy. Some of them worked and gave us the expected performance boosts while others did not perform as well as expected. In this section, we list the techniques that helped us to achieve our performance goal and also those that did not work.

GPU Performance Optimization

For performance optimization on a single NVIDIA GPU, we identified several large kernels and applied optimization techniques to reduce their compute time. Specifically:

  • Remove redundant mask prediction: The original implementation uses both positive and negative sample for box regression and mask prediction branch. However, this is inefficient, as only positive samples are used in loss calculation. Although, during runtime, the number of positive samples is dynamic, it is upper bounded to 25% of the total samples. Thus, we could first sort the samples and then slice out the first 25% of samples, discarding 75% of negative sample that would not be used for box regression and mask prediction loss. This yielded a 20% speed improvement.
  • Optimize RoIAlign operator: The standard RoIAlign operator does not handle invalid RoIs. Invalid RoIs occurs when we distribute RoIs to different feature maps based on their size, which produces dynamic shape input tensor. We converted this dynamic operation into a static one by masking a static shape tensor, so we could leverage MXNet hybridization. Thus, our RoIAlign needed to skip masked RoIs to work efficiently. In practice, we also sorted the RoIs by size, so that they could be masked in consecutive groups, and be skipped efficiently on GPU.
  • Optimize NMS operator: Non-maximum suppression (NMS) is an operation used to filter out the detection proposals to only the best ones. The algorithm of NMS is as follows:
valid proposals = all proposals
Until valid proposals set is empty:
Take the valid proposal `p` with the highest score
Compare it with all the other valid proposals
Remove the proposals that share the class and have large enough intersection with `p`
Move `p` to the result proposals

The direct implementation of this algorithm exposes only limited parallelism as one needs to compare all the candidate proposals with just 1 proposal in order to assess which ones are still valid. In the case of Mask R-CNN there is not enough proposals to fully occupy the GPU. In order to overcome that we employ a speculative algorithm that computes comparisons with multiple proposals, some of them unnecessarily, in order to increase the parallelism of the task and fill the entire GPU.

  • Fuse Box Encoder and Decoder: Box encoder transforms bounding box coordinates to anchor box offsets, which is easier for the model to learn. The box decoder transforms the offsets prediction from model to correct bounding box coordinates. Both of these blocks are composed of multiple operators, which introduced non-trivial amount of kernel launch overhead. We fused these operators together to reduce the overhead. We managed to improve the throughput by 5%.
  • Mixed precision: We used Automatic Mixed Precision (AMP) and improved runtime by 10%. Casting gradients to FP16 for communication improved throughput by another 2%
  • Hybridize model: MXNet allows users to construct the model using imperative programming style for better test and debuggability. It also provides a hybridize API to boost model performance that is equivalent to symbolic programming. We used model hybridization to achieve 5% performance improvement. Enabling the static_alloc option in the hybridize API yielded another 1% throughput improvement

Scaling Efficiency

  • Elastic Fabric Adapter: We used Amazon EC3 P3dn.24xlarge instance, which has 100-Gbps network bandwidth and the new EFA network interface for highly scalable inter-node communication. This means we can train the model on a higher number of nodes with better inter-node communication bandwidth and which results in lower training time.
  • Physical CPU bind: The Amazon EC2 P3dn.24xlarge contains 96 custom Intel Xeon Scalable (Skylake) vCPUs and eight NVIDIA V100 Tensor Core GPUs with 32 GB of memory each. We bind each GPU with 12 vCPUs(six from first CPU and six from the second CPU) using numactl --physcpubind command. numactl is a utility that is used to control NUMA(Non-Uniform Memory Access) policy for processes or shared memory. It is a memory architecture where each core is having a memory region attached to it directly for quick access to memory (local memory), and other regions where the memory access is slower.(non-local memory). By binding each GPUs with 12 vCPUs, throughput improved by approximately 8% without having any impact on the accuracy metric.
  • HOROVOD_NUM_NCCL_STREAMS=2: Number of stream for NCCL operations. By using two NCCL streams, we saw 2% improvement in throughput with batch size of 1 on 24 P3dn.24xlarge. MXNet does a combination of FP32 and FP16 reductions and using a second NCCL stream helps these final FP32 and FP16 reductions to occur concurrently and reduces the overhead.
  • NCCL_TREE_THRESHOLD=4294967296: In the two-node cases, NCCL uses the tree algorithm for all message sizes by default and this gives a 2x improvement over the ring algorithm in theoretical bandwidth. However, this 2x factor does not hold true for nodes more than two and due to this we may see a drop in throughput. For the number of nodes greater than two, NCCL calculates tree threshold and for messages whose size is greater than tree threshold, NCCL switches back to the ring algorithm. By keeping the tree threshold to 4294967296, NCCL used a tree algorithm instead of ring and we observed 4-5% improvement in throughput.
  • Distributed validation: Distributed validation showed significant improvement in validation compute time. The validation time per epoch is 13 secs on 24 P3dn.24xlarge in comparison to one to two minutes on non-distributed validation.
  • Multi-image per device and aspect ratio grouping: We implemented the multi image per device and aspect ratio grouping support to speed up the training on a single-node and multi-node clusters. Users can now use batch size of one, two, or four per GPU to train the Mask R-CNN model. Also, note that batch size greater than four may lead to GPU out-of-memory error, depending on the GPU memory.

Convergence

When we were training Mask R-CNN with large global batch size, e.g. > 128, convergence became an issue as the default hyperparameters caused divergence. We tuned the following hyperparameters to make sure that our model converges correctly and efficiently.

  • Scaling learning rate: According to Goyal et al., learning rate need to scale linearly as the batch size increases. However, in our case, increasing learning rate linearly caused instability, and thus, we used the base learning rate of 0.01 on one P3dn.24xlarge instance with batch size of 8 and 0.16 on 24 P3dn.24xlarge instances with batch size of 192 (if scaled linearly, this would be 0.24).
  • Learning rate schedule: Instead of using a standard learning rate schedule, we adopted a custom schedule in our training. As the magnitude of gradients were very large in the initial training iteration, we decreased the initial learning rate and increased the number of iterations before reaching the target learning rate. In the learning rate decay stage, we changed scheduler to decay only once at epoch 10 in contrast to decaying twice at epoch 8 and epoch 10 in the single node case. Figure 3 below illustrated the difference in learning rate schedule when training on a single node and on 24 nodes.
Fig. 3: Learning rate schedule in single node and 24 nodes

Data loader

  • Increase the number of data workers from four to eight: data worker is a process used for data preprocessing. The number of data workers is equal to the number of multiprocessing workers to use for data preprocessing. Increasing the number of workers may lead to better utilization of CPU resources used for data preprocessing. Too large value, however, may lead to increased contention for CPU resources. Based on our empirical result, using eight data workers yields the best throughput performance.

What did not work

There are a few techniques we tried but did not yield performance optimization as we expected.

  • Cythonize RPNTargetSampler
  • Hierarchical Allreduce
  • Dynamic batching
  • Pointwise fusion

Final results

By integrating all the above optimization techniques, we achieved an approximately 7x times improvement in training throughput on 24 P3dn.24xlarge cluster nodes. The total runtime before the optimization was 175 mins and 25.7 mins after the optimization with the same number of epochs and target accuracy. By December 2019, it was the fastest training time of Mask R-CNN model on MS-COCO2017 dataset on AWS cloud.

Tab. 1: Training time and accuracy of Mask R-CNN on MS-COCO 2017 dataset on different clusters

Fig. 4 below shows the throughput improvement before and after our optimization on 1, 4 and 24 nodes. Our optimization on a single instance achieved 3.5X throughput improvement and combined with multi-node optimization techniques, we achieved 4.8X throughput improvement on 24 nodes.

Fig. 4: Throughput comparison before and after our optimization. On single node, our throughput has improved by 3.5X, while on 24 nodes, our throughput has improved more by 4.8X. Notably, our 4 nodes throughput after optimization is almost as good as 24 nodes before optimization.

Future Improvements

With the given time frame and available resources, we prioritized to integrate above mentioned optimization. But there are these future improvements when implemented could improve the training throughput even more:

  1. Implement NHWC layout for operators in Mask R-CNN
  2. Fuse many tail small kernels that decreases the overhead and improve the training throughput

Summary

We optimized the training time of Mask R-CNN using Apache MXNet from three hours to 25 minutes on 24 Amazon P3dn.24xlarge EC2 instances. We implemented optimization techniques that target both single GPU performance and communication. Due to limited time, there are still other area we could further optimize the runtime such as NHWC layout conversion, better scheduling algorithm in communicating the gradients, more efficient data loaders on NVIDIA GPUs etc.

Our script to train Mask R-CNN on single and 24 nodes can be found at: https://github.com/dmlc/gluon-cv/tree/master/scripts/instance/mask_rcnn/benchmark

This work will also be presented at the 2020 GPU Technology Conference (GTC 2020). You can go to https://www.nvidia.com/en-us/gtc/session-catalog/?search=S22483 to find more information about this talk.

--

--