Fast Fine-tuning of Text Classification with SetFit

tiendh
4 min readMay 9, 2023

--

Text classification is a classical supervised task in natural language processing (NLP). Popular use cases of text classification include sentiment analysis, document classification, intent detection, fake news detection, spam detection, and topic labelling, to name a few. The recent advances in deep learning, especially pre-trained language models (PLMs) and the popularity of libraries such as transformers from HuggingFace (HF) has facilitated the task immensely. One can easily pick up a PTM and fine-tune it on a new dataset, and the fine-tuning can complete in minutes (depending on the volume of the selected dataset and hardware).

Yet the data is not always available or complete. In reality, for many text classification tasks in industry, the labelled datasets are normally not there. In some other cases, the volume of the dataset for fine-tuning is very small or the dataset is just incomplete. To get a decent-size dataset (e.g., 50k labelled examples as in the dataset imdb), it takes lots of time and effort to manually label the dataset, and in some cases impossible for companies with limited resources.

To address this challenge, researchers have come to an approach that can leverage PLMs with small labelled datasets, which is called few-shot learning. In literature, there are many proposed methods such as ADAPET and T-FEW. This article discusses SetFit, a few-shot learning method with proven performance for text classification.

Quick Review

SetFit leverages Sentence Transformer (Sentence BERT or SBERT). In its original paper, the method has two steps:

  • Fine-tune an SBERT model
  • Train a classifier head

SBERT is a model based on either siamese or triplet networks. In essence, it strives to obtain embeddings of text such that the embeddings of two similar text sequences (e.g., having same label) have small cosine distance while the embeddings of two dissimilar text sequences have larger cosine distance.

In SetFit, to fine-tune the SBERT model on a limited-size dataset with C classes, it samples triplets as follows:

  • Sample an anchor text s
  • Sample a positive text sᵖ within the same class of s
  • Sample a negative text sⁿ with a different class
  • Repeat these steps R times (e.g., R = 20)

This way, even with a small original dataset, the result dataset is much larger. The triplet network takes a triplet (sᵃ, sᵖ, sⁿ) as input and produces three embeddings. Then the network parameters are learned by optimising a special loss called “triplet loss”.

Figure 1. SetFit model (source: original paper)

After fine-tuning the network, text embeddings are generated. These embeddings are then used together with original labels to train a classification model. Classical classification heads like SVM from sklearn can be used.

There are a couple of interesting points we might note:

  • The method is not end-to-end, since two separate learning steps are used. However in its released implementation, these steps are nicely integrated, which is very easy to use.
  • Similar to siamese network, components in the triplet network share weight, hence when generating embeddings, any of them can be used.

Implementation

Dataset: On HF hub, there are many datasets for text classification. We will pick one dataset for binary text classi, namely cola (glue). The dataset has 8.55k rows for training and 1.04k rows for the validation set. We will use a small subset of the training sets to fine-tune a sentence transformer model named paraphrase-mpnet-base-v2 with SetFit. Following is the code, which is mainly based on the original article.

from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("glue", "cola")

N = 8 # Sampling N examples per class
train_dataset = sample_dataset(dataset[“train”], label_column=”label”, num_samples=N)
eval_dataset = dataset[“validation”]

# Load a SBERT model from HF hub
model = SetFitModel.from_pretrained(“sentence-transformers/paraphrase-mpnet-base-v2”)

# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20,
num_epochs=1,
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print("evaluation result: ", metrics)

# Inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst"])

Following is the result with different values of N.

Table 1. SetFit performance in terms of accuracy on Cola with different number of training examples.

In this simple experiment, Roberta-base and SBERT paraphrase-mpnet-base-v2are also used to fine-tune on the entire training set. It can be seen that the performance on the entire dataset is not far from that of very small subsets. Clearly, when a large and complete dataset is available, there’s no reason for using few-shot learning since one can achieve very good performance with standard training/fine-tuning. But as mentioned, that is not the case in many real-life situations. In addition, fine-tuning on the small subsets is much faster than the entire dataset. Hence, SetFit is a very promising method in case of limited resources.

--

--