Transformers: Attention is all you need — Teacher Forcing and Masked attention

Shravan Kumar
7 min readNov 6, 2023

--

Please refer to below blogs before reading this:

Introduction to Transformer Architecture

Transfomers: Attention is all you need — Overview on Self-attention

Transformers: Attention is all you need — Overview on Multi-headed attention

In the previous blogs — we mainly focussed on Encoder block but now we are going to focus on Decoder block.

What does the Decoder do?

For the task of translation — we will do translation of an english language text to any local language (here it is ‘Telugu’). Decoder requires input data from Encoder output i.e., whatever encoder has developed it needs to be transmitted to Decoder. Initially start symbol <Go> is given along with Encoder output — we get “nenu” as first word output from Decoder. For the next word output in decoder, all the encoder output along with first word ouptut in decoder is passed as the input to decoder again.

This chain of flow will follow until all the word outputs are finished in decoder.

What will be the output dimension for each output for decoder?

The output is going to be a probablity distribution outcome over a vocabulary of the input language. Every word is represented as one of the dimensions of V-dimensional vector. The output function used is softmax and the word inputs used are embeddings.

The decoder stack is a stack of N = 6 layers. However, each layer is composed of three sublayers.

in the last layer we get probablity distrubution and we take argmax for the maximum probability word. Let’s zoom into one such decoder layer and observe what all components exist there.

This is how a decoder layer looks like:

So from above we need to understand what is

  • Masked Multi-Head (Self) Attention
  • Multi-Head (Cross) Attention

In order to understand the Masked Multi-Head (Self) Attention — we need to understand a concept known as Teacher Forcing.

Let us imagine that you are training a transformer — you have not trained it yet ie., it is still learning and if we check on what is happening during training. Assuming that the input sentence is sent to encoder and the encoder generates an output which sends to decoder. Let’s start with <Go> to decoder and it generates an output “nenu”. For the next word output from decoder is going to depend on encoder output + <Go> + “nenu”. Do we anticipate any challenges in predicting this word output?

Now let’s imagine that you are trianing the transformer where we pass the first batch of data that we have passed with all sentences, all weigths in encoder and decoder are randomly initialized. So what do we expect to predict as first word output from decoder? It might not predict the right expected word during the initial stages of training, FYI — the traning data is already prepared like this with the first column — as english sentence and second column — translated as telugu sentence. Suppose if it predicted the word “hindi” instead of “nenu”. Ideally we will have a loss function to get the right output where actual ground truth is “nenu” the neural network will backpropogate until the actual ground truth is obtained. What would be word that we need to feed as second word to decoder?

In this process — we have two important concepts here

  • We don’t want to rely on the predictions made by transformers during the initial stages of training because everything is at random intitialization — so we cannot take “hindi” and feed this to as second word to decoder.
  • There is no incentive if we send the same translated input data to decoder.

So we need to balance the above two concepts and now let’s try to think with a concept on masking. So while predicting the second word — we don’t want to see what are the next set of words. Why all complete set of words should not be seen? why can’t we just block the next word but we can see words after this blocked word? The reason for this is because we are generating one word at a time so when we are generating the second word when you are going to use it for inference or test time — we don’t have the remaining words. Assuming we have a trained transformer and we have a sentence “I live in chennai” and we are trying to decode the second word, you don’t have the rest of the translation with you hence you cannot see the rest of the translations.

If the length of actual translation was T1 — if we are at t timestamp we will have to feed only t-1 words that means we need to mask everyting that comes after timestamp ‘t’. What would this look like in terms of the matrix (lower traingular matrix).

Takeaway — 1:

Teacher Forcing :

Feed in the actual outputs i.e., forcefully feeding the correct outputs instead of feeding its own garbage to it. No mathematics involved here while feeding an input.

Takeaway — 2 :

Masking:

While teacher forcing makes sense — we cannot take it to extreme and feed the entire sentence, we will have to mask out things after the timestamp ‘t’ where ‘t’ stand for a point where we are predicting the word — we cannot see any word after the ‘t’. There is some mathematics involved here if we want to decide on how we are going to do the masking.

Why the target sentence is being fed as one of the inputs to the decoder?

Usually, we use only the decoder’s previous prediction as input to make the next prediction in the sequence. However, the drawback of this approach is that if the first prediction goes wrong then there is a high chance the rest of the predictions will go wrong(because of conditional probability). This will lead to an accumulation of errors. Ofcourse, the algorithm has to fix this as training progresses. But, it takes a long time to train the model.

The other approach is to use so-called “Teacher Forcing” algorithm.

Recall that in self-attention we computed the query, key and value vectors q,k & v by multiplying the word embeddings h1, h2, h3, …..hT with the transformation matrices WQ, WK, WV respectively.

The same is repeated in the decoder layers. This time the h1,h2,…., hT are the word embeddings of target sentence. But with one important difference: Masking to implement the teacher-forcing approach during training. of course we can’t use teacher forcing during inference. Instead, the decoder act as a auto-regressor.

Note: Encoder block also uses masking in attention sublayer in practice to mask the padded tokens in sequences having length < T

Here is an example on how this is done — for e.g., we have sentences of tokens length as 512 and we have a batch size of 64 the input data is of format 64 x 512. What if some sentences tokens are of some length <512 , then what we will have is some padding is being done to reach the token length of 512 and while we are training the model we will be masking those tokens which are padded.

How does the masking is implemented? where should we incorporate it? At the input or output or somewhere in between?

let us assume we have —

Final attention weights are basically when softmax applied to below

if A is NOT a lower traingular matrix then softmax also will not make this as a lower traingular matrix. Hence before going to softmax we might need to do something where we convert the matrix into lower traingular matrix.

Here A is T x T matrix and when another matrix M is added to it then the dimensions of that matrix should also be T x T. So now let’s look at what entries in the matrix M will be? As the weights on the lower triangular matrix in ‘A’ should not change hence the values in M for lower part of matrix should be all 0. If we make all 0s in the upper traingular matrix in M will not help because when we take softmax on top of that the values will not become 0 — hence we need make all the values in upper triangular matrix for M to be -infinity.

The above matrix is ‘A+M’.

When this matrix ‘M’ is added to matrix ‘A’ (attention before softmax) and then we do the softmax — we are going to get the final lower triangular attention matrix where masking has allowed to do that. In the upcoming blogs let’s zoom into decoder layer and find more details on it.

Please do clap 👏 or comment if you find it helpful ❤️🙏

References:

Introduction to Large Language Models — Instructor: Mitesh M. Khapra

--

--

Shravan Kumar

AI Leader | Associate Director @ Novartis. Follow me for more on AI, Data Science