Mastering Text Classification with BERT: A Comprehensive Guide

Furkan Ayık
10 min readDec 17, 2023

--

Introduction

Classifying text stands as a ubiquitous task within NLP. Its applications span various fields, from the categorization of text to detecting language used in a customer conversation.

As an everyday example, consider your email service’s spam filter — it likely employs text classification to shield your inbox from lots of undesired content.

Another common usage of text classification is sentiment analysis. which is an NLP technique that evaluates and interprets the emotions, opinions, or attitudes expressed within a piece of text. Analyzing the language used determines whether the sentiment conveyed is positive, negative, or neutral, providing insights into the emotional tone or inclination behind the words.

In this blog post, we will build a text classifier that identifies emotions in tweets. Here is the flow:

  1. Conventional NLP vs Era of Language Models
  2. A Little History of Language Models
  3. Architecture of BERT
  4. Exploring the Dataset
  5. Understanding Tokenization in BERT
  6. Training BERT with 2 Different Methods
  7. Error Analysis
  8. Conclusion and Future Directions

This blog post will let you build a text classifier with language models like the BERT family by following fundamentals.

Before Starting, you can find all the code in the notebook on my Kaggle account

Have a cup of coffee, which I would typically suggest to take an espresso for instantly flowing 63.6 mg of caffeine in your vessels with a single shot, and enjoy :)

Conventional NLP vs Era of Language Models

I do only consider using language models for text classification in this post, however in the case of hardware limitations or to create a base model please consider using conventional NLP methods like the below:

Text Classification Steps in Conventional NLP Methods
Text Classification Steps in Conventional NLP Methods

Text classification stands as a foundational pillar within natural language processing (NLP), serving as the bedrock for various applications that involve understanding and organizing textual data. At its core, text classification involves the automatic categorization of text documents into predefined classes or categories based on their content.

In Traditional NLP, the classification process relied on handcrafted features and machine learning algorithms, which often struggle to capture the complexity and nuances of language, especially with large and unstructured datasets.

Language models like BERT leverage transformer-based architectures to grasp the complexities of natural language more comprehensively.

BERT stands out due to its bidirectional nature, enabling it to consider the full context of a word by analyzing both its preceding and subsequent words in a sequence.

This bidirectional understanding of context allows BERT to

  • Understand subtle details
  • Idiomatic expressions
  • Contextual meanings within the text and the word

results in more accurate and context-aware text representations.

A Little History of Language Models

Language models (LM), model the likelihood of word sequences for predicting the probability of the next word/token.

The research community of LMs can be divided into 4 stages:

  • Statistical Language Models
  • Neural Language Models
  • Pre-Trained Language Models
  • Large Language Models

Statistical Language Models

Those are models based on statistical learning methods for word prediction based on Markov assumption. The SLMs have a fixed context length which is generally demonstrated with the’n’ letter, and are also called n-gram models (bigram, trigram etc.).

Drawback: Curse of dimensionality! The higher order of the language model means an exponential number of transitions must be estimated accurately. But the research community came up with some ideas like Good-Turing estimation which is a smoothing strategy and solves sparsity problems up to one point.

Neural language models (NLM)

In NLMs, transition probabilities are estimated by neural networks like multi-layer perceptrons or recurrent neural networks. One of the famous NLM methods is word2vec which builds a shallow neural network to encode word representations.

Pre-trained language models (PLM)

The predecessor of PLMs is ELMo. ELMo captures context-aware word representations since it is pre-trained on a large corpus with a bidirectional LSTM network backbone. This structure enables it to learn word representation without relying on a fixed length.

Architecture of ELMo

After a short period of ELMo paper, Transformer and Self-Attention mechanisms are used in a language model: BERT.

Large Language Models (LLMs)

When scaling PLMs, both in data size and model parameters, model capacity improves and able to understand more complex structures in the data. LLMs, can able to accomplish tasks that they are not directly trained on it (A good paper about the topic). For instance, they are performing well on zero-shot and few-shot predictions. A most well-known example is ChatGPT which has GPT-3 architecture as backbone.

Architecture of BERT

BERT Architecture — Detailed
BERT Architecture — Overview

BERT was trained on large corpora like Wikipedia (~2.5B words) and Google’s BooksCorpus (~800M words). These large datasets contributed to BERT to observe the context in deep in the text. However, the research community started to see LMs as a “world-model” since they are also able to build abstract relationships.

BERT’s architecture relies on 12 encoder transformer layers with around 110 Million parameters. To train such a model, the research team employed 64 TPUs which took 4 days!

BERT’s training process consists 2 steps:

Masked Language Model

In the masked language model, it’s 15% of the tokenized sentence is hidden to BERT and it needs to predict those tokens accurately. Thus the BERT can understand the underlying structure in the language it is trained for (English, by default).

Masked Language Model Logic

Next Sentence Prediction

NSP (Next Sentence Prediction) is used to teach BERT about relationships between sentences. The model needs to predict accurately if a given sentence follows the previous sentence or not.

Exploring the Dataset

For the project we are gonna utilize the ‘emotions’ dataset that consists 6 emotions categories as follows:

  • anger
  • disgust
  • fear
  • joy
  • sadness
  • surprise

Our goal is to train a classifier model that can classify a tweet into one of those categories accurately.

Class Distributions in Dataset

Training Set Class Distributions
Validation Set Class Distributions

Handle Imbalance Class Distributions

There is an imbalance distribution in the dataset. Especially samples from the ‘joy’ and ‘sadness’ classes are dominant. To overcome imbalance class distributions, the ‘imbalanced-learn’ library from Python is a way to handle that kind of problem. The library provides many methods to oversample minority class (ADASYN, SMOTE, Random Over Sample) or undersample majority class (Random Under Sampler, Condensed Nearest Neighbour, Neighbourhood Cleaning Rule).

Understanding Tokenization in BERT

Transformer-based models such as DistilBERT necessitate input in the form of tokenized and numerically encoded text rather than raw strings.

Tokenization involves the segmentation of a string into fundamental units recognized by the model. Various strategies for tokenization exist, with the ideal division of words into subunits typically being acquired from the corpus through learning.

BERT utilizes the tokenization algorithm ‘WordPiece’ developed by the Google Research team which is similar to BPE (Byte Pair Encoding) algorithm for the training phase however tokenization is done differently.

Text to Tokens

Let’s say our input is “The movie was not good”

Tokens: [‘[CLS]’, ‘the’, ‘movie’, ‘was’, ‘not’, ‘good’, ‘[SEP]’]

Input Indices (encoded text) → [101, 1996, 3185, 2001, 2025, 2204, 102]

BERT or other LMs expect numerical representation of tokens as input like input indices above.

Note:

  • [CLS] Token: Start of the sequence token
  • [SEP] Token: End of the sequence token
from transformers import AutoTokenizer

# Load Distilbert Tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
# Encode Our Example Text
encoded_text = tokenizer("The movie was not good")
tokens = tokenizer.convert_ids_to_tokens(encoded_text.input_ids)
print(encoded_text,len(encoded_text.input_ids))
print(tokens)
{'input_ids': [101, 1996, 3185, 2001, 2025, 2204, 102], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1]} 7
['[CLS]', 'the', 'movie', 'was', 'not', 'good', '[SEP]']

Training BERT

Those who made their hands dirty with pre-trained vision models should be familiar with pre-trained models like ResNet, VGG16, Inception-v3, MobileNet, and EfficientNet.

Those models are commonly used as feature extractors in solving many vision models successfully in two ways:

  1. Freeze all layers except the classification head, only train the classification head from scratch or start training on the existing head.
  2. Train all weights.

BERT and other LMs can also be used in those ways since they already learned high-level features and representations of the language (or languages) they are trained on.

Training Methodology-1: BERT as Feature Extractor

In this section I will show, how to train Bert with first methodology: freeze all layers of BERT and use it as a feature extractor. And train a classifier (or a bunch of classifiers) on those features with existing labels.

This method uses the latest hidden states of BERT to extract features. Once the features are extracted a feature matrix is used to create train, test and validation sets.

Now let’s get our hands dirty and don’t worry I will share all the code with more details in my Kaggle account which you can find at the end of this post.

Load Model and Tokenizer

from transformers import AutoModel
from transformers import AutoTokenizer

model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: ",device)
model = AutoModel.from_pretrained(model_ckpt).to(device)

Extract Hidden States

# Use whole dataset in Huggingface dataset format
batch = df_source
# Send inputs from CPU to GPU
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
# Extract last hidden states
# Disable gradient calculation on PyTorch Side
with torch.no_grad():
last_hidden_state = model(**inputs).last_hidden_state
# Return latest hidden state as numpy matrix
df_hidden = last_hidden_state[:,0].cpu().numpy()

Create Train/Test/Validation Sets

import numpy as np

X_train = np.array(df_hidden["train"]["hidden_state"])
X_valid = np.array(df_hidden["valid"]["hidden_state"])
X_test = np.array(df_hidden["test"]["hidden_state"])

y_train = np.array(df_hidden["train"]["label"])
y_valid = np.array(df_hidden["valid"]["label"])
y_test = np.array(df_hidden["test"]["label"])

X_train.shape, y_train.shape

Train a Simple Classifer with Logistic Regression

from sklearn.linear_model import LogisticRegression

clf_lr = LogisticRegression(max_iter=3000)
clf_lr.fit(X_train, y_train)
clf_lr.score(X_valid, y_valid)

Model Evaluation

The logistic regression classifier would achieve 78% accuracy on the test set, which is not a good but average result.

It is also easy to see, that the model performs poorly on classes with few samples like fear and surprise.

However, the results can be better with training different classifiers and hyperparameter tuning.

Training Methodology-2: Fine-Tune BERT

Fine-tuning BERT is quite easy with HuggingFace Library.

Define Training Arguments

First, let’s define training arguments:

# Set Batch Size
batch_size = 64
logging_steps = len(hf_dataset['train']) // batch_size
num_train_epochs = 30
lr_initial = 2e-5
weight_decay = 1e-3
output_dir = ""
training_args = TrainingArguments(output_dir=output_dir,
num_train_epochs=num_train_epochs,
learning_rate=lr_initial,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=weight_decay,
evaluation_strategy="epoch",
disable_tqdm=False,
logging_steps=logging_steps,
push_to_hub=False,
log_level="error")

Note: You can also add parameters for early stop as below:

from transformers import EarlyStoppingCallback, IntervalStrategy
args = TrainingArguments(
evaluation_strategy = IntervalStrategy.STEPS, # "steps"
eval_steps = 50, # Evaluation and Save happens every 50 steps
save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
metric_for_best_model = 'f1', # Metric to pick the best model
load_best_model_at_end=True,
...
)

These training arguments would allow the HuggingFace Trainer function to evaluate the model by validation set after each epoch.

Most metrics would automatically log into Tensorboard locally.

Metrics For Model Evaluation During

from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}

Train the Model

# Create Model
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))
model_name = f"models/{model_ckpt}-finetuned-bert-emotion-tweets"
training_args.output_dir = model_name
trainer = Trainer(model=model, args=training_args,
compute_metrics=compute_metrics,
train_dataset=df_encoded["train"],
eval_dataset=df_encoded["validation"],
tokenizer=tokenizer)
trainer.train()

After training for 2 epochs, the results will be as follows:

{'eval_loss': 0.31904336810112, 'eval_accuracy': 0.9025, 'eval_f1': 0.9011653468824986, 'eval_runtime': 1.649, 'eval_samples_per_second': 1212.839, 'eval_steps_per_second': 19.405, 'epoch': 1.0}
{'loss': 0.2464, 'learning_rate': 0.0, 'epoch': 2.0}

{'eval_loss': 0.2177528589963913, 'eval_accuracy': 0.9215, 'eval_f1': 0.9215056955967161, 'eval_runtime': 1.8332, 'eval_samples_per_second': 1090.98, 'eval_steps_per_second': 17.456, 'epoch': 2.0}
{'train_runtime': 98.2367, 'train_samples_per_second': 325.744, 'train_steps_per_second': 5.09, 'train_loss': 0.5314080963134765, 'epoch': 2.0}

Model Evaluation

Confusion Matrix on Validation Set
Confusion Matrix on Test Set

After just 2 epochs, model performance got much better: 92% accuracy on the test set, and much better generalization performance on a few classes.

Push Model To HuggingFace Ecosystem

trainer.push_to_hub(commit_message=”Fine-Tuned BERT For Sentiment Analysis in Tweets!”)

Error Analysis

Let’s sort some predictions in ascending order to analyze why the model failed to classify them correctly:

Focus on the first one:

i feel badly about reneging on my commitment to bring donuts to the faithful at holy family catholic church in columbus ohio

The sentence is labelled as ‘anger’ however the dataset says the true label is ‘joy’.

The person feels bad about not having the opportunity to achieve his wish and it must be classified as ‘sadness’. So both the model and label are wrong here.

— — — — — — — — — — — — — — — — — — —

Focus on the second one:

i feel that he was being overshadowed by the supporting characters

Actual Label: Joy

Predicted Label: Anger

The text is more related to anger rather than joy. So the label is wrong here. Detecting false labels is crucial in both training/testing datasets for deploying better models.

Conclusion and Future Directions

In conclusion, exploring the power of BERT in text classification reveals its transformative impact on natural language processing. BERT’s contextual understanding and pre-trained representations have elevated the accuracy and depth of text classification tasks across various domains.

Looking ahead, future blogs could focus on:

  1. BERT Variants and Transformers: Examining newer transformer architectures beyond BERT, like GPT (Generative Pre-trained Transformer) models, and their implications in text classification.
  2. Fairness in Language Models: Understanding and fixing any unfairness or biases in models like BERT, especially when they’re used to sort text.
  3. Domain-specific Applications: Focusing on specialized applications such as medical text classification or legal document analysis using BERT and similar models.

Stay tuned for more insights into the evolving landscape of NLP and the fascinating developments in the field of text classification!

--

--