Fine-tune Your Own BERT Token Classification Model

Use Hugging Face and TensorFlow to build a model that identifies molecular biology terms!

Raj Pulapakura
6 min readJan 20, 2024
TensorFlow + HuggingFace + DNA: Molecular Biology Token Classification

Over the last week I’ve been diving to the Hugging Face NLP libraries. It’s been a fun adventure, and I want to give back to the HF community with a tutorial on token classification for molecular biology (this is also for my learning!)

Brief

Token classification is an NLP task in which we classify each word in a sentence.

Basic Token Classification: Hugging Face Inference API

This model has identified the words Raj as a Person, Australia as a Location, and TensorFlow as an Organization.

Using HuggingFace, we can effortlessly load a token classification model with the pipeline API.

from transformers import pipeline

token_classifier = pipeline(
"token-classification",
"dbmdz/bert-large-cased-finetuned-conll03-english",
grouped_entities=True,
)

Using this model is simple:

token_classifier(
"My name is Raj, I live in Australia, and I love TensorFlow."
)
>>> [{'entity_group': 'PER',
>>> 'score': 0.99407405,
>>> 'word': 'Raj',
>>> 'start': 11,
>>> 'end': 14},
>>> {'entity_group': 'LOC',
>>> 'score': 0.999833,
>>> 'word': 'Australia',
>>> 'start': 26,
>>> 'end': 35},
>>> {'entity_group': 'ORG',
>>> 'score': 0.94205034,
>>> 'word': 'TensorFlow',
>>> 'start': 48,
>>> 'end': 58}]

This particular model is only trained to identify people, locations and organizations. If this is all you need, then hooray, you’re done 🥳!

But what if we wanted a model that can identify 🐈 animals, 🤬 foul language, or the 🧬 names of proteins?

To do this, we need to fine-tune a base language model on a new dataset. In this tutorial, we’ll be fine-tuning a BERT language model to identify terms in the molecular biology field, including terms related to DNA, RNA, proteins, cell line and cell type.

🧑‍💻👉 If you want to build the model yourself, here’s a colab notebook which will take you through all the steps, but you should still read this article as I will be explaining some concepts here.

Just for context, this is what our final model will be able to do:

Getting a Dataset

The first step in fine-tuning is to find a dataset to fine-tune on.

For our use case, we’ll be using the jnlpba dataset.

from datasets import load_dataset

raw_datasets = load_dataset("jnlpba")

To get a sense of the dataset, I’ve taken an example and labeled it.

As you can see, each word in the sentence has a corresponding label. There are 11 classes in total, and here’s what each one means:

  • O => ordinary word
  • B-DNA => beginning of a “DNA” term
  • I-DNA => contiunation of a “DNA” term
  • B-RNA=> beginning of an “RNA” term
  • I-RNA => contiunation of an “RNA” term
  • B-protein => beginning of a “protein” term
  • I-protein => continuation of a “protein” term
  • B-cell_line => beginning of a “cell line” term
  • I-cell_line => continuation of a “cell line” term
  • B-cell_type => beginning of a “cell type” term
  • I-cell_type => continuation of a “cell type” term

Preprocessing the dataset

Tokenization

A tokenizer breaks down a sentence into individual tokens. You can think of tokens as words, but their actually more like sub-words, little pieces of text that can be composed into longer words. Check out this comprehensive guide to subword tokenization to learn more.

We’re going to be using the bert-base-cased model, so we need to also use its corresponding tokenizer.

from transformers import AutoTokenizer

checkpoint = "bert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Along with tokenization, we have to do some additional preprocessing to align the labels with the tokens. If you’re following along with the colab, I’ve provided detailed explanations there.

Once we have the preprocessing function, we can use the neat Dataset.map() method to apply it to all our datasets.

tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)

Data Collation

Data Collation means taking our dataset and organanizing it in mini-batches.

You may have noticed that we haven’t padded our dataset yet, as models require each sample to be of equal length. Padding the entire dataset at once would be inefficient, as we would be padding each tensor to the length of the longest tensor in the entire dataset.

Instead, we can do this for each mini-batch, so each tensor is only padded up to the largest tensor in its mini-batch. This saves RAM and computation.

Using a data collator in Hugging Face is as easy as 1,2,3:

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer,
return_tensors="tf")

I forgot to mention, I’ve written this article using Hugging Face with TensorFlow (because I 💖 TensorFlow), but you can adjust it to use PyTorch if you like.

Fine-tuning!

Finally, the juicy part.

Real quick, let’s create our final datasets which will be passed into the model.

columns = ["attention_mask", "input_ids", "labels", "token_type_ids"]
batch_size = 16

tf_train_dataset = tokenized_datasets["train"].to_tf_dataset(
columns=columns,
collate_fn=data_collator,
batch_size=batch_size,
shuffle=True,
)

tf_eval_dataset = tokenized_datasets["validation"].to_tf_dataset(
columns=columns,
collate_fn=data_collator,
batch_size=batch_size,
shuffle=False,
)

Next, let’s load the bert-base-cased model (you can use a different model if you like, but you must remember to use its corresponding tokenizer).

from transformers import TFAutoModelForTokenClassification

checkpoint = "bert-base-cased"

model = TFAutoModelForTokenClassification.from_pretrained(
checkpoint,
id2label=id2label,
label2id=label2id,
)

If you want the labels (DNA, protein etc.) to show up in the Inference API on the Hugging Face website, you also need to pass id2label and label2id to your model, as shown above.

id2label = {0: 'O',
1: 'B-DNA',
2: 'I-DNA',
3: 'B-RNA',
4: 'I-RNA',
5: 'B-cell_line',
6: 'I-cell_line',
7: 'B-cell_type',
8: 'I-cell_type',
9: 'B-protein',
10: 'I-protein'}+


label2id = {'O': 0,
'B-DNA': 1,
'I-DNA': 2,
'B-RNA': 3,
'I-RNA': 4,
'B-cell_line': 5,
'I-cell_line': 6,
'B-cell_type': 7,
'I-cell_type': 8,
'B-protein': 9,
'I-protein': 10}

Let’s set up the optimizer with a learning rate decay:

from transformers import create_optimizer
import tensorflow as tf

num_epochs = 3
num_train_steps = len(tf_train_dataset) * num_epochs

# set up optimizer with learning rate decay
optimizer, schedule = create_optimizer(
init_lr=2e-5,
num_warmup_steps=0,
num_train_steps=num_train_steps,
weight_decay_rate=0.01
)

# compile
model.compile(optimizer=optimizer)

Notice how we don’t have to specify the loss, because the model will use its own internal loss (cross entropy in this case).

If you want to push the model up to your Hugging Face account, you need to log in to your Hugging Face account and use the following callback:

from transformers.keras_callbacks import PushToHubCallback

callback = PushToHubCallback(
output_dir="bert-finetuned-ner-medical", tokenizer=tokenizer)

Finally, let’s fine-tune our model. Fire away!

model.fit(
tf_train_dataset,
validation_data=tf_eval_dataset,
callbacks=[callback],
epochs=num_epochs
)

Output after training for 3 epochs:

Epoch 1/3
1160/1160 [==============================] - 484s 402ms/step - loss: 0.3065 - val_loss: 0.2755
Epoch 2/3
1160/1160 [==============================] - 420s 362ms/step - loss: 0.1835 - val_loss: 0.2722
Epoch 3/3
1160/1160 [==============================] - 435s 375ms/step - loss: 0.1514 - val_loss: 0.2864

Try it out

Now we can load our fine-tuned model using the wonderful pipeline API.

from transformers import pipeline

token_classifier = pipeline(
"token-classification",
"raj-p/bert-finetuned-ner-medical", # use you're account instead of "raj-p"
grouped_entities=True,
)

Let’s test it out!

token_classifier(
"In contrast , IL-12 induction of IFN-gamma cytoplasmic mRNA appears to only partially depend on activation of protein kinase C."
)
>>> [{'entity_group': 'protein',
>>> 'score': 0.9925425,
>>> 'word': 'IL - 12',
>>> 'start': 14,
>>> 'end': 19},
>>> {'entity_group': 'RNA',
>>> 'score': 0.9626391,
>>> 'word': 'IFN - gamma cytoplasmic mRNA',
>>> 'start': 33,
>>> 'end': 59},
>>> {'entity_group': 'protein',
>>> 'score': 0.9795227,
>>> 'word': 'protein kinase C',
>>> 'start': 110,
>>> 'end': 126}]

You can also use the Inference API on Hugging Face to try out your model, or you can try out mine.

That’s it for this tutorial! I hope you learned something new (I definitely learned lots!)

👉 My Previous Articles:

Follow me to stay updated, I post a new ML/AI article every Saturday.

💖 My Profiles:

With that, have an absolutely fantastic day ✨

--

--

Raj Pulapakura
Raj Pulapakura

Written by Raj Pulapakura

Machine Learning Engineer and Full Stack Developer. Passionate about advancing human intelligence and solving problems.