Fine-Tuning un Cross-Encoder

Nicola Procopio
4 min readJan 3, 2024

--

Cos’è un cross-encoder?

Un Cross-Encoder è un classificatore di coppie di frasi. Richiede in input una coppia di testi e ne classifica la similarità mediante un indice tra 0 e 1. A differenza dei Bi-Encoder non calcola gli embeddings della frase.

Differenza tra bi-encoder e cross-encoder

Un Bi-Encoder prende in input una frase alla volta e ne calcola l’embedding. Per capire la similarità tra frasi su questi embeddings va calcolata una misura di similarità come la similarità del coseno o il prodotto scalare.

Un Cross-Encoder prende in input una coppia di frasi simultaneamente, non ne calcola gli embeddings, ma le classifica secondo un indice di similarità compreso tra 0 e 1.

N.B. se si vuole applicare il prodotto scalare assicurarsi che gli embeddings siano normalizzati

Training di un Cross-Encoder usando SBERT

Di seguito un esempio della creazione di un Cross-Encoder per l’italiano con la libreria sentence-transformers. Come per i Bi-Encoder ci sono diversi metodi per farlo a seconda del dataset scelto, qui verrà usato sempre un dataset formato Semantic Textual Similarity (STS) benchmark.

!pip install -U sentence-transformers
!pip install datasets

Dataset

A seguito dello scaricamento del dataset da Huggingface Hub bisogna preparare le coppie di frasi per il modello.

Il processo è molto simile a quello del Bi-Encoder ma qui, nel dataset di train le coppie vanno inserite in entrambe le combinazione quindi:

  • ([frase1, frase2], sim)
  • ([frase2, frase1], sim)

perchè lo score deve essere simmetrico. Per il dataset di valutazione e/o test non c’è bisogno.

Per maggiori info sulle altre tecniche consultare gli esempi.


from datasets import load_dataset
from sentence_transformers import InputExample
from torch.utils.data import DataLoader

dataset_train = load_dataset("stsb_multi_mt", name="it", split="train")
dataset_test = load_dataset("stsb_multi_mt", name="it", split="test")

gold_samples = []
batch_size = 16
for df in dataset_train:
score = float(df['similarity_score'])/5.0
gold_samples.append(InputExample(texts=[df['sentence1'], df['sentence2']], label=score))
gold_samples.append(InputExample(texts=[df['sentence2'], df['sentence1']], label=score))

train_dataloader = DataLoader(gold_samples, shuffle=True, batch_size=batch_size)


from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
import math

evaluator = CECorrelationEvaluator([ [x['sentence1'], x['sentence2']] for x in dataset_test], [x/5.0 for x in dataset_test['similarity_score']])

Train del modello

Partendo sempre da un BERT per l’italiano creiamo il nostro cross-encoder inizializzando l’head mediante il numero di etichette da predire.

Nel nostro caso, modello STS abbiamo una solo etichetta, ovvero lo score. Se avessimo avuto un dataset formato NLI (“contradiction”, “entailment”, “neutral”) avremmo inserito il numero delle etichette presenti.

Scelte il numero di epoche di addestramento e i warmup steps possiamo iniziare l’addestramento.

from sentence_transformers.cross_encoder import CrossEncoder

model_checkpoint = "dbmdz/bert-base-italian-uncased"
cross_encoder = CrossEncoder(model_checkpoint, num_labels=1)

num_epochs = 4
evaluation_steps = 500

warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

cross_encoder.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=evaluation_steps,
warmup_steps=warmup_steps,
save_best_model=True,
output_path = "cross-encoder-italian-bert-stsb/")

Evaluation

Solitamente i Cross-Encoder hanno prestazioni migliori sull’inferenza rispetto ai Bi-Encoder, in questo caso non ci siamo focalizzati sulla metrica infatti il circa 81% di accuratezza non è altissimo per questo modello.

evaluator(cross_encoder)

A cosa servono i Cross-Encoder?

I cross-encoder possono essere usati ogni volta che bisogna classificare coppie di frasi, come detto hanno prestazioni migliori dei Bi-Encoder ma scalano male.

I Bi-Encoder a loro volta possono essere usati per tutti quei task come la ricerca, il clustering, … comunque quando abbiamo tante frasi da confrontare.

Per dare una misura della poca scalabilità dei cross-encoder si pensi che fare clustering di 10000 frasi potrebbe richiedere il calcolo di 50 milioni combinazioni, circa 65 ore. Lo stesso task il bi-encoder lo risolve in pochi secondi.

Retrieve & Re-Rank

I Cross-Encoder sono molto performanti se combinati con i Bi-Encoder, in particolare nella ricerca. Una particolare tecnica che li coinvolge è il retrieve & re-rank che in poche parole consiste in:

  • retrieve: applicare il Bi-Encoder su una vasta base di possibili risultati per fare una “scrematura”
  • re-rank: applicare il Cross-Encoder per ordinare un sottoinsieme di risultati e scartare quelli irrilevanti

Esempio, in un dataset di 1 milione di documenti viene applicato un bi-encoder che restituisce 100 risultati per la richiesta effettuata dall’utente, su questi 100 viene applicato il cross-encoder che restituirà i primi 10 in ordine di somiglianza.

In SBERT si possono trovare diversi cross-encoder pre-addestrati

Il notebook eseguibile è su kaggle.
Il cross-encoder creato potete scaricarlo dall’Hub di Hugging Face

--

--

Nicola Procopio

Data Scientist and Open Source enthusiast. After many years of indecision whether to open a medium profile or not, here I am. I write in English and Italian.