Diving Deep into Naive Self-Attention: A Step Towards Understanding Transformers, Part 2

Luv Verma
7 min readMay 13, 2023

--

In Part 1 of this series, we took a deep dive into the fundamental concepts of ‘attention’ in the context of Sequence-to-Sequence (Seq2Seq) models. Our goal was to explore how we could shift from the traditional Recurrent Neural Network (RNN) based Seq2Seq models to a new paradigm that leverages attention layers for generation tasks — tasks typically handled by RNNs.

(link to part 1)

As we continue our journey in Part 2, we turn our focus to Naive Self-Attention, a pivotal step into fully understanding the architecture of the revolutionary ‘Transformer’ model.

Question: Is Attention All We Need?

This was answered by the famous paper: Attention is All You Need.

In the last blog, we have seen attention applied to the Seq2Seq model (Figure 1). Note the horizontal connections. Even though there is attention added in each decoder state, still the information is moving sequentially.

Figure 1: Basics of key, query vectors, and attention scores. Attention applied to the sequential model

Now, we have attention, do we even need these sequential (or recurrent) connections (Figure 1)? Attention can grab any information from the input string. So instead of relying on recurrence to capture information about the input string, can we entirely rely on attention to capture this information (Figure 2, sequential information crossed)?

Figure 2: Sequential connections crossed. Get away with them to reach to fully functional attention network

Can we transform RNN into an attention-based model?

For this:

  1. Attention would have to access every time step (can do).
  2. Need to index into every time step of the input (can do).
  3. Need to index into every time step of the previous output to see which words were predicted before (can do).

With the above three capabilities, attention can achieve what recurrence does, but to build a complete model, there are certain issues that need to be addressed.

  1. Problem 1: According to Figure 1, all decoder states, s(l) can access encoder states h(t), but they cannot access the decoder states before them s(l-1). This hinders understanding of what the decoder has produced as an output already.
  2. Problem 2: Each Encoder state has become independent of the other encoder states (we have crossed the connections, and removed the dependencies). Such a thing is not true in sentences

Let’s see how can we build a complete attention-based model, keeping in mind the problems described above.

Figure 3: Basics of the input layers generating key, query, and values and the interaction of query at time step 1 with keys at all time-steps

  • Let’s say we have a weight matrix (W) that is shared between all the layers at the input level. For example, at level 1, three layers h(1), h(2), h(3) shares W (Figure 3). It can also be seen from equation 1, h(t) varies according to the time step, but W does not have any index which means it is common to all three layers
  • There is another function applied on h(t) called (equation 2), which gives something called a value at time step ‘t’. v(t) is some function v applied over h(t). This function is just a linear transformation from h(t) to v(t) (equation 3). Again matrix W(v) is constant and just like W it is also applied to all h(t), which means W(v) is also shared.
  • We will also output keys (k). This will also be some function of h(t). Just for simplicity, let’s say that function is a linear transformation between h(t) and k(t) given by weight matrix W(k) which is again shared across all time steps (equation 4)
  • Instead of the decoder producing query vectors (as in Seq2Seq models), here, every time step h(t), will also produce the query vector (Figure 3), which will again be a linear function over h(t) and follow a similar structure as equations 3 and 4, with a shared weight matrix W(q) (equation 5).
  • As done, before with the Seq2Seq model (Figure 2), assume that now also, the query of every other step can index (or communicate) with the key of every other step, including itself (i.e. q1 to k1, Figure 3 connections).

Figure 4: Visualization of the dot product between query at time-step 1 with keys at all time-steps

  • This means, that we can compute dot products between every q(t) and every k(t) (equation 6, Figure 4), which gives us attention scores. So for 3-time steps, we will have 9 attention scores. This is because the time step ‘l’ can vary from 1 to 3, and the time step ‘t’ can vary from 1 to 3.
  • Pass each of them through the softmax (equation 7). In Softmax, we are normalizing the keys. for every time-step (l) belonging to the query, we normalize over ‘t’ timesteps.

Figure 5: Visualization of context-vector calculation for the time-step ‘l’ = 1.(time-steps for keys are denoted by ‘t’ and for queries are denoted by ‘l’)

  • Finally, to get the attention at time-step 1 (equation 8, Figure 6), just sum up over all time-steps (t). this will give the context vector (also called attention) at time step (l) (Figure 5).

As, discussed, above for time-step (l=1), we can do a similar calculation for all time-steps (l =1 or 2 or 3).

Equations 1 to 8, and figures 3 to 6, describe what is called the Self-Attention mechanism. This self-attention mechanism has produced context vector/attention at every time step. Therefore, this can be thought of as 1 layer, that integrates information across time steps!!

Can we build a complete network out of such self-attention layers?

Yes, we can. Stack them up (Figure 6)

Figure 6: Stacking up individual self-attention layers

Why stack them up? what do we get by stacking?

We get to process the sequence more and more, and maybe we can transform 1 sort of sentence, like a sentence in French, into a sentence in English. In the end, somehow decode it to get an answer (will be discussed later).

That was the basic idea of Self-Attention !! But, now there is another question. How do we move from self-attention to the building of transformers?

To build, transformers, some of the fundamental limitations of self-attention layers have to be addressed.

Following are the limitations of the self-attention layers, that need to be addressed to build transformers:

  • Self-attention does not have any notion of proximity in time ( for example, x1,x2, and x3 are processed entirely in parallel, without any regard for their order). Let us say, we switch their order, the self-attention layer would not care. But in natural language order of the word matters. This problem is dealt with by something called positional encoding.
  • We saw, from Figure 6, that each self-attention layer had 1 key, value, and query for each time step. But why just restrict to only 1 key, query, or value pair? This key, value, and query can be thought of as filters in the convolutional networks, and as we use a lot of filters per layer, we can similarly use a lot of key, query, and value pairs. This is solved by something called Multi-headed attention.
  • How can we add nonlinearities in the self-attention layers? The above ones discussed were entirely linear. Calculating a(l) (equation 8) was linear in the values v(t), values are linear functions of h(t) (equation 3).
  • Since, the attention mechanism can also look into the future (for example, query at time step 1, can look into the keys at time step 2 and 3, and so on). However, we are trying to predict the future, so how can we look into it? So self-attention is unable to distinguish between the past and the future. It is solved by something called Masked Attention.

That’s all for part 2 !!

In the third part of this blog, series I have delved deep into the problems with naive-self attention.

(link to part 1)

(link to part 3)

(link to part 4)

If you like, please read and clap!!

LinkedIn

--

--