Implementation of a simple Masked Language Model

A step by step Pytorch implementation from scratch

Ngieng Kianyew
9 min readSep 10, 2023

Are the probabilities of mask tokens, randomly replaced tokens and unchanged tokens really as what was indicated in the paper? Although the intuitions and foundations for MLMs are extremely simple to understand, I am curious see (1) what are some tricks researchers use and (2) how researchers ensure the specified probability, and (3) if they do not , what are the kind of deviations they are willing to accept.

In this article, i will describe Bert’s masked language modeling and implement it from scratch in Pytorch and summarise what I learned. The resource that I used is: https://nn.labml.ai/transformers/mlm/index.html

However, this resource may not provide a simple dataset for us to run the code. As a result, what I learned is that the quickest and best way to learn cutting-edge NLP neural network architectures is to always implement a dummy dataset for us to (1) run the code and (2) look at how the input and output change

Describing Bert’s Masked Language Modelling

Description of masked language modeling from paper

This is what I summarised with more clarify:

  • (1) 15% of the overall tokens are randomly selected for prediction
    - For example, if we have 100 tokens in a sequence, then 15% * 100 = 15 tokens will be chosen for prediction in our label
  • (2) 10% out of the 15% are replaced with a random token
    - For example, 10% out of 15 tokens = 1.5, rounding it up will be = 2, so 2 tokens will be replaced with random tokens
  • (3) 10% out of the 15% are unchanged.
    - For example, 10% out of 15 tokens = 1.5, rounding it up will be = 2, so 2 tokens will be unchanged
  • (4) 80% out of the 15% are replaced with [MASK]
    - For example, 80% out of 15 tokens = 12 tokens will be replaced with [MASK]

Looking at this example, we can immediately realize that even though 15 tokens are selected for prediction, strictly following this rule will give an error: when we do the math:
- number_of_tokens = 12(for mask) + 2(for random) + 2(for unchanged) = 16

A brute force method will be for each sequence, to find the index of the 15% of the tokens, select 10% of the index for a random token and another 10% of the index for unchanged tokens. However, with so many sequences and the length of the sequence can be large, this is seemingly too computationally expensive to do for each sequence.

Therefore, there must be some trick involved.

Pytorch implementation of Masked Language Modeling

Following this resource: https://nn.labml.ai/transformers/mlm/index.html, I implemented it line by line, and tried to understand how the masking rule is being created.

Since the resource does not provide sample data for us, I created this code to generate a batch of sequences of variable length, and then pad them with 0 until the max_len

vocab_size = 10 # including [mask] and [pad]
max_len = 5
num_seq = 5

def gen_sample_data(vocab_size, max_len, num_seq):
"""generate a list of text with variable lengths
"""
# minus 2 for [0: padding ,1: mask]
gen_single_sequence = lambda : torch.randint(2, vocab_size-3, size=(torch.randint(1,max_len, size=(1,)),))
return [gen_single_sequence() for _ in range(num_seq)]

seqs = gen_sample_data(vocab_size, max_len, num_seq)

def batch_data(data):
"""Generate batched_data with padding
"""
num_samples = len(data)
full_data = torch.zeros(num_samples, max_len)
for i, sent in enumerate(data):
min_length = min(len(sent), max_len)
full_data[i, :min_length] = sent[:min_length]
return full_data.long()

batch_data = batch_data(seqs)
batch_data
# tensor([[4, 0, 0, 0, 0],
# [5, 2, 0, 0, 0],
# [5, 2, 4, 0, 0],
# [2, 3, 6, 6, 0],
# [6, 0, 0, 0, 0]])

(1) Select 15% of tokens for prediction

masking_prob = 0.15 
full_mask = torch.randn(batch_data.shape) < masking_prob
full_mask
# tensor([[ True, False, True, False, False],
# [False, True, True, True, *True*],
# [ True, False, True, False, True],
# [False, False, False, True, True],
# [ True, True, False, False, True]])

** Note that we use masking_prob as an alias for the probability of being selected from prediction as per the resource

We call this the full_mask and True means that the token will be selected for prediction (indicated by the asterisk *)

Note that at this point, some of our padding tokens are chosen for prediction. Therefore, we need to remove them from the mask.

special_tokens = [0]
for tk in special_tokens:
full_mask = full_mask & (batch_data != tk)
full_mask
# tensor([[ True, False, False, False, False],
# [False, True, False, False, *False*],
# [ True, False, True, False, False],
# [False, False, False, True, False],
# [ True, False, False, False, False]])

We can see that the padding token is now not selected for prediction(indicated by the asterisk *)

Learning points:

  • We do not mask special tokens even though they are selected for prediction
  • The interesting use of full_mask & (batch_data != tk to modify a mask to exclude the special tokens
  • Note that at this point, when we remove the masking on the [PAD] tokens, the number of tokens selected for prediction is no longer 15%.

(2) 10% out of the 15% are replaced with a random token

random_prob = 0.1
random_mask = torch.randn(batch_data.shape) < random_prob
# for all the tokens that should be masked, select those that should be randomly masked
full_mask_with_random = full_mask & (random_mask)
full_mask_with_random
# tensor([[ True, False, False, False, False],
# [False, True, False, False, False],
# [ True, False, False, False, False],
# [False, False, False, False, False],
# [ True, False, False, False, False]])

We call this the full_mask_with_random to indicate that this mask is to get 10% of the tokens selected for the prediction to be replaced with a random token.

Learning points:

  • We are deciding the mask for random tokens NOT directly from the full mask but from the original batch_data and then with full_mask & (random_mask).

This initially was puzzling to me, but after thinking for a bit, I reckon it is because we assume that full_mask and random_mask are independent. And because they are independent, we can use the Product rule for independent events from our stats class:
— If A and B are independent, P(AB)=P(A)P(B) (because P(A|B)=P(A) for independent events).

Example:
masking_prob = 0.15, random_prob = 0.1,
therefore masking_prob * random_prob = 0.15 * 0.1,
effectively meaning 10% of the 15% tokens chosen for predictions are replaced with random token.

(3) 10% out of the 15% are unchanged

unchanged_prob = 0.1
unchanged_mask = torch.randn(batch_data.shape) < unchanged_prob
# for all the tokens that should be masked, select those that should be unchanged
full_mask_with_unchanged = full_mask & (unchanged_mask)
full_mask_with_unchanged
# tensor([[ True, True, True, True, False],
# [False, False, False, False, False],
# [False, False, False, False, False],
# [False, False, False, False, False],
# [False, False, False, False, False]])

We call this the full_mask_with_random to indicate that this mask is to get 10% of the tokens selected for the prediction are unchanged.

Learning points:

  • We are deciding the mask for unchanged tokens NOT directly from the full mask but from the original batch_data and then with full_mask & (unchanged_mask).
    (same reasoning for this implementation as the full_mask_with_random)
  • However, note that because we are also basing our full_mask_with_unchanged on batch_data, it is possible that an index is selected to be replaced with random tokens and unchanged. This is something that the implementation allows.
    (As indicated by asterisks (*) below)
full_mask_with_unchanged
# tensor([[ *True*, True, True, True, False],
# [False, False, False, False, False],
# [False, False, False, False, False],
# [False, False, False, False, False],
# [False, False, False, False, False]])

full_mask_with_random
# tensor([[ *True*, False, False, False, False],
# [False, True, False, False, False],
# [ True, False, False, False, False],
# [False, False, False, False, False],
# [ True, False, False, False, False]])

(4) 80% out of the 15% are replaced with [MASK]

# get the mask for [mask] tokens
full_mask_with_mask = full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged)
full_mask_with_mask
# tensor([[False, False, False, False, False],
# [False, False, False, False, False],
# [False, False, True, False, False],
# [False, False, False, True, False],
# [False, False, False, False, False]])

We call this the full_mask_with_mask to indicate that this mask is to get 80% of the tokens selected for the prediction to be replaced [MASK]

Learning points:

  • We are deciding the mask for [MASK] NOT directly from the full mask but with full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged)
  • The reasoning for this code:
    full_mask_with_mask = full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged) is because, for those tokens that are not unchanged or replaced with random, they should be replaced with [MASK](The remaining 80%). This ensures that the token can be only chosen to be replaced with [MASK], ensuring no conflicts with the index being selected for randomization or unchanged.

(5) Fill the corresponding values in the corresponding masks

  • Filling in the tokens that should be replaced with random_tokens
final_mask = batch_data.clone()

num_random_tokens = full_mask_with_random.sum().item()
random_tokens = torch.randint(0, vocab_size, size=(num_random_tokens,))
indices = torch.nonzero(full_mask_with_random, as_tuple=True) # (returns tuples *(x,y), *(x,y)
final_mask[indices] = random_tokens
  • Filling in the mask tokens with token_index = 1
mask_token = 1
final_mask = final_mask.masked_fill_(full_mask_with_mask, mask_token)

Learning points:

  • When we fill the tokens that should be replaced with random_tokens, we ignore the cases where the token should also be unchanged. Meaning to say, we prioritize the index for random tokens over unchanged tokens
  • The random tokens can include special tokens. So it can be [MASK] or even [PAD] !
  • The remaining tokens that are not replaced with [MASK] or random tokens will automatically be unchanged!

(6) Create the labels for training

y = batch_data.clone()
padding_token = 0
y = y.masked_fill_(~full_mask, padding_token)

Learning points:

  • We use the original data as the label, but then for those tokens not selected for prediction, we pad it with padding_token = 0. The reason is because we only want to train the model on 15% of tokens that are selected for prediction (or modification). Later on in our loss function we need to set ignore_idex = 0, as seen in the Pytorch documentation below
Pytorch CrossEntropyLoss documentation

Summary

  • Because there can be conflicts in a position being selected for replacement with random token and unchanged token, the probability of masking, unchanged and replacement does not strictly follow the rule of 80% mask token, 10% random token and 10% unchanged token
  • I understand that it could be because this particular implementation did not come from the authors themselves so do let me know if there is a better proper implementation!

Collated learning points:

  • We do not mask special tokens even though they are selected for prediction
  • The interesting use of full_mask & (batch_data != tk to modify a mask to exclude the special tokens
  • When we remove the masking on the [PAD] tokens, the number of tokens selected for prediction is no longer 15%.
  • We are deciding the mask for random tokens NOT directly from the full mask but from the original batch_data and then with full_mask & (random_mask).
  • We are deciding the mask for unchanged tokens NOT directly from the full mask but from the original batch_data and then with full_mask & (unchanged_mask).
    (same reasoning for this implementation as the full_mask_with_random)
  • However, note that because we are also basing our full_mask_with_unchanged on batch_data, it is possible that an index is selected to be replaced with random tokens and unchanged. This is something that the implementation allows.
  • We are deciding the mask for [MASK] NOT directly from the full mask but with full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged)
  • The reasoning for this code:
    full_mask_with_mask = full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged) is because, for those tokens that are not unchanged or replaced with random, they should be replaced with [MASK](The remaining 80%). This ensures that the token can be only chosen to be replaced with [MASK], ensuring no conflicts with the index being selected for randomization or unchanged.
  • When we fill the tokens that should be replaced with random_tokens, we ignore the cases where the token should also be unchanged. Meaning to say, we prioritize the index for random tokens over unchanged tokens
  • The random tokens can include special tokens. So it can be [MASK] or even [PAD] !
  • The remaining tokens that are not replaced with [MASK] or random tokens will automatically be unchanged!
  • We use the original data as the label, but then for those tokens not selected for prediction, we pad it with padding_token = 0. The reason is because we only want to train the model on 15% of tokens that are selected for prediction (or modification). Later in our loss function we need to set ignore_idex = 0

--

--