Fine-Tuning Sentence Transformer Models: A Case Study
Sentence Transformer are a types of Natural Language Processing(NLP) model that can generate Sentence embedding. Sentence embedding techniques encode sentences into a fixed-sized, dense vector space such that semantically similar sentences are close. It can be used variety of tasks, such as semantic similarity, information retrieval, text classification, and question answering etc.
Traditionally, sentence transformers are trained using supervised learning method where large amount of labeled data requires for learning sentence embeddings. However, for most tasks and domains labeled data certation is more expensive. Where as in unsupervised training approach label data not require.
In this blog first we go through two most popular unsupervised learning methods TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning
and GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval
for fine-tune Sentence Transformer Models and then implement this paper using python.
TSDAE (Transformers and Sequential Denoising Auto-Encoder )
TSDAE (Transformers and Sequential Denoising Auto-Encoder) is one of most popular unsupervised training method. The main idea is reconstruct the original sentence from a corrupted sentence.
The TSDAE model consists of two parts: an encoder and a decoder. During training, TSDAE encodes corrupted sentences into fixed-sized vectors and requires the decoder to reconstruct the original sentences from this sentence embedding. For good reconstruction quality, the semantics must be captured well in the sentence embedding from the encoder.
The TSDAE paper described, the best performing noise is deletion-only, with a deletion ratio of 0.6.
Result on STS data
Python Code -
from sentence_transformers import SentenceTransformer,models
from sentence_transformers.datasets import DenoisingAutoEncoderDataset
from sentence_transformers.losses import DenoisingAutoEncoderLoss
import torch
from torch.utils.data import DataLoader
import re
data = ['TSDAE (Tranformer and Sequential Denoising AutoEncoder) is one of most popular unsupervised training method. The main idea is reconstruct the original sentence from a corrupted sentence. The TSDAE model consists of two parts: an encoder and a decoder.',
' TSDAE encodes corrupted sentences into fixed-sized vectors and requires the decoder to reconstruct the original sentences from this sentence embedding.'
]
sentences = []
def generate_sentence(jd):
spliter = re.compile(r'\.\s?\n?')
list_of_sentences = spliter.split(jd)
if len(sentences)<100_000:
sentences.extend([i for i in list_of_sentences if len(i)>30])
[generate_sentence(sent) for sent in data]
print('number of sentence',len(sentences))
def clean_sentence(text):
text = text.lower()
text = re.sub("[^ A-Za-z0-9.&,\-]"," ",text)
text = re.sub(' +',' ',text)
return text
sentences = [clean_sentence(i) for i in sentences]
# The DenoisingAutoEncoderDataset returns InputExamples in the format: texts=[noise_fn(sentence), sentence]
# add noise in traning data
train_data = DenoisingAutoEncoderDataset(sentences)
loader = DataLoader(train_data,batch_size=4,shuffle=True)
gte_model = models.Transformer('thenlper/gte-base')
polling = models.Pooling(gte_model.get_word_embedding_dimension(),'cls')
model = SentenceTransformer(modules = [gte_model,polling])
loss = DenoisingAutoEncoderLoss(model,tie_encoder_decoder = True)
model.fit([(loader,loss)],
epochs=1,
weight_decay=0,
scheduler='constantlr',
optimizer_params={'lr': 3e-5},
show_progress_bar=True
)
model.save('output/gte-base-fine-tune')
GPL (Generative Pseudo Labeling)
GPL (Generative Pseudo Labeling) is an unsupervised domain adaptation method for training dense retrieval.
GPL works in three steps:
- Query generation: For each passage in the unlabeled target corpus, a query is generated using a query generation model. In GPL papre, they use T5-encoder-decoder model for generate three queries for each passage.
- Negative mining: For each generated queries, using dense retrival method 50 negative passages are mined from the target corpus.
- Pseudo labeling: For each (query, positive passage, negative passage) tuple we compute the margin δ = CE(Q, P +) − CE(Q, P −) with CE the score as predicted by a cross-encoder, Q is query and P+ positive passage and P- negative passage.
MarginMSE loss function is used for training GPL(Generative Pseudo Labeling).
Python Code -
# https://github.com/UKPLab/gpl/tree/main
!pip install jsonlines
!pip install gpl
import gpl
import jsonlines
data = ['TSDAE (Tranformer and Sequential Denoising AutoEncoder) is one of most popular unsupervised training method. The main idea is reconstruct the original sentence from a corrupted sentence. The TSDAE model consists of two parts: an encoder and a decoder.',
' TSDAE encodes corrupted sentences into fixed-sized vectors and requires the decoder to reconstruct the original sentences from this sentence embedding.',
...
]
def prepare_data_for_gpl(data):
gpl_data = []
counter = 1
for i in data:
gpl_data.append({
"_id": str(counter),
"title": "",
"text": i,
"metadata": {}
})
counter+=1
return gpl_data
gpl_data = prepare_data_for_gpl(data)
with jsonlines.open('/content/jd-data/corpus.jsonl', 'w') as writer:
writer.write_all(gpl_data)
dataset = 'jd-data'
gpl.train(
path_to_generated_data=f"generated/{dataset}",
base_ckpt="thenlper/gte-base",
# The starting checkpoint of the experiments in the paper
gpl_score_function="dot",# Note that GPL uses MarginMSE loss, which works with dot-product
batch_size_gpl=4,
gpl_steps=3,
new_size=-1, # Resize the corpus, -1 means take all data
queries_per_passage=3,# Number of Queries Per Passage (QPP)
output_dir=f"output/{dataset}",
evaluation_data=f"./{dataset}",
evaluation_output=f"evaluation/{dataset}",
generator="BeIR/query-gen-msmarco-t5-base-v1",
retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
retriever_score_functions=["cos_sim", "cos_sim"],
# Note that these two retriever model work with cosine-similarity
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
qgen_prefix="qgen",
# This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
do_evaluation=False,
# use_amp=True # One can use this flag for enabling the efficient float16 precision
)
checkout GitHub Repo —