Accelerating Deep Learning Recommender Systems by Over 15x Using RAPIDS, PyTorch and fast.ai

Even Oldridge
RAPIDS AI
Published in
11 min readSep 16, 2019

This June the RAPIDS Deep Learning team took part in the RecSys 2019 Challenge, where we placed 15th out of 1534 teams despite joining the competition in its final weeks. The competition centered around recommending hotel listings to users of the website Trivago, who was the host of the competition. Given a list of impressions, the task was to rank the impressions such that the clicked item appeared as close to the top as possible, with the final score determined by Mean Reciprocal Rank (MRR) on a holdout test set. Here’s an example of what the data looked like from the competition website:

Figure 1. RecSys Challenge 2019 Example Data

Our methodology relied on feature engineering, a stacked ensemble of models, and the fastai library’s tabular deep learning model, which was the main model architecture we utilized for deep learning. The focus of this blog post, and in the more technical paper we’re presenting at the RecSys Competition Workshop, are:

  1. the improvements in model training time that we were able to accomplish by accelerating preprocessing with cuDF (RAPIDS open-source dataframe library),
  2. using our custom PyTorch batch dataloader instead of the standard item by item dataloader,
  3. scaling batch size to extremes, and
  4. the improvements to the kernel used to compute the embedding gradients in PyTorch.

In our final solution we sped up training of the fastai tabular model by a factor of 15.6x, from a workflow of 891.8s (14m52s) to 57.2s on the 40M row RecSys dataset. We also accelerated our feature generation by a factor of 9.7x, reducing the time taken to generate the features used in our model from 51 minutes to 5. With training time so drastically reduced, the iteration involved in generating new features and training new models is much more fluid. This allows for the rapid prototyping of deep learning based recommender systems in hours as opposed to days.

We’re very excited to apply the acceleration of RAPIDS to the fastai library. We’ve long been fans of Jeremy, Rachel, Sylvain and their vision of a more accessible deep learning education (for anyone interested in the topic). The process is underway to get these changes integrated into the fastai library v2 directly making it even easier to accelerate your tabular deep learning workflow with RAPIDS. For now you can use the basic_data methodology and try out in our example repo. We’d love to hear from anyone who tries out these methods on their own data. Now, let’s dive in!

Feature Engineering and Preprocessing

RAPIDS is becoming known for accelerating the feature engineering phase of tabular data workflows through previous posts by our teammate Jiwei Liu, who we were lucky enough to work with on the competition. Unsurprisingly RAPIDS shines in this case as well, accelerating the feature creation phases of the project from over 50 minutes to just over 5. This speedup of 9.7x is typical of the type of dataframe acceleration RAPIDS provides through the use of cuDF instead of Pandas. The feature engineering code can be found in its own directory here.

Table 1. Feature Engineering Acceleration Using RAPIDS

Getting data ready for the model also involves dataframe operations that can be accelerated on the GPU. Categorical variables must be encoded into a numerical representation for embedding lookups while continuous variables null values are mapped to the median, and a binary is_null variable is created indicating rows where the value has been replaced. New to RAPIDS 0.8 is the functionality to encode categorical variables using nvstring functionality. We were excited to try this out in the context of the fastai workflow.

Fastai performs preprocessing steps during the creation of the databunch, which aggregates training, validation, and optionally test dataloaders into a single object. The original workflow utilizes Pandas before eventually converting to Numpy, stacking the categorical and continuous variables separately and converting to a longtensor and a floattensor respectively. The creation of the RecSys databunch using the original workflow took 397s (6m37s).

The cuDF implementation takes advantage of the byte transfer capabilities of its parquet reader and avoids any unnecessary conversions, copying the parquet files created during feature generation directly into GPU memory. A future blog post will go into detail on how the dataloader works. A csv version is also available but is not quite as performant. After preprocessing has taken place, the cuDF dataframes are memory mapped directly into tensors using a dlpack based zero copy transfer of data. These tensors are optionally transferred to the CPU to be used in the PyTorch dataset. The transfer is unnecessary if your GPU has enough memory to hold both the dataset and the model.

The unoptimized workflow took 401.7 seconds to preprocess the feature engineered data and get it ready for model training. The optimized cuDF workflow shortens this to 47.41 seconds when using CPU memory to store the tensors, or 41.8 seconds if the tensor can stay on the GPU. This speedup of 9.6x, similar to what was achieved in feature generation.

Dataloaders

Dataloaders are the mechanism by which data is passed in batches to the model. In the vanilla PyTorch dataloader this takes the form of an iterator that randomly selects indices from the dataset, grabs the data, collates the results into a batch, and then passes that batch to the GPU. This process of grabbing from the dataset item by item is a relatively inexpensive one, but it’s not free, and in order to keep the GPU utilized you generally need to assign multiple workers to create batches in parallel.

The Dataloader is a function that returns an iterator over the dataset and creates batches. This functionality has been implemented in the form of a batch dataloader that loads the entirety of the batch from CPU memory into a tensor in a single memory access. BatchDataloader and BatchDataset replace the vanilla Dataloader and Dataset functions, but are otherwise using the same API. In order to maintain the randomization required when training DNN models effectively, we implemented a mechanism to shuffle the data before training and at the beginning of each epoch.

Loading data in this way has significant benefits. First, no multiprocessing is required, which reduces bus errors and significantly speeds up workflows in environments like Windows where multiprocessing is slower due to the way new processes are spawned. Second, performance of the batch dataloader is better than that of the single-item, multi-worker dataloader on tabular datasets.

Figure 2 below shows the before and after profiles of the training loop for an extreme batch size of 204,800. Initial testing was done at a more “modest” batch size of 4096, however the profiles are similar in nature. As in Figure 2, in the batch size 4096 unoptimized profile the dataloader is a largest component of the profile, taking 260 seconds or 54.6% of the total training runtime. The batch dataloader improves this to 2.45 seconds; only 2.08% of the total training time. This is a relative speedup of over 100x for data loading and a speed up of total training by more than 2x. This improvement of two orders of magnitude partly comes from the usage of the PyTorch dataloader in fastai, which does tensor conversions during __get_item__ rather than during the initialization phase. More typically we see a 5–50% improvement over the multi-worker dataloader for similar sized batches across different workloads, depending on the batch and tensor size.

Figure 2. Before and After Profiles of the Training Loop

PyTorch Kernel Improvements

With the batch dataloader in place training is no longer dominated by dataloading and we’re now more easily able to explore the performance of the CUDA kernels within the execution of the model by scaling the batch size to better utilize the GPU. Analyzing the model using the nvprof tool in Figure 3 we see that 69.6% of the work on the GPU is happening within a single kernel (EmbGPK) which is responsible for the calculation of embedding gradients.

Figure 3. NVProf Profile of the unoptimized workflow. Almost 70% of the time on GPU is spent in one kernel

The calculation of EmbGPK is basically a segmented sum over the rows of a matrix to compute the embedding gradients. In the original implementation indices access was not distributed evenly and there were not enough threads, resulting in underutilization of the processors on the GPU shown in Figure 4 (top). Instead of doing a flat sum where a single variable is used to accumulate the weights, we modified the implementation of the kernel to compute the sum of the weights in two-steps:

  • Each GPU warp sums ‘NROWS_PER_THREAD’ number of row given by ‘indices’
  • Sum each partial-sum from 1) and scatter into ‘grad_weight’
Figure 4. Improved GPU utilization (bottom) after optimizing the EmbGPK kernel.

Not only does this increase the utilization, it also improves numerical stability. Figure 4 (bottom) shows the utilization of the GPU under the new kernel where we see that all GPU Streaming Multiprocessors are utilized fully. The optimizations have been implemented and merged into PyTorch and any method using embeddings at scale should see similar improvement in performance. When we examine the model performance again in NVProf we see a significantly improved graph. The calculation of embedding gradients has gone from ~70% of the workload to 2.3% and the GPU computation is dominated by matrix multiply kernels which are already heavily optimized. The overall speedup due to this method is ~6.5x.

Extreme Batch Sizes

One of the further advantages of the batch dataloader introduced above is the ability to scale to extreme batch sizes. On the Tesla V100 GPUs that we used for evaluation the maximum batch size we could scale our baseline model to was 819,600. Scaling to this degree sped up training to 15s / epoch, however the model took two epochs to converge. After further evaluation we reduced our batch size to 204,800 which we were able to converge in a single epoch in 15.4 seconds. This is a 7.5x speedup over the 4096 batch size GPU Memory variant, while the speedup for the unoptimized fastai workflow with the larger batch size was only 4.7%, in part due to increased dataloader costs.

The scaling of batch size to this degree was aided by using layer-wise adaptive large batch optimization, more commonly known as the LAMB optimizer, which scales the learning rate for each layer of the network. We validated model performance by evaluating mean and stdev of AUC and MRR across five runs of each batch size. The 4096 batch size version had an AUC of 0.8827 +/- 0.0049 and an MRR of 0.6143 +/- 0.005 while the 204,800 batch version had an AUC of 0.8816 +/- 0.0012 and an MRR of 0.6136 +/- 0.0008 for our baseline model, a relative difference of ~0.1% for both metrics.

Conclusions and Future Work

In this blog post we have shared several novel optimizations that can be used when training deep learning models in PyTorch using the fastai library. Using the RAPIDS.AI cuDF library we improve the preprocessing steps that prepare data for the model by a factor of 9x, performing categorical encoding, normalization and null value filling on the GPU. We introduce a novel batch dataloader which loads an entire batch from memory in a single read accelerating the PyTorch dataloader by a factor of over 100x, and training time by 2x. Upon further analysis of the GPU computations we also update the kernel responsible for calculating the embedding gradient in PyTorch reducing it from 70% of the workload to 2.3% and improving the training time by 1.9x. Then, taking advantage of the LAMB optimizer we are able to scale our batch size to the limits of memory, achieving a further 2.12x speedup in training time. A further 15% improvement in performance was achieved by keeping the entire dataset in GPU memory during training.

These optimizations collectively improve the end to end training time by a factor of 15.6x from our initial model on this problem. Figure 5 highlights the improvements to each stage of the recommendation pipeline. By reducing the training time from 15 minutes to 57.2s we enable a much richer exploration of feature space in hours instead of days, and significantly reduce the cost of training deep learning based recommenders. The iterative nature of feature engineering and model training mean that these speedups compound; long feature creation and training times result in a lot of cognitive downtime and a disrupted work process. By shortening feature creation time and training time so dramatically we hope to provide a more immediate feedback loop where Data Scientists and ML Engineers can stay in the flow of their work.

Figure 5. Summary of Optimizations (time in s)

We’re working on implementing variations of this method that work in different situations. We currently have an in GPU memory version that was described in detail here, and we are close to completion of a version that operates on a subset of the data in GPU memory for mid sized datasets or for GPUs with lower memory. We’re also developing a larger than CPU memory preprocessor and dataloader that will allow for efficient pre processing and data loading when the data is too large for even the CPU.

In addition to these improvements we plan to explore FP16 mixed-precision training using Apex and multi-gpu training via Hogwild, Horovod and BytePS, further improving model performance at the cost of additional hardware. Further, these performance improvements aren’t limited to tabular data or recommendation and should be applicable to Natural Language Processing, Time Series, Tabular Data, and other problems where dataloading is a bottleneck relative to compute. We hope to demonstrate their effectiveness in other domains and provide easy to use examples for the community.

Finally, we’re very excited to see the GPU preprocessing integrated into the fastai library v2. Integration into fastai and other deep learning frameworks is extremely important to us. We want to make it as simple as possible to make use of these methods and within that context fastai is a great fit.

The work shared here is the joint effort of our two amazing interns Sara Rabhi and Wenbo Sun, Kaggle Grandmaster Jiwei Liu, Julio Perez, Mads B. Kristensen, Rick Zamora and myself.

The source code of the solution is available in our repo and we encourage you to try these methods out your own workflows. Although the full 15x speedup requires all of them, they can be used individually to provide significant speedup to workflows where that isn’t possible. As with all RAPIDS related projects we’d love to hear from you regarding your experiences with RAPIDS and we welcome contributions from the community.

--

--

Even Oldridge
RAPIDS AI

I’m a research scientist working at NVidia on deep learning for tabular data.