Transformers: Attention is all you need — Overview on Multi-headed attention
Please refer to below blogs before reading this:
Introduction to Transformer Architecture
Transfomers: Attention is all you need — Overview on Self-attention
The entire self-attention can be done in parallel for all the input words and it all happens through matrix muliplications as show below and the entire block with purple color is called as one “Scaled Dot Product Head”.
Let us take this one Scaled Dot Product Head and assume we have 5 word inputs resulting to an output with 5 words.
If we had multiple attention heads for e.g., with Two-head Attention and producing 2 outputs with z11 and z12 respectively after the scaled dot product head calculations.
Why do we need this Multi-Head attention?
The concept of multi-head attention comes from computer vision technique on using multiple filter/kernel in a CNN layer.
Mulitple filter/kernel in a CNN layer is used for learning more abstract representations, capture more meaningful interactions (e.g., blurs, edges, shapes etc.) between inputs.
Similarly, we can have more than one self-attention heads with different parameter matrices (WiQ, WiK, WiV) with a hope that it learns subtle contextual information.
This motivates “Multi-head Attention”, which is a simple extension of single-head attention like each kernel independently learns its feature in CNN, each head independently computes the attention in Transformers. (Parallel computation!).
Let us take an example:
In the above sentence example let’s say if we want to learn contextual representation for the word “it” and want to focus more on “was” becuase I want to know what is the subject for “was” and also for the word “it” acts as a subject. So we need to learn alphas for each timepoint. We are finding the contextual representation for “it” with a focus on “was”-> alpha10, 11
Attention that need to be paid for the 11th word (“was”) when you are computing the 10th word (“it”)contextual representation.
Let’s say that if want to learn the contextual representaion for the word “it” and want to focus on “animal” -> alpha10, 2
So here we multiple cases where we want alpha10, 11 to be high and also alpha10, 2 be high. So this is going to form 2 self-attention blocks which in turn called as multi-headed attention blocks. Hence considering all these scenarios we observe from the chart that
- The word “it” is strongly connected to the word “was” in the first head
- The word “it” is strongly connected to the word “animal” in the second head
so it is evident (empircally) that adding more than one attention helps in capturing different contextual information of the sentence. So this is how a two headed attention look like.
Now let’s take a look at the dimensions of each vector at each stage of process. If the concatenate stage should be of a vector with 512 dimension — we need to get all the 8 heads (8 different attention heads for each input word) with 64 dimensions output.
Assuming the input word vectors of each of H are 512 dimensions — then multiply with 64 x 512 matrix to get 64 dim outputs until after the scaled dot product attention layers. The same thing is repeated for all the 8 heads with 64 dim outputs — hence concatenating them will have 512 dim so it is like we started the input word dimension with 512 and we ended with 512 dim until the concatenation output and finally we do a linear transformation to get final outputs from z1 to zT. We will have to remember that all the concatenations happening are with all the input words for e.g., if we had input words count as 5 then we will have 8 different attention head outputs for z1. similarly for z2, z3, ….z5 (assuming T = 5) we have 8 different attention head outputs. The output at each layer is with T embeddings.
Zooming into the Encoder Stack
This is how our basic encoder block look like:
The inputs from h1 to h5 are passed through self-attention (multi-head) and the outputs from this are s1 to s5 are passed to Feed Forward Network (FFN). Now let us understand what happens in this FFN. This encoder is a stacked encoder which typically has 6 such encoder layers or blocks.
FFN actually does projections on the input data received from Self-attention layer where d dimension input received is converted or projected to d dimension output.
The projections for each of the FFN is shown below:
To get the z1 output the projections from FFN1 are -
Similarly for z2 and so on:
The above non-linearity equation is shown for all the FFNs.
The encoder is composed of N identical layers and each layer is composed of 2 sub-layers (Self-attention and Feed Forward Network)
The computation is parallelized in the horizonatal direction (i.e., within a training sample) of the encoder stack, not along the vertical direction. Let us denote the output sequence of vectors from the encoder as ej, for j = 1, 2, ….T (T represent no of tokens). So hence for each token we get the final refined represntation or output with contextual learning which is passed with so many layers as shown above.
Please do clap 👏 or comment if you find it helpful ❤️🙏
References:
Introduction to Large Language Models — Instructor: Mitesh M. Khapra