Learning Embeddings for Music Recommendation with MXNet’s Sparse API

Kat Ellis
Apache MXNet
Published in
6 min readFeb 21, 2019

Amazon Music is a digital music streaming service available to all Prime customers and paid subscribers of our premium service, Amazon Music Unlimited. A goal of the Amazon Music Machine Learning team is to provide an effortless, “lean-back” experience, wherein the right music is recommended to each individual customer at the right time, and with minimal input from the customer. Toward this ambitious goal, we develop large-scale recommendation systems that personalize the experience for each customer. Training these large-scale recommendation systems for the millions of customers we serve can be computationally expensive, requiring many matrix computations and taking days to train. In this blog post we’ll explain how we use MXNet’s sparse API to avoid unnecessary computation and reduce training time significantly.

One particular piece of our system uses embedding (a.k.a. latent factor) models to learn representations of customers and recommendable items (e.g. songs, albums, artists, etc.) from listening behavior. In this semantic embedding space, customers and items are represented by low-dimensional vectors, such that customers are located close to the items that they listen to most, as well as music that they haven’t listened to but might appreciate. Finding relevant items for each customer amounts to finding the items that are closest to the customer using nearest neighbor search in the embedding space.

Embedding models have a number of attractive qualities:

  • The embedding space provides useful geometric relationships between entities (e.g., affinities, similarities, synonyms and differences).
  • Embeddings are flexible, general-purpose primitives that can be used as building blocks to create many compelling customer experiences. In addition to music recommendation, learned embeddings are also useful in genre tagging, music sequencing, and voice and visual user experience (UX) optimization.
  • Embeddings can be trained on large quantities of unsupervised or “weakly” supervised data, such as the customer interactions and implicit feedback that power the majority of industrial recommender systems.
  • Embeddings facilitate knowledge transfer between related tasks (e.g., using pre-trained unsupervised word embeddings to bootstrap complex natural language understanding (NLU) models), thereby speeding up learning and effectively amplifying small amounts of supervised training data.

Embedding via Collaborative Filtering

For music recommendation, we mine patterns from our customers’ listening behavior. Our goal is to learn an embedding model that places customers close to the music that they listen to. By virtue of optimizing this goal, the model will learn to place similar music close together. That is, music that is listened to by similar customers, or in similar contexts — will be clustered nearby in the embedding space.

The embedding model can be viewed as a form of collaborative filtering, a broad category of algorithms that extract preference patterns from the collective set of customers in order to predict the preferences of individual customers. As an illustration, consider how people discovered new music in the 1980s and 90s. The invention of cassette tapes allowed people to create personal music mixes, which they could share with their friends, who often had similar tastes. We can think of collaborative filtering as doing something similar on a much larger scale: looking for “customers like you” and leveraging their listening behavior to recommend new music to you.

2-D visualization of music embeddings using T-SNE. Similar music (i.e., music from the same genre) is close together in the embedding space.

Model Setup

The exact design of the embedding model, and the way that we train it, varies depending on the specific experience we are optimizing for. At an abstract level, we define an embedding model as a neural network in which the input is an entity index, such as the ID of a song or a customer, and the output is the embedding. As we are modeling interactions between two (or more) entities, we have multiple embedding networks, which are combined via some functional (e.g., dot product, cosine or Euclidean distance) in the loss function. Loosely speaking, the loss function measures how well the combined embeddings explain the observed customer-item interactions. We train the networks using stochastic gradient descent (SGD), with mini-batches of customer-item interactions (e.g., listening events, likes or library adds).

Scaling Up

Given the massive scale of our catalog and customer base, the datasets that we train on are on the order of billions of examples. Moreover, since we learn an embedding for each entity, the number of model parameters is also in the billions. Processing that much data is challenging, even for a beefy, GPU-enabled host. Luckily, we can exploit the inherent sparsity of the problem to speed up training. The key observation is that any given customer only listens to a small subset of the catalog, and each mini-batch of customers only represents a small fraction of the total customer base. Thus, for certain loss functions, each mini-batch update will only involve a small subset of the embeddings. Exploiting the sparsity of the mini-batch updates can yield tremendous savings in computation and I/O.

First let’s look at a standard implementation, without exploiting sparsity. Our model has a matrix of parameters, weights, in which each row is an embedding vector for an item in our catalog. The dimensionality of weights is number_of_items (on the order of tens of millions) by embedding_dimension. That's a really big matrix! At each iteration of training, we update weights by taking a step in the direction of the negative gradient (with respect to weights) of the loss and weight regularizer (implemented using weight decay). We illustrate this in the following code block, which is a simplification of the SGD implementation in MXNet.

weights -= learning_rate * (loss_grad + weight_decay * weights)

Because number_of_items is large, this update involves operations on matrices with billions of entries. Even with GPU computing, these calculations are quite expensive, and training a model takes multiple hours.

However, based on our earlier observation, we know that only a few items will appear in any given mini-batch, so the gradient of the loss, loss_grad, will be zero for all rows corresponding to items that don't appear in the mini-batch. Further, we can modify the weight decay to only decay the active embeddings — which is actually a smart way to prevent the weights from decaying too fast. We then have a “lazy” update step that only updates the weights corresponding to the non-zero rows of the loss gradient matrix.

for row in grad.indices:
weights[row] -= learning_rate * (loss_grad[row] + weight_decay * weights[row])

MXNet v1.3 introduced new sparse gradient functionality that leverages the advantages of sparse representations and the above optimization. To use this functionality, all you need to do is pass the argument sparse_grad=True to the constructor of an Embedding layer, as illustrated in the following code.

inputs = mx.sym.var(name='inputs')
weights = mx.sym.var(name='weights', shape=(number_of_items, embedding_dimension))
item_embedding = mx.sym.Embedding(data=inputs,
weight=weights,
input_dim=number_of_items,
output_dim=embedding_dimension,
sparse_grad=True)

For more details about how MXNet handle sparse arrays, see: https://mxnet.incubator.apache.org/tutorials/sparse/row_sparse.html.

Speedup

How much of a difference do these optimizations make in practice? In our experiments, we saw a 1.5x speedup when using MXNet’s sparse gradient functionality with lazy updates.

Time to train an embedding model with dense weights and sparse weights. We saw a 1.5x speedup by taking advantage of sparsity and lazy updates.

Implementation Details:
These tests were run on libraries installed with Python 3, CUDA 9.0 and MKL-DNN. Each models was run on a single Tesla K80 GPU on a p2.16xlarge instance.

Conclusion

MXNet’s sparse API was essential for efficiently training Amazon Music’s embedding-based recommendation model on billions of customer data points. We’re looking forward to trying out multi-GPU training to further speed up training.

Acknowledgements: Jeff Schmitz, Santhosh Kasa, Ben London, Ted Sandler, Fabian Moerchen, Gert Lanckriet

--

--

Kat Ellis
Apache MXNet

Applied Scientist at Amazon Music Machine Learning