Using BERT with Pytorch

A super-easy practical guide to build you own fine tuned BERT based architecture using Pytorch.

Noa Lubin
4 min readJun 10, 2019
Bert image — sesame street

In this post I assume you are aware of BERT model and principles. If not, I highly encourage you to read the paper [1] and this post or hear my lecture about contextualised embeddings. If you are still missing some background, you might need to read about positional embeddings and transformers.

In this post you will find a super-easy practical guide with code examples to build you own fine tuned BERT based architecture using Pytorch. We will be using https://github.com/huggingface/pytorch-pretrained-BERT wonderful package.

If you understand BERT you might identify you will need to take these two steps in your code: tokenize the samples and build your own fine-tuned architecture.

  1. Tokenize the samples (BPE):

BERT uses a special tokenization (BPE) of the words. In addition, depending on your task, each sentence can be padded with [CLS] at the beginning of the first sentence and [SEP] tokens at the end of each sentence.
The [CLS] token is used mainly for classification tasks, and the [SEP] token is used for multiple sentences for tasks such as SNLI or question answering.

BERT input presentation [1]
from pytorch_pretrained_bert.tokenization import BertTokenizertokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)def get_tokenized_samples(samples, max_seq_length, tokenizer):
"""
we assume a function
label_map that maps each label to an index or vector encoding. Could also be a dictionary.
:param samples: we assume struct {.text, .label)
:param max_seq_length: the maximal sequence length
:param tokenizer: BERT tokenizer
:return: list of features
"""

features = []
for sample in samples:
textlist = sample.text.split(' ')
labellist = sample.label
tokens = []
labels = []
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word) #tokenize word according to BERT
tokens.extend(token)
label = labellist[i]
# fit labels to tokenized size of word
for m in range(len(token)):
if m == 0:
labels.append(label)
else:
labels.append("X")
# if we exceed max sequence length, cut sample
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]

ntokens = []
segment_ids = []
label_ids = []
# start with [CLS] token
ntokens.append("[CLS]")
segment_ids.append(0)
label_ids.append(label_map(["[CLS]"]))
for i, token in enumerate(tokens):
# append tokens
ntokens.append(token)
segment_ids.append(0)
label_ids.append(label_map(labels[i]))
# end with [SEP] token
ntokens.append("[SEP]")
segment_ids.append(0)
label_ids.append(label_map(["[SEP]"]))
# convert tokens to IDs
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
# build mask of tokens to be accounted for
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
# pad with zeros to maximal length
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append([0] * (len(label_list) + 1))

features.append((input_ids,
input_mask,
segment_ids,
label_id))
return features

2. Build your own architecture based on BERT

Unlike traditional embeddings, BERT embeddings are context related, therefore we need to rely on a pretrained BERT architecture. In full sentence classification tasks we add a classification layer on top of the output for the [CLS] token. In sequence tagging we will need the full output of the sequence. This simple example is a sequence tagging one.

Fine Tune BERT pre-training to your task [1]
from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModelclass MyBertBasedModel(BertPreTrainedModel):
"""
MyBertBasedModel inherits from BertPreTrainedModel which is an abstract class to handle weights initialization and
a simple interface for downloading and loading pre-trained models.
"""

def __init__(self, config, num_labels):
super(MyBertBasedModel, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config) # basic BERT model
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)


def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
# now you can implement any architecture that receives bert sequence output
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

if labels is not None:
loss_fct = MyLoss()
# it is important to activate the loss only on un-padded inputs
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1, self.num_labels)[active_loss]
loss = loss_fct(active_logits, active_labels)
return loss
else:
return logits

3. How it all comes together

train_tokenized_samples = get_tokenized_samples(
train_samples, args.max_seq_length, tokenizer)
model = MyBertBasedModel.from_pretrained(args.bert_model,
num_labels = num_labels)
model.train()
for range(n_epochs):
for sample in train_tokenized_samples:
input_ids, input_mask, segment_ids, label_ids = sample
loss = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()

I hope this makes working with pre-trained BERT model in Pytorch easier.

[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding: Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova: https://arxiv.org/abs/1810.04805, 2018

Become a ML Writer

--

--

Noa Lubin

data science manager, AI researcher, space enthusiast and social entrepreneur. I hope this blog helps you navigate your way into the incredible world of AI.