Multi-label Text Classification using Transformers(BERT)

Prasad Nageshkar
Analytics Vidhya
Published in
12 min readMar 12, 2021

Predicting Tags for a Question posted on Stack Exchange using a pre-trained BERT model from Hugging Face and PyTorch Lightning

Stack Exchange Logo
Stack Exchange

Stack Exchange is a network of 176 communities that are created and run by experts and enthusiasts who are passionate about a specific topic. The website primarily serves as a platform for users to ask and answer questions.

All questions are tagged with their subject areas. Each can have up to 5 tags since a question might be related to several subjects. Tags make it easy to find questions in your area of interest and enable you to either learn from the answers given or be able to answer them if they fall in your area of expertise. The tags may be entered by the user while posting the question or will have to be predicted by StackExchange based on the question posed.

Question Tags image
Example of Tags associated with a question

“Predicting & associating the correct tags with a question is important, in order to ensure that the question gets the attention of all people who can answer them based on the tagged subject areas. This will increase the chances of faster response and thus drive more engagement”

Objective:

Develop a machine learning model that will accurately predict all the tags (one or more) that could be associated with a question.

It is assumed that the reader has a reasonable background in Natural Language Processing (NLP) and some familiarity with PyTorch & Transformers in general and BERT in particular. This post is an outcome of my effort to solve a Multi-label Text classification problem using Transformers, hope it helps a few readers!

Approach:

The task of predicting ‘tags’ is basically a Multi-label Text classification problem. While there could be multiple approaches to solve this problem — our solution will be based on leveraging the power of the pre-trained Transformers (BERT) model and the PyTorch Lightning framework.

High-Level Steps:

  1. Install & Import Libraries
  2. Load and Pre-process the Data
  3. Prepare PyTorch Dataset & Lightning DataModule
  4. Define the Model (BERT based Classifier)
  5. Train the Model using Lightning Trainer
  6. Evaluate Performance of the Model
  7. Model Inference

1.Install & Import Libraries

The main libraries we need are a) Hugging Face Transformers (for BERT Model and Tokenizer), b) PyTorch (DL framework & Dataset prep), c)PyTorch Lightning(Model Definition and Training), d)Sklearn (for splitting dataset & metrics) and e)BeautifulSoup(for removing out HTML tags from the raw text in the given data).

2.Load & Pre-process the data

The needed dataset is available in two files Questions.csv and Tags.csv at Kaggle StatsQuestion. Load them in to separate pandas dataframes.

Questions Data
Tags Data

Questions dataframe: The ‘Body’ column contains the text in an HTML format, also this is the only column in the dataset that is useful for our task apart from the Id column.

Tags dataframe: This contains the tags associated with a question. A question has a distinct ID. Notice that we have 3 tags for Id=1.

For further processing, we need to join these two dataframes on the Id column that links them. Before that we clean the text in Body column — first remove the HTML tags using Beautiful soup, then remove all characters except alphabets using Regex and finally convert all text to lower case.

def pre_process(text):
text = BeautifulSoup(text).get_text()
# fetch alphabetic characters
text = re.sub("[^a-zA-Z]", " ", text)
# convert text to lower case
text = text.lower()
# split text into tokens to remove whitespaces
tokens = text.split()
return " ".join(tokens)

There are about 85 k rows of Questions and 1315 unique tags. Many of the tags have a very low count and don’t really matter. For the scope of this problem, we will restrict ourselves only to the Top 10 tags. That gives us a total of about 11 k rows of questions — which is decent enough given that we are using a pre-trained BERT model. Finally, we need to merge both the dataframes to generate a single dataframe that contains only 3 columns — Id,Body and Tags.

Here are the top 10 tags:

Top 10 tags
Top 10 Tags

Here is the data structure that will be used for training and testing the model: ‘Clean_Body’ (question) column contains the input for training and ‘tags’ column contains the label or the target. The multi-label structure of the tags is clearly evident below:

Final dataframe

Check the length of text (word count in a sentence):

Before we go ahead, we need to transform the text data into a numerical representation (the models understand only numbers). Transformer models can’t handle more than 512 words at a time. A quick histogram plot reveals that most of the questions have a word count < 300. Also in general, that much length is reasonable for the model develops sufficient context to be able to perform classification for a narrow problem. We will restrict ourselves to the first 300 words

Split data into Training, Validation and Test dataset:

from sklearn.model_selection import train_test_split
# First Split for Train and Test
x_train,x_test,y_train,y_test = train_test_split(x, yt, test_size=0.1, random_state=RANDOM_SEED,shuffle=True)
# Next split Train in to training and validation
x_tr,x_val,y_tr,y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=RANDOM_SEED,shuffle=True)

3.Preparing the Dataset and DataModule

Since the machine learning model can only process numerical data — we need to encode, both, the tags (labels) and the text of Clean-Body(question) into a numerical format.

Encoding tags: We use the MultiLabelBinarizer() class from sklearn. This is used to transform the tags into a binary format — each unique tag has a position — a 1 at a position corresponding to a tag indicates the presence of a tag and a 0 indicates the absence of the tag. We have only 10 tags so we will have a label vector with a length of 10.

from sklearn.preprocessing import MultiLabelBinarizer
mlb = MultiLabelBinarizer()
yt = mlb.fit_transform(y)
# Getting a sense of how the tags data looks like
print(yt[0])
print(mlb.inverse_transform(yt[0].reshape(1,-1)))
print(mlb.classes_)
------------------------------------------
Output:
[0 0 0 0 0 0 1 0 0 1]
[('r', 'time series')]
['classification' 'distributions' 'hypothesis testing' 'logistic'
'machine learning' 'probability' 'r' 'regression' 'self study'
'time series']

Encoding input (question): We need to tokenize and encode the text data numerically in a structured format required for BERT, the BERTTokenizer class from the Hugging Face (transformers) library makes this a simple affair. The encode_plus() makes this a one-line code for us.

inputs = self.tokenizer.encode_plus(
text,
None,
add_special_tokens=True,#Add [CLS] [SEP] tokens
max_length= self.max_len,
padding = 'max_length',
return_token_type_ids= False,
return_attention_mask= True,#diff normal/pad tokens
truncation= True,# Truncate data beyond max length
return_tensors = 'pt' # PyTorch Tensor format
)

First create QTagDataset class based on the Dataset class, that readies the text in a format needed for the BERT Model.

class QTagDataset (Dataset):
def __init__(self,quest,tags, tokenizer, max_len):
self.tokenizer = tokenizer
self.text = quest
self.labels = tags
self.max_len = max_len

def __len__(self):
return len(self.text)

def __getitem__(self, item_idx):
text = self.text[item_idx]
inputs = self.tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length= self.max_len,
padding = 'max_length',
return_token_type_ids= False,
return_attention_mask= True,
truncation=True,
return_tensors = 'pt'
)

input_ids = inputs['input_ids'].flatten()
attn_mask = inputs['attention_mask'].flatten()

return {
'input_ids': input_ids ,
'attention_mask': attn_mask,
'label':torch.tensor(self.labels[item_idx],dtype= torch.float)

}

Since we are using Pytorch Lightning for Model training — we will setup the QTagDataModule class that is derived from the LightningDataModule.

class QTagDataModule (pl.LightningDataModule):

def _init__(self,x_tr,y_tr,x_val,y_val,x_test,y_test,tokenizer,
batch_size=16,max_token_len=200):
super().__init__()
self.tr_text = x_tr
self.tr_label = y_tr
self.val_text = x_val
self.val_label = y_val
self.test_text = x_test
self.test_label = y_test
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_token_len = max_token_len

def setup(self):
self.train_dataset = QTagDataset(quest=self.tr_text, tags=self.tr_label,tokenizer=self.tokenizer,max_len= self.max_token_len)
self.val_dataset= QTagDataset(quest=self.val_text, tags=self.val_label,tokenizer=self.tokenizer,max_len = self.max_token_len) self.test_dataset =QTagDataset(quest=self.test_text, tags=self.test_label,tokenizer=self.tokenizer,max_len = self.max_token_len)


def train_dataloader(self):
return DataLoader(self.train_dataset,batch_size= self.batch_size, shuffle = True , num_workers=4)

def val_dataloader(self):
return DataLoader (self.val_dataset,batch_size= 16)

def test_dataloader(self):
return DataLoader (self.test_dataset,batch_size= 16)

Setup the DataModule:

# Instantiate and set up the data_module
QTdata_module = QTagDataModule(x_tr,y_tr,x_val,y_val,x_test,y_test,
Bert_tokenizer,BATCH_SIZE,MAX_LEN)
QTdata_module.setup()

4.Define the Model ( BERT based Classifier)

The out-of-the-box BERT model has already been pre-trained on Wikipedia and Book Corpus and thus has a good understanding of generic English text. However, the particular dataset from StackExchange comprises a lot of technology-related words, which the BERT model may not have seen during the pre-training phase.

Hence we need to fine-tune the model on our dataset so that it can build an understanding of our dataset and become better at the text classification task. The way to do that is to add a classification head on top of the core BERT model and then train the entire model on our dataset. This post discusses using BERT for multi-label classification, however, BERT can also be used used for performing other tasks like Question Answering, Named Entity Recognition, or Keyword Extraction. The tasks like the above in NLP parlance, are also referred to as downstream tasks.

In this text classification task -we make use of the BERT Base model which outputs a vector of length 768 for each word(token) and also for the pooled output (CLS). The pooled output at the end of the model training cycle gathers sufficient context of the task and is able to help in making predictions. Since our prediction task basically needs probabilities of only 10 labels(tags) we add a Linear layer of 10 outputs on top of the 768 outputs from BERT.

Since the output is multi-label (multiple tags associated with a question), we may tend to use a Sigmoid activation function for the final output and a Binary Cross-Entropy loss function. However, the Pytorch documentation recommends using the BCEWithLogitsLoss () function which combines a Sigmoid layer and the BCELoss in one single class instead of having a plain Sigmoid followed by a BCELoss.

# we will use the BERT base model(the smaller one)
BERT_MODEL_NAME = "bert-base-cased"
class QTagClassifier(pl.LightningModule):
# Set up the classifier
def __init__(self,n_classes=10,steps_per_epoch=None,n_epochs=3, lr=2e-5):
super().__init__()

self.bert=BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
self.classifier=nn.Linear(self.bert.config.hidden_size,
n_classes)
self.steps_per_epoch = steps_per_epoch
self.n_epochs = n_epochs
self.lr = lr
self.criterion = nn.BCEWithLogitsLoss()
def forward(self,input_ids, attn_mask):
output = self.bert(input_ids=input_ids,attention_mask=attn_mask)
output = self.classifier(output.pooler_output)

return output

5.Train the Model (using Pytorch Lightning Trainer)

For a typical Pytorch training cycle, we need to implement the loop for epochs, iterate through the mini-batches, perform feedforward pass for each mini-batch, compute the loss, perform backpropagation for each batch and then finally update the gradients.

PyTorch Lightning restructures and abstracts that out, we basically provide the configurable details like an optimizer, learning rate, number of Epochs, and Lightning takes care of the rest.

PyTorch Lightning is a high-level framework built on top of PyTorch.It provides structuring and abstraction to the traditional way of doing Deep Learning with PyTorch code. Basically, it reduces the code we need to write and allows us to focus on experimental problems like hyperparameter tuning, finding the best model, and visualizing the results. Also Lightning will handle how to run your model on multiple GPUs or speeding up the code.

The pl.LightningModule is similar to nn.Module of PyTorch but with added functionality - our classifier model is derived from that. The LightningModule defines the model structure and thus has methods for implementing training, validation, and configuring the optimizer. It organizes our PyTorch code into 5 sections:

  • Initialization(__init__)
  • Inference (forward)
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Lightning also enables us to define a callback for Modelcheckpoint() that runs automatically, in conjunction with the training loop. A Checkpoint is an intermediate dump of a model’s entire internal state(architecture, weights, state of the optimizer, epoch, hyperparameters, etc.) to a file on the machine. Thus Checkpointing the training process allows us to resume a training process in case it was interrupted, fine-tune a model, or use a pre-trained model for inference without having to retrain the model. Checkpointing also allows us to define the ‘loss’ or ‘accuracy’ criteria to save the best-performing model.

# saves a file like: input/QTag-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',# monitored quantity
filename='QTag-{epoch:02d}-{val_loss:.2f}',
save_top_k=3, # save the top 3 models
mode='min', # mode of the monitored quantity for optimization
)

Initialize the Hyperparameters

# Initialize the parameters that will be use for training
N_EPOCHS = 12
BATCH_SIZE = 32
MAX_LEN = 300
LR = 2e-05

After we organize our code into a LightningModule, the Trainer() automates everything else. Here is how we typically use the Trainer()

trainer = Trainer()
trainer.fit(model, datamodule)

Under the hood, the Lightning Trainer handles the training loop details for us, here are some of the things it does:

  • Automatically enabling/disabling grads
  • Running the training, validation, and test dataloaders
  • Calling the Callbacks at the appropriate times
  • Putting batches and computations on the correct devices(GPU/CPU)
  • Progress indicator

Here is our implementation:

# Instantiate the Model Trainer
trainer = pl.Trainer(max_epochs = N_EPOCHS , gpus = 1, callbacks=[checkpoint_callback],progress_bar_refresh_rate = 30)
# Train the Classifier Model
trainer.fit(model, QTdata_module)

6.Evaluate Performance on the Test Dataset

Lightning supports integration with popular frameworks (TensorBoard, Comet, Weights & Biases, Neptune..)that help us log, track and visualize performance/results of machine learning experiments.

This basically requires a logger to log the configuration information (parameters/hyperparameters), results and metrics. We just have to simply pass the logger to the Trainer(). Since Lightning uses TensorBoard by default — you will not find it being explicitly passed to the Trainer in our code.

Here is how this is typically done:

# Example of using logger from wandb.ai (Weights & Biases Inc.)from pytorch_lightning.loggers import WandbLogger 
wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)

Visualizing Performance using Tensorboard:

# Visualize the logs using tensorboard.%load_ext tensorboard
%tensorboard --logdir lightning_logs/
Training Loss Vs Training Steps
Validation Loss Vs Training Steps
# Evaluate the model performance on the test dataset
trainer.test(model,datamodule=QTdata_module)
Out[ ]:[{'test_loss': 0.2652013897895813}]

The classifier outputs a vector that has a probability of each tag — however, we need to roll it over to a 1 or 0 depending on whether it is above or below a threshold value. While using a threshold = 0.5 is possible, we can try out different values of thresholds between 0.3 and 0.51 to see which maximizes the prediction performance on the test set. For this problem, a threshold value of 0.4 gives the best result. Take a look at the code for details.

Here are the performance metrics:

              precision    recall  f1-score   support

0 0.94 0.93 0.94 8748
1 0.76 0.77 0.77 2362

accuracy 0.90 11110
macro avg 0.85 0.85 0.85 11110
weighted avg 0.90 0.90 0.90 11110

Here is a table of predicted and actual tags of randomly sampled Test Data:

7. Model Inference

Using the best performing trained model — we can start predicting tags that can be associated with any relevant question that we have.

question = "based on the following relationship between matthew s correlation coefficient mcc and chi square mcc is the pearson product moment correlation coefficient is it possible to conclude that by having imbalanced binary classification problem n and p df following mcc is significant mcc sqrt which is mcc when comparing two algorithms a b with trials of times if mean mcc a mcc a mean mcc b mcc b then a significantly outperforms b thanks in advance edit roc curves provide an overly optimistic view of the performance for imbalanced binary classification regarding threshold i m not a big fan of not using it as finally one have to decide for a threshold and quite frankly that person has no more information than me to decide upon hence providing pr or roc curves are just for the sake of circumventing the problem for publishing"

tags = predict(question)
if not tags[0]:
print('No Known Tags')
else:
print(f'Following are the Tags associated : {tags}')

Here is the output:

Following Tags are associated : 
[('classification', 'machine learning')]

Conclusion

In this article, we have built a Multi-label Text Classification Model using pre-trained BERT. We also wanted to get a sense of how PyTorch Lightning helps the training of the Model. I have tried explaining the important aspects in the article and here is the link to my code.

No serious attempt has been made to improve the performance since the idea was to primarily do a POC — you may want to play with the hyper-parameters to improve the performance. In the past, I had used CNN (Keras/Tensorflow) to build a model for this problem. Compared to that model, the BERT model takes a significantly higher time to train, and the achieved gain in performance is extremely small.

You may want to restructure the code a bit and enclose the inferencing code into an API (like FastAPI) and provide a web interface using Streamlit for users to use the model for predicting tags.

You could run the notebook on Google Colab. Use the dataset from Kaggle (link provided in the notebook). Don’t forget to make changes in the code in order to set the correct path for the file containing the data.

Thank you for reading!

Please Clap if you learned something new in this post! It will motivate me to write more to help more people!

References

Fine-Tuning BERT with HuggingFace and PyTorch Lightning

Introduction to Pytorch Lightning

Pytorch Lightning Documentation

StackExchange Dataset on Kaggle

--

--

Prasad Nageshkar
Analytics Vidhya

Applied Natural Language Processing, Machine Learning