Fine-tuning BERT for Text Classification: A Step-by-Step Guide
BERT is a powerful pre-trained language model that can be fine-tuned for a variety of NLP tasks. In this article, I will provide a step-by-step guide to fine-tuning BERT for document classification and sentiment analysis.
Prerequisites
To follow along with this tutorial, you will need:
- Python 3.6+
- PyTorch 1.0+
- Transformers library by Hugging Face
- Dataset for your text classification task
We will be using the 20 Newsgroups dataset for document classification and the SST-2 dataset for sentiment analysis.
Preparing the Data
For the 20 Newsgroups dataset, download the data and extract the files. We will create a data.csv
file with columns text
and label
.
For SST-2, download the dataset and you will get train and test CSV files with columns sentence
, sentiment
and label
.
We will load the data into PyTorch Datasets and Dataloaders to feed into our model.
Fine-tuning BERT
- Choose a BERT model: We will use
bert-base-uncased
for this tutorial. Load it with the Transformers library:
model = BertModel.from_pretrained('bert-base-uncased')
- Define a classifier head: We will add a classifier layer on top of BERT. For a single label, we use a Dense layer. For multiple labels, we use one output for each class.
classifier = nn.Linear(768, num_labels)
- Concatenate the BERT encoder and classifier into a single model:
model = nn.Sequential(model, classifier)
- Define an optimizer and loss function for your model. We will use cross-entropy loss for the classifiers:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-5)
- Create DataLoader objects from your datasets to feed data into the model. We use a batch size of 16:
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16)
- Define a training function to train your model. We train for 3–4 epochs:
def train(model, optimizer, train_loader, criterion):
model.train()
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
input_ids, attention_mask, labels = batch
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Training loss: {total_loss/len(train_loader)}')
- Define an evaluation function to evaluate your model on the test set:
def evaluate(model, test_loader, criterion):
model.eval()
total_loss = 0
total_acc = 0
with torch.no_grad():
for batch in test_loader:
input_ids, attention_mask, labels = batch
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, labels)
total_loss += loss.item()
predictions = torch.argmax(outputs, dim=1)
total_acc += (predictions == labels).sum().item()
print(f'Test loss: {total_loss/len(test_loader)} Test acc: {total_acc/len(test_set)*100}%')
- Train your model by calling the training function, and evaluate on the test set:
for epoch in range(3):
train(model, optimizer, train_loader, criterion)
evaluate(model, test_loader, criterion)
- Save your model:
torch.save(model.state_dict(), ' sentiment_model.pt')
You have fine-tuned BERT for your text classification task! You can now use the saved model to make predictions on new data.