LSH based sampling-bias correction for retrieval
Retrieval models are employed to quickly generate candidates for more nuanced downstream processing in different settings like retrieval-augmented generation (RAG) and candidate generation for recommender systems to name a few. They work by embedding documents and queries into the same n-dimensional space S.
Normally one trains these models using in-batch negatives, which refers to the technique of using positive documents for other queries in the same mini-batch as negative documents for the query under consideration. Furthermore it is common-place to use contrastive loss on all available in-batch negatives as the loss function. The code snippet below shows how one typically trains a retrieval model [see TensorFlow Recommenders for a reference implementation].
temperature = 0.05
query_model = QueryModel(...).to(device)
document_model = DocumentModel(...).to(device)
batch = {...}
query_embedding = F.normalize(query_model(batch), p=2.0, dim=-1)
document_embedding = F.normalize(document_model(batch), p=2.0, dim=-1)
batch_size = query_embedding.size(0)
logits = query_embedding @ document_embedding.T / temperature
labels = torch.arange(0, batch_size, dtype=torch.int64, device=device)
loss = F.cross_entropy(logits, labels)Since popular query-document pairs (often asked question-answers for example) appear frequently in the training data, the corresponding documents appear as negatives often in the contrastive loss as well. But, the aim of the above loss function is to approximate the full softmax on the complete document corpus. In that ideal case, each document appears as a negative the same (=1) number of times as any other document. This disparity between the actual loss function and the ideal loss function is known to hurt the performance of the learnt retrieval model.
Sampling-bias / log-Q correction
In recommendation system literature, there is a well known method to correct this “bias.” This sampling-bias correction technique, also known as log-Q correction, is described in this google paper. The basic idea is to get a good approximation of the soft-max denominator (aka partition function) using the documents in the mini-batch:
q = query_embedding
d_k = k-th document embedding
d_0 = correct document embedding
partition_function
= sum over the document corpus k(exp(<q, d_k> / T)
= exp(<q, d_0> / T) + sum over the negative document corpus(exp(<q, d_k> / T * Pr[d_k in mini-batch] / Pr[d_k in mini-batch])
~= exp(<q, d_0> / T) + sum over the minibatch negatives(exp(<q, d_k> / T) / Pr[d_k in mini-batch]
= exp(<q, d_0> / T) + sum over the minibatch negatives(exp(<q, d_k> / T - log(Pr[d_k in mini-batch]))This is the derivation of the log-Q correction formula and it utilizes the simple formula for the first moment. The updated training code looks like:
eps = 1e-6
batch_size = query_embedding.size(0)
logits = query_embedding @ document_embedding.T / temperature
labels = torch.arange(0, batch_size, dtype=torch.int64, device=device)
log_candidate_sampling_prob = batch['log_candidate_sampling_prob'].reshape(1, -1).repeat(batch_size, 1)
# need to set the correction for the positive label to 0
log_candidate_sampling_prob.fill_diagonal_(0)
loss = F.cross_entropy(logits - log_candidate_sampling_prob, labels)This gives gains when the learnt embeddings are evaluated on the full corpus (using any standard ANN algorithm like HNSW if the corpus is large). Now moving on to how one can estimate the log_candidate_sampling_prob column.
Estimating candidate sampling probability
The standard approach is to estimate unigram probability of the document in the training corpus. If this probability is say p and the mini-batch size is N, the candidate sampling probability can be easily approximated using the formula: 1-(1-p)^N.
Another popular approach is to estimate log_candidate_sampling_prob in a streaming fashion using count-min sketch (on the document_id column in the mini-batch).
class StreamingLogQCorrectionModule(nn.Module):
def __init__(
self,
num_buckets: int,
hash_offset: int,
alpha: float,
p_init: float,
):
super().__init__()
self.num_buckets = num_buckets
self.hash_offset = hash_offset
self.alpha = alpha
self.register_buffer('b', (1.0 / p_init) * torch.ones((num_buckets,), dtype=torch.float32))
self.register_buffer('a', torch.zeros((num_buckets,), dtype=torch.long))
def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
h = self.hash_fn(document_ids.view(-1))
return - self.b[h].log().reshape(*document_ids.shape)
def hash_fn(self, document_ids: torch.LongTensor) -> torch.LongTensor:
return (document_ids + self.hash_offset) % self.num_buckets
def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
h = self.hash_fn(document_ids).unique()
self.b[h] = (1 - self.alpha) * self.b[h] + self.alpha * (batch_idx - self.a[h]).float()
self.a[h] = batch_idxOne typically cascades a bunch of these estimators like so:
class CascadedStreamingLogQCorrectionModule(nn.Module):
def __init__(
self,
num_buckets: int,
hash_offsets: Tuple[int, ...],
alpha: float,
p_init: float,
):
super().__init__()
self.models = nn.ModuleList([
StreamingLogQCorrectionModule(num_buckets, offset, alpha, p_init)
for offset in hash_offsets
])
def forward(self, document_ids: torch.LongTensor) -> torch.Tensor:
result = torch.empty((0,), device=document_ids.device)
for i, mod in enumerate(self.models):
if i == 0:
result = mod(document_ids)
else:
result = torch.minimum(result, mod(document_ids))
return result
def train_step(self, document_ids: torch.LongTensor, batch_idx: int) -> None:
for mod in self.models:
mod.train_step(document_ids, batch_idx)We should feed the document_id column and the batch index to the train_step method to update the running statistics of the log of the candidate sampling probability and then retrieve the log of the candidate sampling probability by calling the above cascaded streaming log-Q correction probability estimator module.
Issues with log-Q correction
What happens when the document corpus is very large? What if the document embeddings are not free parameters (like latent embeddings)? What if the document embeddings are the output of a content model like a LLM’s hidden state or a “sentence-transformer?” The log-Q correction values take a very large negative, approximately constant value for all documents. This hurts the log-approximation so much so that the model with no log-Q correction tends to perform better than the one with log-Q correction.
So how do we fix this? We utilize the fact that in such settings, the document embeddings are constrained to live on the n-dimensional space (typically the n-dimensional unit-sphere). There are only so many “approximately-unique” embeddings on the n-dimensional unit-sphere, especially for reasonably small sized n. So as far as the contrastive loss is concerned, we are grossly over-estimating the log-Q correction factor on the partition function if we use the regular log-Q correction formula. So we propose to use the position of the document embeddings on the unit-sphere to estimate the log-Q correction instead of using the document-id based estimation of the log-Q correction factor.
Document embedding based log-Q correction
We use the same CascadedStreamingLogQCorrectionModule to estimate the log-Q correction factor. However instead of using the document_id column as input to this estimator, we use the Locality Sensitive Hash (LSH) of the document embedding to estimate the log-Q correction factor.
The LSH (of different resolutions) for the n-dimensional unit-sphere can be calculated using the following module with appropriate values for n_proj and num_bins:
class LocalitySensitiveHashingModule(nn.Module):
def __init__(self, emb_dim: int, n_proj: int, num_bins: int):
super().__init__()
self.register_buffer(
'projection_mat',
F.normalize(torch.randn((emb_dim, n_proj)), p=2.0, dim=0),
persistent=True,
)
resolution = 2.0 / num_bins
self.register_buffer(
'grid',
torch.linspace(-1, 1, num_bins + 1)[:-1] + 0.5 * resolution,
persistent=True,
)
self.num_bins = num_bins
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = F.normalize(x, p=2.0, dim=-1) @ self.projection_mat
z = torch.bucketize(z, self.grid).long()
result = torch.empty((0,), device=x.device, dtype=torch.long)
for i, t in enumerate(z.unbind(dim=-1)):
if i == 0:
result = t
else:
result = result * (1 + self.num_bins) + t
return resultThis approach may result in even better ANN evaluation results on the whole corpus than the standard log-Q correction using document_id. It is not clear how the training dynamics evolve as the document model can also be trained along with the query model. In such cases, the memory parameter alpha in CascadedStreamingLogQCorrectionModule may need to be carefully tuned.
