LSH as a feature extractor

Dinesh Ramasamy
3 min readAug 9, 2023

--

Locality Sensitive Hashing (LSH) is typically used for approximate nearest neighbor (ANN) operations (vector search). The properties of LSH that make it very useful for vector search can also be exploited in neural network models that take vectors as input (for example content signals like audio, video and text embeddings in various settings). This is especially relevant in LLM settings where one wants to use one modality embeddings as “soft-prompts” for the LLM.

Typically manifolds that are input to a model in a specific domain are complicated (not i.i.d). This complexity makes it very hard to disentangle these manifolds using multi-layered perceptrons which are compute intensive operations. A classic trick to learn complex mappings is to memorize outcomes rather than learn functions. How do we memorize vector mappings? The obvious answer is to use embeddings, right? However, we need discrete objects to compute embeddings and vectors are not discrete. So how do we go about using embeddings for vector inputs? Well again the answer seems obvious: Hash the vectors — with the caveat that nearby points must remain “nearby” after hashing. This is exactly what LSH does. So we propose to use embedding models stacked on top of LSH ops as shallow feature extractors.

The choice of LSH algorithm and the manner in which the LSH buckets are converted into embeddings is very important. We present an algorithm that is only direction aware (ignores magnitudes of vectors, but this can be easily fixed) and is based on this simple LSH algorithm:

import torch
import torch.nn as nn
import torch.nn.functional as F


class CosineVectorEmbedding(nn.Module):
"""
LSH based vector indexer for highly non-linear ops
"""

def __init__(self, inp_dim: int, emb_dim: int, n_proj: int = 16, num_bins: int = 20):
super().__init__()
self.register_buffer(
'projection_mat',
F.normalize(torch.randn((inp_dim, n_proj)), p=2.0, dim=0),
persistent=True,
)
resolution = 2.0 / num_bins
self.register_buffer(
'grid',
torch.linspace(-1, 1, num_bins + 1)[:-1] + 0.5 * resolution,
persistent=True,
)
self.register_buffer(
'pos_offset',
((num_bins + 1) * torch.arange(0, n_proj, dtype=torch.long)).long().reshape(-1, 1, 1),
persistent=True,
)
self.emb = nn.EmbeddingBag(
(num_bins + 1) * n_proj,
emb_dim,
mode='sum',
)
self.emb_dim = emb_dim
self.n_proj = n_proj

def forward(self, x):
bs, seq_len, emb_dim = x.size()
z = F.normalize(x, p=2.0, dim=-1) @ self.projection_mat
z = torch.bucketize(z, self.grid).transpose(0, -1)
z = (z + self.pos_offset).transpose(0, -1).contiguous()
return self.emb(z.view(-1, self.n_proj)).reshape(bs, seq_len, self.emb_dim)

To illustrate its effectiveness we show this can be used to train a RecSys LLM with input content embeddings that are 32 dimensional. We use a cascade of independent low-to-high resolution LSH embeddings (inp_dim = 32, emb_dim = 512, n_proj = 32, num_bins = (1, 2, 4, 8, 12, 16, 20)) and add their outputs. We compare this with using a simple projector (i.e., nn.Linear(32, 512)).

Val in-batch recall@50 as a function of number of user sequences seen. Blue (LSH) and orange (Linear).

Clearly, we see that CosineVectorEmbedding is a far better feature extractor than a simple linear transform (which is of course more parameter and compute efficient).

--

--