Transformer-based embedding models like Sentence Transformers and the recent SFR-Embedding-Mistral have emerged as powerful tools for dense vector representations of text. While these pre-trained models work well out-of-the-box, there are often scenarios where further fine-tuning on a domain-specific corpus can boost performance.
In this blog post, we’ll explore how to fine-tune black-box embedding models using low-rank adaptation (LoRA) with the LlamaIndex library. LoRA is a technique that trains a small number of rank-decomposed weights to adapt a pre-trained model to a new task or domain. Compared to full model fine-tuning, LoRA is much more compute and memory efficient (you can often run it on a single consumer-grade GPU) while achieving comparable performance gains.
LlamaIndex comes with an API in the shape of the BaseAdapter
and EmbeddingAdapterFinetuneEngine
classes, which allow you to fine-tune any embedding model. These are severely limited however as they only allow you to put a new trainable layer (in general any nn.Module
) on top of your pre-trained embedding model. As LoRA typically adds adapters to every self-attention block in the model, this setup will simply not work for us. At the same time, we don’t want to reinvent the wheel, so we’d rather plug into LlamaIndex’s fine-tuning machinery somehow to make it do our bidding 🙃
We start by subclassing BaseAdapter
:
class UniversalAdapter(torch.nn.Identity, BaseAdapter):
"""Adapter model that does nothing, but includes trainable parameters
(e.g. LoRAs) of the embedding model, which the FinetuneEngine actually
trains."""
def __init__(self, embed_model):
super().__init__()
self.embed_model = embed_model
def save(self, output_path):
self.embed_model.save_pretrained(output_path, save_adapter=True, save_config=True)
As per the docstring, this adapter added on top of the embedding model will simply pass through all the inputs because nn.Identity
. However, by storing embed_model
inside the adapter class, all the parameters of embed_model
will get registered as parameters of our UniversalAdapter
, courtesy of some PyTorch magic.
Next up is the EmbeddingAdapterFinetuneEngine
class. The current implementation is almost ok with one small caveat: it strips the embeddings of any gradient information. We can’t have that, so we subclass again:
class UniversalEmbeddingFinetuneEngine(EmbeddingAdapterFinetuneEngine):
"""Fintune any parameters of embed_model with requires_grad set to True,
e.g. LoRA adapaters."""
def __init__(
self,
dataset: EmbeddingQAFinetuneDataset,
embed_model: BaseEmbedding,
batch_size: int = 10,
epochs: int = 1,
dim: Optional[int] = None,
device: Optional[str] = None,
model_output_path: str = "model_output",
model_checkpoint_path: Optional[str] = None,
checkpoint_save_steps: int = 100,
verbose: bool = False,
bias: bool = False,
**train_kwargs: Any,
) -> None:
super().__init__(
dataset=dataset,
embed_model=embed_model,
batch_size=batch_size,
epochs=epochs,
adapter_model=UniversalAdapter(embed_model._model),
dim=dim,
device=device,
model_output_path=model_output_path,
model_checkpoint_path=model_checkpoint_path,
checkpoint_save_steps=checkpoint_save_steps,
verbose=verbose,
bias=bias,
**train_kwargs,
)
def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:
"""Smart batching collate."""
import torch
from torch import Tensor
query_embeddings: List[Tensor] = []
text_embeddings: List[Tensor] = []
for query, text in batch:
query_embedding = self.embed_model.get_query_embedding(query)
text_embedding = self.embed_model.get_text_embedding(text)
# was stripping gradients: query_embeddings.append(torch.tensor(query_embedding))
query_embeddings.append(query_embedding)
# was stripping gradients: text_embeddings.append(torch.tensor(text_embedding))
text_embeddings.append(text_embedding)
query_embeddings_t = torch.stack(query_embeddings)
text_embeddings_t = torch.stack(text_embeddings)
return query_embeddings_t, text_embeddings_t
The only changes with respect to the EmbeddingAdapterFinetuneEngine
class are: (1) in the constructor, which doesn’t accept adapter_model
as an argument anymore and uses UniversalAdapter
instead, and (2) in the smart_batching_collate
function to keep them gradients alive!
We’re almost there. The last misbehaving class is HuggingFaceEmbedding
, which once again strips the embeddings of gradient information by converting them from tensors to lists, effectively detaching them from the computational graph. We hence subclass once more, with very minimal changes:
class HuggingFaceEmbeddingWithGrad(HuggingFaceEmbedding):
"""HuggingFaceEmbedding with gradient support."""
def __getattr__(self, name: str) -> Any:
return getattr(self._model, name)
def _embed(self, sentences: List[str]) -> torch.Tensor:
"""Embed sentences."""
encoded_input = self._tokenizer(
sentences,
padding=True,
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
# pop token_type_ids
encoded_input.pop("token_type_ids", None)
# move tokenizer inputs to device
encoded_input = {
key: val.to(self._device) for key, val in encoded_input.items()
}
model_output = self._model(**encoded_input)
context_layer: "torch.Tensor" = model_output[0]
if self.pooling == Pooling.CLS:
embeddings = self.pooling.cls_pooling(context_layer)
elif self.pooling == Pooling.LAST:
embeddings = self.pooling.last_pooling(context_layer)
else:
embeddings = self._mean_pooling(
token_embeddings=context_layer,
attention_mask=encoded_input["attention_mask"],
)
if self.normalize:
import torch
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings # was embeddings.tolist()
The final piece of the puzzle, which by the way wasn’t even required until some two weeks ago, is pydantic
validation. Because we have changed the return type of the _embed
method above, pydantic
is gonna give us grief. In order to avoid it, we’re going to temporarily disable all pydantic
validations with the below context manager:
class disable_pydantic:
"""Context manager to disable pydantic validation."""
def __enter__(self) -> None:
self.validate = pydantic_fields.ModelField.validate
pydantic_fields.ModelField.validate = lambda *args, **kwargs: (args[1], None)
def __exit__(self, *args) -> None:
pydantic_fields.ModelField.validate = self.validate
All that’s left now is to fine-tune:
hf_qlora_model = HuggingFaceEmbeddingWithGrad(
model=peft_model,
tokenizer=embed_tokenizer,
query_instruction=query_instruction,
pooling=pooling,
embed_batch_size=1
)
finetune_engine = UniversalEmbeddingFinetuneEngine(
train_dataset,
embed_model=hf_qlora_model,
dim=4096,
model_output_path=lora_adapters_path,
epochs=5,
verbose=False,
)
with disable_pydantic():
finetune_engine.finetune()
A complete Jupyter notebook based on an example from the LlamaIndex repo can be found at https://github.com/marib00/llamaindex-embedding-lora. This notebook takes you through the whole process, starting from quantizingSFR-Embedding-Mistral
, which is currently at the top of the Massive Text Embedding Benchmark (MTEB) leaderboard, to fine-tuning it with QLoRA on a synthetic dataset generated using Mixtral-8x7B-v0.1-GPTQ
, and evaluating it against the base model.
Happy fine-tuning!