Text Classification with BERT

Khang Pham
8 min readMay 9, 2023


Text Classification with BERT

What is Text Classification?

Text classification is a machine learning subfield that teaches computers how to classify text into different categories. It’s commonly used as a supervised learning technique, which means that the algorithm is trained on a set of texts that have already been labeled with their respective categories. Once it’s been trained on this data, the algorithm can use what it’s learned to make predictions about new, unlabeled texts.

The algorithm looks for patterns in the text to determine which category it belongs to. It’s like when learning to recognize a certain type of flower — we start to notice certain features that distinguish it from other types of flowers. With text classification, the algorithm is doing the same thing but with words and phrases instead.

Text classification is a versatile tool that is widely used in many real-world applications that you may have come across. For instance, an email that ended up in your spam folder is text classification at work. The model can differentiate between spam and non-spam emails by studying specific words or phrases that identify a given email as spam, such as “Congratulations, you have won” or “Today is your lucky day.”

Text classification is also valuable in analyzing social media posts’ sentiment, particularly when it comes to detecting negative sentiments like hate speech. By using a machine learning model, the text can be classified and monitored for offensive language and hate speech.

But text classification isn’t just for serious applications — it can also be used for fun things like categorizing news articles and videos by topic. This way, a user can easily find articles and videos that interest them without having to sift through irrelevant content. Text classification truly is a powerful tool with a variety of practical applications.

What is BERT for deep learning?

BERT, short for Bidirectional Encoder Representations from Transformers, is a powerful natural language processing (NLP) model developed by Google that uses a deep neural network architecture based on the state-of-the-art transformer model.

As we said earlier, the BERT model architecture is based on a deep neural network called a transformer, which is different from traditional NLP models that process text one word at a time. Instead, transformers can process the entire text input all at once, which helps them to capture the relationships between words and phrases more effectively.

How does the BERT model work for text classification?

BERT uses a multi-layer bidirectional transformer encoder to represent the input text in a high-dimensional space. That means it can take into account the entire context of each word in the sentence, which helps it to better understand the meaning of the text.

One of the most interesting things about BERT is that it’s a pre-trained model. This means that BERT can be trained on massive amounts of text data, such as books, articles, and websites, before it’s fine-tuned for specific downstream NLP tasks, including text classification.

By pre-training on a large corpus of text data, BERT can develop a deep understanding of the underlying structure and meaning of language, making it a highly effective tool for NLP tasks. Once pre-trained, BERT can be fine-tuned for specific tasks, which allows it to adapt to the specific nuances of the task and improve its accuracy.

It’s important to note that there are two different BERT variations: BERT base and BERT large. For the remainder of this article, we’ll be utilizing the BERT base model, which is a more compact version of BERT that still maintains a strong grasp of context and linguistic nuances. By employing BERT base, we can take advantage of its abilities while lowering computational demands, making it more suitable for a variety of text classification tasks and achieving remarkable results.

Tutorial on Text Classification using BERT

So why do we fine-tune BERT on the IMDB movie review dataset? Well, we want to tailor the already powerful BERT model for sentiment analysis tasks. BERT is excellent at understanding language structure and context, but it doesn’t naturally have sentiment analysis skills. By fine-tuning BERT for text classification with a labeled dataset, such as IMDB movie reviews, we give it the ability to accurately predict sentiments in the sentences it encounters.

In case you would like to run the below model yourself, you can find the IMDB data set in the following link:

IMDB Dataset of 50K Movie Reviews.

Step 1: Import the necessary libraries

This snippet of code is all about importing the essential tools we need for our project. We’re using PyTorch for the deep learning functionality, the transformers library for BERT, and essential methods from the scikit-learn library in order to handle data and check how well our model does.

import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

Step 2: Import the IMDB data set and preprocess it

The below code defines a function load_imdb_data that reads a CSV file containing IMDB movie reviews and their corresponding sentiments. It returns a list of review texts and a list of labels, where 1 represents a positive sentiment, and 0 represents a negative sentiment.

def load_imdb_data(data_file):
df = pd.read_csv(data_file)
texts = df['review'].tolist()
labels = [1 if sentiment == "positive" else 0 for sentiment in df['sentiment'].tolist()]
return texts, labels

We will then save our data set directory and insert it as an input to the load_imdb_data() function.

data_file = "/kaggle/input/imdb-dataset-of-50k-movie-reviews/IMDB Dataset.csv"
texts, labels = load_imdb_data(data_file)

Step 3: Create a custom dataset class for text classification

This is a custom dataset class that helps organize movie reviews and their sentiments for our BERT model. It takes care of tokenizing the text, handling the sequence length, and providing a neat package with input IDs, attention masks, and labels for our model to learn from.

class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

Step 4: Build our customer BERT classifier

In this step, we aim to create our own custom BERT classifier. The classifier is built on top of the famous BERT model, which is great at understanding text. We will then add a dropout layer to keep things in check and a linear layer to help us classify text.

Our BERTClassifier takes in some input IDs and attention masks, and runs them through BERT and the extra layers we added. The classifier returns our output as class scores.

class BERTClassifier(nn.Module):
def __init__(self, bert_model_name, num_classes):
super(BERTClassifier, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
x = self.dropout(pooled_output)
logits = self.fc(x)
return logits

Step 5: Define the train() function

The train() function takes the model, data loader, optimizer, scheduler, and device as its trainees. The function puts the model into training mode and then runs through each batch of data from the data loader. For each batch, it clears the optimizer’s gradients, gets the input IDs, attention masks, and labels, and feeds them to the model.

def train(model, data_loader, optimizer, scheduler, device):
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)

Step 6: Build our evaluation method

def evaluate(model, data_loader, device):
predictions = []
actual_labels = []
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
_, preds = torch.max(outputs, dim=1)
return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

Step 7: Build our prediction method

The predict_sentiment() function acts as our evaluation method. For each batch, it gets the input IDs, attention masks, and labels and feeds them to the model. The model then gives its best predictions, which are compared to the actual labels.

Finally, the function calculates the accuracy score and a classification report to let us know how well the model did in understanding movie reviews’ sentiments.

def predict_sentiment(text, model, tokenizer, device, max_length=128):
encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
_, preds = torch.max(outputs, dim=1)
return "positive" if preds.item() == 1 else "negative"

Step 8: Define our model’s parameters

Here, we are going to set up essential parameters for fine-tuning the BERTClassifier, including the BERT model name, number of classes, maximum input sequence length, batch size, number of training epochs, and learning rate, to help the model effectively understand movie reviews and their sentiments.

# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = 2
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5

Step 9: Loading and splitting the data.

train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

Step 10: Initialize tokenizer, dataset, and data loader

tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

Step 11: Set up the device and model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

Step 12: Set up optimizer and learning rate scheduler

optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Step 13: Training the model

for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
train(model, train_dataloader, optimizer, scheduler, device)
accuracy, report = evaluate(model, val_dataloader, device)
print(f"Validation Accuracy: {accuracy:.4f}")

Saving the final model

torch.save(model.state_dict(), "bert_classifier.pth")

Step 14: Evaluating our model’s performance

# Test sentiment prediction
test_text = "The movie was great and I really enjoyed the performances of the actors."
sentiment = predict_sentiment(test_text, model, tokenizer, device)
print("The movie was great and I really enjoyed the performances of the actors.")
print(f"Predicted sentiment: {sentiment}")

Output: The movie was great and I really enjoyed the performances of the actors.

Predicted sentiment: positive

# Test sentiment prediction
test_text = "The movie was so bad and I would not recommend it to anyone."
sentiment = predict_sentiment(test_text, model, tokenizer, device)
print("The movie was so bad and I would not recommend it to anyone.")
print(f"Predicted sentiment: {sentiment}")

Output: The movie was so bad and I would not recommend it to anyone.

Predicted sentiment: negative

# Test sentiment prediction
test_text = "Worst movie of the year."
sentiment = predict_sentiment(test_text, model, tokenizer, device)
print("Worst movie of the year.")
print(f"Predicted sentiment: {sentiment}")

Output: Worst movie of the year.

Predicted sentiment: negative

Final Thoughts on Text Classification Using BERT

To sum it up, BERT has seriously changed the game when it comes to text classification. It’s made things like sentiment analysis and topic sorting a whole lot better and faster. By taking those pre-trained models and customizing them for our own projects, we’re getting amazing results that help us out in the real world.

As we keep experimenting with BERT and similar models, there’s no doubt we’ll see even more cool stuff happening in the world of AI and language understanding.



Khang Pham

Tech marketer by day, father and husband by night, mountain biker and snowboarder whenever possible. Currently at Exxact Corporation.