Training Larger and Faster Recommender Systems with PyTorch Sparse Embeddings

Bo Liu
NVIDIA Merlin
Published in
5 min readAug 5, 2021
Photo by Michael D Beckwith on Unsplash

In the recent RecSys 2021 Challenge, we leveraged PyTorch Sparse Embedding Layers to train one of the neural network models in our winning solution. It enables training to be nearly 6x faster while embedding 2.4x more users on a single GPU card, which is critical when dealing with huge datasets with millions of users. In this post, we explain its rationale and the tricks to make it work with code examples.

What is an embedding layer?

In tabular data deep learning problems, the standard way to use categorical features are categorical embeddings, i.e., representing each unique categorical value in the dataset by a n-dimensional embedding vector. The mapping from each categorical feature value to its embedding vector is learned via the embedding table.

For example, to represent a feature with 5,000 unique values by a 128-dimensional embedding, we would need an embedding layer consisting of a lookup table with dimension (5000, 128).

Note that categorical embedding is an optimization on one-hot encoding, because it is equivalent to the combination of these two steps: (1) one-hot encoding the categorical variables, and (2) learning a fully connected dense layer. The above example would be equivalent to one-hot encoding the 5,000 unique values into 5,000-dimensional one-hot vectors, then learning a fully connected layer with input dimension 5,000 and output dimension 128. The weight matrix of this dense layer also has dimension (5000,128).

Why do we need sparse embedding layers?

In domains such as recommender systems, some features’ cardinality (i.e., number of unique values) can be huge, such as user IDs. It’s common to see millions or tens of millions of unique users in a recommender dataset.

An obvious consequence is that the embedding table takes a large chunk or even majority of the model size, depending on the rest of model architecture. It may be too large to fit into a single GPU. Although there are ways to circumvent this, such as lowering embedding dimension (i.e., reducing embedding table’s width) or grouping less frequent users into the same category (i.e., reducing embedding table’s height), they all come at the cost of model accuracy.

A less obvious ramification of having a huge embedding table is that training is dramatically slowed down — calculating the gradients of a huge matrix is an expensive operation. But do we really need to compute the gradients of the whole matrix for every batch? Probably not, since batch size (e.g., 1024) is usually a few magnitudes smaller than the embedding matrix size (millions). The gradients are 0 for embedding vectors, which are not used in that batch size. As they are not used in that particular batch, there cannot be any learning signal from the target. Calculating the gradients for not used embedding vectors is inefficient and adds overhead in calculation time.

Wouldn’t it be nice if we only compute the gradients of the 1024 rows corresponding to the 1024 data points in each batch? That’s the idea of PyTorch sparse embeddings: representing the gradient matrix by a sparse tensor and only calculating gradients for embedding vectors which will be non zero . It addresses not only the speed but also the GPU memory issue, since sparse tensors take up less memory than dense ones.

Let’s see how to use it in PyTorch.

Creating sparse embedding layers

In PyTorch, a sparse embedding layer is just torch.nn.Embedding layer with argument sparse=True.

NVTabular’s handy utility class ConcatenatedEmbeddings can create and concatenate all the embedding layers for a model, with the option to specify which ones to be sparse. In this example, we make a_user_id and b_user_id sparse since both have high cardinality.

import torch.nn as nn
from nvtabular.framework_utils.torch.layers import ConcatenatedEmbeddings
embedding_table_shapes = {
"a_user_id": (656096, 512),
"b_user_id": (868034, 512),
"language": (67, 16),
"media": (14, 16),
"tweet_type": (4, 16),
}
class Net(nn.Module):
def __init__(
self,
embedding_table_shapes,
dropout=0.2,
sparse_columns=["a_user_id", "b_user_id"],
):
super(Net, self).__init__()
self.dropout = dropout
self.initial_cat_layer = ConcatenatedEmbeddings(
embedding_table_shapes, dropout=dropout
)

Just one additional argument, very simple, isn’t it? If you plug this code into the usual training pipeline with Adam optimizer, you will get a RuntimeError: “Adam does not support sparse gradients, please consider SparseAdam instead”. This is because only a limited number of optimizers support sparse gradients according to PyTorch documentation: “currently it’s optim.SGD (CUDA and CPU), optim.SparseAdam (CUDA and CPU) and optim.Adagrad (CPU).”

Two optimizers for the same model

This means that, if we want to use Adam-type optimizers, we would need two optimizers for the same model: SparseAdam for the embedding layers and regular Adam or AdamW for all other layers.

In this example, we specify SparseAdam for the first two model parameters (a_user_id and b_user_id) and AdamW for all the other parameters.

import torch.optim as optim# for sparse embedding layers
optimizer = optim.SparseAdam(list(model.parameters())[:2], lr=lr)
# for the rest of the model
optimizer2 = optim.AdamW(list(model.parameters())[2:], lr=lr2)

In the training loop, both optimizers need to be updated. In this example, we also use torch amp mixed precision training.

import torch.cuda.amp as amp# training loop
for data in loader:
data = data.cuda()

optimizer.zero_grad()
optimizer2.zero_grad()
with amp.autocast():
logits = model(data)
loss = criterion(logits, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.step(optimizer2)
scaler.update()

Lastly, note that when initializing the two optimizers above, we used different learning rates. Chances are, the optimal learning rates are different for the sparse embedding layers and the rest of the network. We should tune them separately. For our RecSys 2021 model, SparseAdam’s learning rate is 100x larger than AdamW.

Faster training, larger model

How helpful is the sparse embedding? In our RecSys model, training time per epoch was reduced from 41 hours to 7 hours after we converted the user embedding layer to sparse, a speedup of 5.85x

Without sparse embedding, we could embed about 8.2 million unique users on a single V100 GPU by using frequency threshold 25; with sparse embedding, we could embed 19.7 million unique users by using frequency threshold 3. The 2.4x amount of embedded users significantly boosted the model’s score.

In a future post, we will explain how to speed up training PyTorch tabular models even more by combining the powers of rapids and NVTabular torch dataloader, so that the data stays on the GPU during the whole pipeline.

Optimize your Deep Learning Recommender Systems

In this blog post, we showed how to optimize the embedding layers in large deep learning recommender models with PyTorch. You can easily try out optimizing your neural networks. If your embedding tables still do not fit into a single GPU memory, you may want to check out HugeCTR. HugeCTR is a custom-built deep learning framework, designed to scale embedding tables to multiple GPUs or nodes. It is part of NVIDIA Merlin, an open source framework to scale and accelerate recommender systems on the GPU.

We published multiple resources of our RecSys 2021 solutions, such as a blog post, code, paper and video. We will publish more technical blog posts about our learnings, soon. Stay tuned!

Acknowledgements

I would like to thank the Merlin team’s Benedikt Schifferer and Even Oldridge for their comments and suggestions, and for Benedikt pointing me to sparse embeddings in the first place.

References

https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

https://raberrytv.wordpress.com/2019/06/13/pytorch-combining-dense-and-sparse-gradients/

--

--

Bo Liu
NVIDIA Merlin

Senior Deep Learning Data Scientist at NVIDIA