Demystifying Attention Mechanisms in Sequence-to-Sequence Models, Transformers (Part 1)

Seq2seq models gained immense popularity but faced the bottleneck issue which was dealt with by the concepts from the transformers model. In this blog, we aim to develop an understanding of transformers’ key components by examining seq2seq model limitations and building upon them.

Luv Verma
9 min readMay 9, 2023

Basic Introduction of Sequence-to-Sequence (seq2seq) models

The sequence-to-sequence (seq2seq) model is a popular approach for addressing various natural language processing tasks, such as translation, summarization, and dialogue systems.

The seq2seq model comprises two main components: an encoder and a decoder, both of which are recurrent neural networks (RNNs).

In a seq2seq model, the encoder processes the input sequence, one element at a time, and generates a fixed-length hidden state representation. This hidden state captures the information from the input sequence. The decoder, on the other hand, takes this hidden state and generates the output sequence element by element, conditioned on the previous output and the hidden state.

Figure 0: An example of a sequence-to-sequence encoder-decoder RNN network.

Despite its success in many tasks, the seq2seq model has some limitations, particularly when dealing with long input sequences and distant context dependencies.

In the following sections, we will discuss the bottleneck problem in seq2seq models and explore a potential solution to address this issue.

Bottleneck problem with sequence-to-sequence models.

Since there is a sequential connection from the encoder to the decoder, all the information from the source sequence has to be contained in the activations at the beginning of the decoder. So, the decoder does not know anything about the source sequence other than what the encoder puts in the hidden state.

Figure 1: Representative of issues with sequence-to-sequence model

Why is it a problem?

What if the sequence/sentence is too big? or the context between two words is quite far apart in one sentence? Compressing big sequences and passing it sequentially through a bottleneck (single hidden layer) leads to information and context loss (Figure 1). Additionally, training challenges like vanishing gradients make it hard for the model to effectively learn and retain long-range dependencies, leading to inaccuracies in the generated target sequence.

For example, in the following sentence:

“In the beginning, Alice discovered the entrance to Wonderland, which was a hidden rabbit hole, and after facing numerous thrilling adventures with peculiar creatures, she ultimately learned that the key to returning back to her world was hidden in the pocket of her dress all along.”

In this sentence, the context between “Alice” and “her world” is far apart, as is the context between “hidden rabbit hole” and “key to returning back.” Sequential models are not just sufficient to deal with them.

Can we peek at the input, while decoding? What is the meaning of peeking at the input while decoding?

Peeking at the input while decoding, simply means allowing the decoder to directly access and refer to the relevant parts of the input sequence encoded by the encoder (Figure 2). So basically connecting encoders with decoder such that bottleneck is removed.

Figure 2: Representative of more interaction between the encoder and decoder networks

But, how that can be achieved? Consider upcoming steps and Figure 3

Figure 3: Basics of Key and Query Vectors

  1. At each step of the encoder (or at each time step through the encoder), produce a vector that represents the summary of information/or describes what is present/associated with a particular encoder at a particular time step. These are not semantically meaningful to human eyes. These are called as Key vectors
  2. Generally, there is a learned function (e.g linear layer + ReLU) that maps from a respective encoder state to a respective key vector.
  3. Just, as we did while encoding, we can output another set of vectors that represents the information while decoding. But, this information is about what a particular decoder step is looking for instead of what information the decoder step possesses (as with key vectors for encoders). These are called as query vectors.
  4. Query vectors are looking for the information, key vectors. Let’s pick up one query vector. Each query vector will look at all the key vectors to capture the key vector which has the most similar information. Because of this now we will know which time step of the input encoding process is most relevant to a particular decoder time step.
  5. Once, we know which encoding state is most relevant, we send the hidden state information from that encoding time step directly to the decoder. µk is the mean of cluster k.
  • For example, say the 2nd key vector is most relevant to the 3rd query vector. In this case, the attention mechanism helps the 3rd decoding step to focus on the information from the 2nd encoding step.

Crude Intuition: Say 3rd query vector asks for what is the verb in the sentence, and 2nd key vector is the closest match. What would 2nd key vector represents?

In this case, the 2nd key vector would represent the information about the verb in the input sentence, encoded during the 2nd time step of the encoder. It captures the context and semantic information around the verb, making it the most relevant piece of information for the 3rd query vector, which is looking for the verb.

But how is the information populated in these key and query vectors? They are learned as a part of the training process, we are not selecting it manually.

Let us look at how the keys and queries can be mathematically framed!!

Some terminologies used:

  • h → hidden state of the encoder (RNN activations)
  • s → hidden state of the decoder (RNN activations)
  • k → for the key vector. Time step t is used to denote the time step while encoding
  • q → for the query vector. Time step l is used to denote the time step while decoding

Figure 3: Basics of key, query vectors, and attention scores

Key Vector:

For a particular encoder time step (t), the key is a function k that is applied to the hidden state of the encoder (h) at a particular time step t. Therefore equation 1:

A very simple example (equation 2) is a linear layer followed by a non-linearity (sigmoid, Tanh, ReLU, etc.)

However, a very often in practice function k is just a linear transformation.

Query Vector:

For a particular decoder time step (l), the query is a function q that is applied to the hidden state of the decoder (s) at a particular time step l. Therefore equation 3:

The similarity score between a key vector at time step (t) with a query vector at time step (l) is given by a dot-product also called attention score (equation 4).

Intuitively, we want to pull out hidden state h(t) for the time step t for which the attention score is the largest (i.e., the dot product in equation 4 is the largest). Once we have the hidden state h(t), which maximizes the attention score (equation 4), we will send that hidden state h(t) to the decoder at step l. As this is like sending an index/time step ‘t’, which is an argmax operation. argmax simply means sending an index instead of a value. However, argmax operation is non-differentiable. So for training purposes, instead of using argmax, we can use softmax (equation 5):

Now, we have a softened approximation of the arg max (called softmax) over a particular encoding time step (t) and a particular decoding time step (l). Now, since ‘l’ would like to

The alpha(t,l) are also known as the attention weights. These are further used to create context vectors a(l) given by equation 6:

In equation 6, we are summing across ‘t’ to get to a(l) (so-called context vectors). By summing across all time steps ‘t’, we combine the information from all encoder time steps, weighted by the attention weights alpha(t, l). This allows the decoder to receive a more comprehensive and contextually relevant piece of information. The attention weights alpha(t, l) determine the importance of each encoder hidden state h(t) for the current decoding step ‘l’. A higher weight (for a particular time step ‘t’ means that the corresponding hidden state h(t) is more relevant to the current decoding step (‘l’).

Explaining equations 5 and 6 with the help of an example:

Let’s consider a simple example where we have 5-time steps (t=1 to t=5) for the encoding process and a single decoding step (l=3). Suppose the attention scores e(t, l=3) for each time step t are as follows:

  • e(1, 3) = 1
  • e(2, 3) = 5
  • e(3, 3) = 2
  • e(4, 3) = 1
  • e(5, 3) = 1

Now, we compute the softmax values for these attention scores:

  • α(1, 3) = exp(1) / (exp(1) + exp(5) + exp(2) + exp(1) + exp(1)) ≈ 0.027
  • α(2, 3) = exp(5) / (exp(1) + exp(5) + exp(2) + exp(1) + exp(1)) ≈ 0.798
  • α(3, 3) = exp(2) / (exp(1) + exp(5) + exp(2) + exp(1) + exp(1)) ≈ 0.108
  • α(4, 3) = exp(1) / (exp(1) + exp(5) + exp(2) + exp(1) + exp(1)) ≈ 0.027
  • α(5, 3) = exp(1) / (exp(1) + exp(5) + exp(2) + exp(1) + exp(1)) ≈ 0.027

It’s visible that α(2, 3) is the highest among all attention weights, which means that the hidden state at time step t=2 has the most significant impact on the decoding step l=3.

Now, suppose we have the encoder hidden states h(t) for each time step t. To create the context vector a(3) for the decoding step l=3, we compute the weighted sum of the hidden states:

  • a(3) = α(1, 3) * h(1) + α(2, 3) * h(2) + α(3, 3) * h(3) + α(4, 3) * h(4) + α(5, 3) * h(5)

Since α(2, 3) is the highest weight, the contribution of h(2) to the context vector a(3) will be the most significant, and the context vector will be heavily influenced by the information from the encoder’s time step t=2.

So, in short, we were able to pull out the RNN encoder state h(t) for time step t( t=2 in the above example) which forms the largest dot product between key and query (l = 3 in the above example).

How is this context vector a(l) used further?

Earlier, in RNN, we used to have equation 7 (y-hat(l) represents the output from time-step ‘l’ of the decoder):

However, now we have equation 8:

Thus, now that we have a context vector, to produce an output from the decoder state ‘l’ we are also utilizing a context vector at time step ‘l’.

Equation 8, is one of the ways to use the context vector, however, there are multiple ways to use it and we will delve further into the next blog, about how can this attention mechanism be used.

Transformers will ring a bell!!

If you like, please read and clap!!

LinkedIn

--

--