Building a Text Classification Model using DistilBERT

Prakash Ramu
4 min readMar 29, 2024

--

In this blog post, we’ll walk through the process of building a text classification model using the DistilBERT model. Text classification is a fundamental task in natural language processing (NLP), and it involves categorizing text into predefined categories or classes. We’ll use the Hugging Face Transformers library, which provides easy-to-use interfaces to various pre-trained language models, including DistilBERT.

Data Preparation

let’s take a glance at our dataset. The provided data.csv screenshot showcases the structure of our dataset, comprising text responses and their corresponding labels. With this visual understanding, we’re poised to navigate the subsequent steps of data preparation, model training, and evaluation with clarity and purpose.

Before diving into model construction, it’s crucial to prepare the data. This involves loading the dataset, cleaning the text, encoding labels, and splitting the data into training and testing sets.

# Import necessary libraries
import pandas as pd
from bs4 import BeautifulSoup
import re
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

# Load data from CSV file
data_path = "data.csv"
df = pd.read_csv(data_path)
  • Imports: Necessary libraries are imported.
  • Load Data: CSV data is loaded into a DataFrame.
class TextCleaner():
def __init__(self):
pass

def clean_text(self, text):
text = text.lower()
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'http\S+', '', text)
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text

cleaner = TextCleaner()
df['cleaned_text'] = df['response'].apply(cleaner.clean_text)
  • Text Cleaning: A class is defined to clean text data, removing HTML tags, URLs, and special characters.
  • Apply Cleaning: Text cleaning function is applied to each text response, and results are stored in a new column.
label_encoder = preprocessing.LabelEncoder()
df['label'] = label_encoder.fit_transform(df['label'].tolist())

Encode Labels: Labels are encoded into numerical values for model training.

train_df, test_df = train_test_split(df, test_size=0.2)

Train-Test Split: Data is split into training and testing sets for model evaluation.

Model Training

With the data prepared, the next crucial step is to train our text classification model using the DistilBERT architecture. This section will be split into two parts: tokenization and model fine-tuning.

Tokenization

Before feeding the text data into the model, it needs to be converted into a format that the model can understand. This process is called tokenization. We utilize the Hugging Face AutoTokenizer to tokenize our text data.

from datasets import Dataset
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Convert datasets to tokenized format
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

def tokenize_data(examples):
return tokenizer(examples["cleaned_text"], truncation=True)

tokenized_train = train_dataset.map(tokenize_data, batched=True)
tokenized_test = test_dataset.map(tokenize_data, batched=True)

In this snippet, we load the DistilBERT tokenizer and then map it over our training and testing datasets. This converts our text data into tokenized sequences suitable for input into the model. We use to truncation = True ensure that sequences longer than the model's maximum input length are appropriately truncated.

Model Fine-Tuning

With the tokenized datasets prepared, we can proceed to fine-tune the DistilBERT model for our specific text classification task.

from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding

# Load pre-trained DistilBERT model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

# Prepare data collator for padding sequences
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-4,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=5,
weight_decay=0.01,
evaluation_strategy="epoch",
logging_strategy="epoch"
)

# Define Trainer object for training the model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_test,
tokenizer=tokenizer,
data_collator=data_collator,
)

# Train the model
trainer.train()

# Save the trained model
trainer.save_model('model')

Here, we instantiate the DistilBERT model for sequence classification using AutoModelForSequenceClassification. We define training arguments such as the learning rate, batch size, and number of epochs. Then, we create a Trainer object, specifying the model, training and evaluation datasets, and training arguments. Finally, we train the model using the train() method and save the trained model for future use.

Training the model involves iteratively adjusting the model’s parameters to minimize a predefined loss function. The goal is to optimize the model’s ability to classify text accurately based on the provided labels. Through this process, the model learns to extract relevant features from the input text and make predictions accordingly.

Conclusion

In this tutorial, we’ve learned how to build a text classification model using DistilBERT. We started by getting our data ready, then trained our model to understand and classify text into different categories. With this model, we can analyze sentiments, detect spam, or categorize topics. Text classification is a handy tool in understanding and organizing large amounts of text data, and with DistilBERT, we can do it efficiently and accurately. Now, armed with this knowledge, you’re ready to dive into the world of text analysis and make sense of the vast amount of textual information available.

Thank you for joining us on this journey! Should you have any further questions or wish to connect, feel free to reach out to me via my LinkedIn profile: Prakash Ramu. Let’s continue advancing in NLP together!

--

--