Masking in Transformers’ self-attention mechanism
Masking is needed to prevent the attention mechanism of a transformer from “cheating” in the decoder when training (on a translating task for instance). This kind of “ cheating-proof masking” is not present in the encoder side.
I had a tough time understanding how masking was done in the decoder, that is why I’m writing this article. Hopefully it will help other people as well.
Before I start, an excellent introductory read for understanding transformers is Jay’s Alammar article “http://jalammar.github.io/illustrated-transformer/ “.
Let’s start with an example:
If we have the following sequence as an input for our decoder: “I love it”, then the expected prediction for the token at position one (“I”) is the token at the next position (“love”). Similarly the expected prediction for the tokens “I love” is “it”.
We do not want the attention mechanism to share any information regarding the token at the next positions, when giving a prediction using all the previous tokens.
To ensure that this is done, we mask future positions (setting them to -inf
) before the softmax step in the self-attention calculation.
Below is a commented example summing up how the masking works with a two token long sequence (“Hello there.” for instance):
Zoom-In:
First section:
In the first section, I show how the Q matrix is created from X (the process is similar for V and K matrices).
X has the following size:
- 2 which is the sequence length
- 4 which is the embedding dimension
Wq has the following size:
- 4 (embedding dimension)
- 4, because we assume that Wq is a square matrix for simplification purposes. The reality is slightly more complicated, but we do not need to know more about it to understand the masking mechanism.
The resulting Q, V and K matrices have the following size:
- 2 which is the sequence length
- 4
Second section:
In the second section, I write the formula that sums up the steps in the self attention mechanism (see http://jalammar.github.io/illustrated-transformer/ for a more thorough explanation).
Third section:
In the third section, I show the values of the matrix “I”, which is the result of the dot product of Q and Kt: A, B, C and D.
One can see how A results of operations on values that come from the embedding of the token in the first position (q1 and k1).
This stands in contrast with B whose value comes from both the embedding of tokens at the first and second position (q1 and k2).
N.B: q1=x1*Wq and k2=x2*Wk, where x1 and x2 are respectively the embedding of the first and second token, “hello” and “there”.
Fourth section:
In the fourth section, I do the same work of analysing the composition of our final values (matrix F).
It outlines how self attention allows the decoder to peek on future positions, if we do not add a masking mechanism.
The softmax operation normalizes the scores so they’re all positive and add up to 1.
We see that if we want to respect the rule of “no peeking ahead”, we need to set the value B’ of the I’ matrix to zero (with I’ =softmax(I/dk^(1/2))).
We obtain that result by adding a mask to I before the softmax operation as follow:
where:
Then we have:
If we redo the operation to obtain F, with the “mask-on”:
Finally, I’ll use one of the images of Jay’s article to sum up what is happening with the masking on a two token long sequence.
That is all. Please tell me if you find anything wrong, or if it helped you understand how masking was used to prevent the “peeking” in the comment section below!