An Overview of the Various BERT Pre-Training Methods
If you are interested in machine learning then, over the past few years, you have likely heard of the Transformer model that has revolutionized Natural Language Processing.
A very popular variation of the Transformer is called BERT, which uses Transformer Encoders to learn text representations from unlabeled corpora. How do they learn from unlabeled data you ask? Well, they define a set of pre-training tasks for the model to learn from. Namely, Masked Language Modeling (MLM) and Next Sentence Prediction (NSP).
Note: In this article, I will assume you have background knowledge about BERT and that you are looking for more information on pre-training objectives.
BERT
Masked Language Modeling
This is the task of predicting a missing word in a sentence. As you can imagine, you don’t need labels for this learning objective because you can just mask out any word from an input sentence. Thus, if we have a sample in our dataset that looks like this
Machine Learning is Super Cool
During training, BERT may actually feed in as input something like
Machine [MASK] is Super Cool
And the goal of the model will be to predict the missing word. Specifically, BERT will mask some percentage of the input sentence at random (typically 15%), then ask the model to fill in the blanks.
Next Sentence Prediction
Since many important downstream tasks involve the relationship between two sentences, BERT also pre-trains on something called Next Sentence Prediction.
Simply put, the objective here is, given two sentences A and B, to use the [CLS] token to predict if B is in fact the sentence that comes after A. When constructing the data for this task, the input may look like this
[CLS]Sentence A [SEP] Sentence B[EOS]
We then feed this into BERT and the goal is for the [CLS] token to learn a representation that allows us to identify if Sentence B should come after Sentence A!
This very simple idea, BERT authors claim, plays a huge part in BERT’s ability to perform tasks such as Question Answering and Natural Language Inference, which require the comparison of two sentences.
Other Pre-Training Objectives
Above we have established what one might call the “baseline” pre-training objectives. In recent years, researchers have found even better ways to pre-train the BERT architecture to encourage the learning of better textual representations. The remainder of this article will highlight other popular pre-training methods.
RoBERTa: A Robustly Optimized BERT Pretraining Approach
In the RoBERTa paper, the authors highlight flaws in the baseline BERT pre-training objectives.
- RoBERTa disagrees with BERT’s use of a static mask. In the original BERT implementation, the authors perform masking once during their preprocessing stage. Specifically, each sample is duplicated 10 times, each masked in a unique way. Thus, the model will see the same sample masked 10 different ways over 40 epochs of training.
In RoBERTa, the authors propose the use of a dynamic mask that is randomly generating the mask every time a sample is fed into the model. Results show very similar performance to the static mask, however, this method is much more efficient as there is no need to duplicate the data 10x.
2. RoBERTa shows that the NSP task used in BERT is actually unnecessary and that removing it can improve performance. Specifically, this refers to the binary classification task asking “is sentence B the sentence which follows sentence A”. BERT uses the loss from this classification task to guide training.
Instead RoBERTa simply inputs 2 consecutive full sentences of the format
[CLS] Sentence A [SEP] Sentence B [EOS]
Without asking the model to predict if the sentences are consecutive.
BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
I was unsure if I should include BART in this article as it is a little different than BERT given it contains a decoder. At the risk of over-simplification, BART is BERT + an autoregressive Decoder (such as GPT-2). Given the presence of the decoder, the model has more flexibility in how it can formulate the pre-training objectives.
The high-level overview of how BART is trained is as follows. 1) Corrupt the input sentence. 2) Encode it with BERT. 3) Decode the BERT output 4) Compare decoding to ground truth sentence. As you can see, this is a fancy denoising auto-encoder.
Text Corruption Strategies (Pre-Training Objectives)
Token Masking
This is the same masking done in BERT. BART will simply mask some of the tokens in the input and try to have the decoder predict what is missing.
Token Deletion
Randomly delete tokens in the input sentence. This forces the model not only to predict missing tokens but to predict where they are.
Text Infilling
Delete sections of the input sentence and replace them with a single mask token. This forces the model to learn how many words [MASK] needs to be decoded into.
Document Rotation
Randomly select a token from the input, then rotate/wrap the input such that the input now starts with the randomly selected token.
In BART, one fine-tunes for classification tasks by inputting a sentence into the BERT encoder and using the decoder’s [CLS] equivalent for prediction.
SpanBERT: Improving Pre-training by Representing and Predicting Spans
SpanBERT presents a really simple idea that gives a great performance boost over the baseline pre-training objectives. Consider the image below from the original paper:
Span BERT does two novel things during pre-training
- They mask out contiguous spans of text in the original sentence. In the graphic above, you can see a set of 4 consecutive tokens replaced with [MASK]. The model is tasked with predicting this missing information.
- They introduce a Span Boundary Objective (SBO) that forces the model to use the representations of the tokens along the border of the masked region (x4 and x9 in the image above) to predict the missing tokens.
Further SpanBERT removes the NSP objective from BERT and only trains on single-sentence inputs.
TAKING NOTES ON THE FLY HELPS LANGUAGE PRE-TRAINING (TNF)
This paper points out that BERT training is, in part, so expensive because it has to learn good representations of extremely rare words. If you consider the sentence “COVID-19 has cost thousands of lives”, you can imagine that COVID-19 is not a common term in a given training corpus. Further, if we needed to predict “COVID-19 has cost thousands of [MASK]”, then it is the case that a rare word is the only useful piece of information we have to predict the missing token.
TNF solves this problem by keeping a dictionary, where the keys are these rare words and the values are vector representations of the word’s historical context. In other words, each rare word has a representation that is updated given its context in other sentences. By doing this, when a rare word comes up again, we can use this historical context embedding to assist the learning task in the presence of rare words.
In practice, if a rare word is detected, one simply finds the historical context vector in the rare word lookup table and adds it to the final input embedding. e.g. (word_embedding + position_embedding + TNF_embedding).
ELECTRA: PRE-TRAINING TEXT ENCODERS AS DISCRIMINATORS RATHER THAN GENERATORS
This pre-training approach also aims to make training BERT more efficient. The key idea here is to train using replaced token detection.
Replaced Token Detection
To perform this step, we need two transformer models 1) Generator and 2) Discriminator.
We first mask a few of the input tokens as shown above and the generator outputs predicted tokens. Next, the discriminator (which is the key innovation in Electra) has to decide which inputs are real and which are synthetic.
The reason this is so much more efficient is that now, instead of computing a loss over just the masked tokens (as done in BERT), we need to compute loss over all input tokens. This provides a better use of resources and allows one to train a BERT model much more quickly.
Span Selection Pre-training for Question Answering
I won’t go over this one in too much detail, but the overarching idea is to use an external resource to predict the missing information during Masked Language Modeling. In the image above, we see that the [BLANK] token is not filled in by the model itself (as is done in BERT), but is actually predicted from other relevant passages. This model outperformed BERT on Machine Reading Comprehension tasks by 3 F1 points. Source
SenseBERT: Driving Some Sense into BERT
The pre-training idea proposed in this paper forces the model not only to perform Masked Language Modeling but to predict the Super Sense of each masked token using WordNet.
CAPT: Contrastive Pre-Training for Learning Denoised Sequence Representations
If you are familiar with contrastive learning in ML, then you might be interested in CAPT. The goal of this work is to encourage the representations of input sentences to be as close as possible to the representations of corrupted input sentences. This helps the model become more invariant to noise as well as better capture the global semantics of a given sentence.
LIMIT-BERT : Linguistics Informed Multi-Task BERT
LIMIT-BERT does two things different from traditional BERT. 1) instead of randomly masking tokens, they mask semantically meaningful tokens as determined from pre-trained linguistics models. 2) They alter the BERT training objective to have a multi-task loss that requires LIMIT-BERT not only to perform MLM, but to be able to learn linguistic properties about the input data (Identifying parts of speech, identifying semantic roles, etc.).