Choosing the right parameters for pre-training BERT using TPU

Pooja Aggarwal
4 min readJan 14, 2021

--

source: Google Search

Pre-training a BERT model is not easy and many articles out there give a great high-level overview on what BERT is and the amazing things it can do, or go into depth about a really small implementation detail. This leaves aspiring Data Scientists, like me a while ago, often looking at Notebooks out there, thinking “It looks great and works, but why did the author choose this number of batch size or this sequence length instead of another?

In this article, I want to give some intuition on how to make some of the decisions to find the right configuration to pre-train the BERT model.

The full article with code and outputs can be found on Github as a Notebook.

Let us begin looking each input parameter.

1. Do lower Case

You can pre-train “uncased” or “cased” version of the BERT model by setting do_lower_case to true or false.

Uncased means that the text has been lowercased before tokenization, e.g., John Smith becomes john smith. The Uncased model also strips out any accent markers. Cased means that the true case and accent markers are preserved.

You can decide its value based on the type of task and the language of the input data.

Type of task

For tasks like,

  1. Named Entity Recognition
  2. Part-of-Speech tagging
  3. Sentiment Detection

Case information is an important signal. Imagine the word “us”. It could be pronoun representing “we” on the sentence or the country “USA” .

For the other tasks, we should use “uncased” model only. It does not have unnecessary duplicates e.g., both dhoni and Dhoni will be present in “cased” vocab but not in it and so, able to represent language better with same size vocab.

Input Language

Stripping accent from Devanagari and other non-latin scripts changes the meaning of word and hence, the meaning of sentence.

Example: कुरान (kuran) becomes करान (kran, ु is removed)

It is therefore recommended to use “cased” model for non-Latin languages.

2. Maximum Sequence length

The max_seq_length specifies the maximum number of tokens of the input. The input tokens are truncated or padded based on its value. Its is set in power of two e.g., 64, 128, 512.

You can decide its value based on the end task. Are you looking to do prediction on articles, sentences or phrases ?

One should keep larger value for the articles and smaller value for the phrases.

3. Train Batch Size

The train batch size is a number of samples processed before the model is updated. Larger batch size are preferred to get stable enough estimate of what the gradient of the full dataset would be. The batch size is always set in power of two e.g., 512, 1024, 2048.

Final input shape looks like (batch_size, max_seq_length, embedding_size). The embedding size is generally 768 for BERT based language models and sequence length is decided based on the end task as discussed above.

Our motive is to utilize our resource fully. So, you should set train batch size to maximum value based on the available ram.

4. Maximum Prediction Per Sequence

The max_predictions_per_seq is the maximum number of masked LM predictions per sequence. This is computed as follow.

max_predictions_per_seq= (max_seq_length* masked_lm_prob)For instance
max_seq_length= 512 and max_seq_length = 0.15
max_predictions_per_seq = 77

where masked_lm_prob is the percentage of words that are replaced with a [MASK] token in each sequence.

5. Number of Training Steps

Training steps are number of times we pass the batch to train the model. It is computed as follow

steps = (epoch * examples)/batch size
For instance
epoch = 100, examples = 1000 and batch_size = 1000
steps = 100

where epoch is the number of times you want to pass the complete dataset. You should keep its value more than 1.

In real world, we usually keep a very high value and stop manually as soon as the loss bottoms-out.

6. Learning Rate

learning rate, a positive scalar determining the size of the step.

we should not use a learning rate that is too large or too small. When the learning rate is too large, gradient descent can inadvertently increase rather than decrease the training error. When the learning rate is too small, training is not only slower, but may become permanently stuck with a high training error.

7. Number of Warm-up Steps

If your data set is highly differentiated, you can suffer from a sort of “early over-fitting”. If your shuffled data happens to include a cluster of related, strongly-featured observations, your model’s initial training can skew badly toward those features — or worse, toward incidental features that aren’t truly related to the topic at all.

Warm-up is a way to reduce the primacy effect of the early training examples. Without it, you may need to run a few extra epochs to get the convergence desired, as the model un-trains those early superstitions.

The learning rate is increased linearly over the warm-up period. If the target learning rate is p and the warm-up period is n, then the first batch iteration uses 1*p/n for its learning rate; the second uses 2*p/n, and so on: iteration i uses i*p/n, until we hit the nominal rate at iteration n.

Warm-up period is usually set to 1% of total training period.

Conclusion

In this articles, I discussed input parameters to pre-train the BERT model and explained the factors that help to decide their configuration.

Cheers!

--

--