In Depth Understanding of Attention Mechanism (Part II) - Scaled Dot-Product Attention and Example

FunCry
8 min readMar 1, 2023

--

Introduction

In the previous article, we discussed the challenges faced by machine translation and introduced the Attention mechanism proposed in Neural Machine Translation by Jointly Learning to Align and Translate.

In this article, we will focus on introducing the Scaled Dot-Product Attention behind the Transformer and explain its computational logic and design principles in detail.

At the end of the article, we will also provide an example of using Attention, hoping to give readers a more comprehensive understanding of Attention after reading.

Scaled Dot-Product Attention

Now we have learned the prototype of the attention mechanism, however, it fails to address the issue of slow input processing. To enhance computation speed and harness the power of parallel computing, it is necessary to abandon the conventional one-character-at-a-time approach.

The paper Attention Is All You Need introduced Scaled Dot-Product Attention to overcome this challenge. The formula is as follows:

The formula may look complex, but it can be broken down into simpler steps. Let’s explore each step to understand the principle behind it.

1. QKᵀ

In this step, we are working with two matrices: Q (Query) and K (Key). Suppose Q has three pieces of data and K has four pieces of data. The dimensions of the two matrices are 3 * dₖ and 4 * dₖ , respectively.

It's important to note that the two matrices need to have the same number of columns. The interpretation is:

  • Q has three pieces of data, and each piece of data is represented by a vector of length dₖ.
  • K has four pieces of data, and each piece of data is represented by a vector of length dₖ.

If the dimension of the input sequences used to represent Q and K are not the same, or if you want to specify a special dₖ, you can use Linear(input_q_dim, dₖ) and Linear(input_k_dim, dₖ) to linearly transform the original Query and Key to dₖ dimensions.

The purpose of these two layers is to transform the two sequences to the same vector space.

Now that we have Q and K matrices (assuming dₖ = 4), let’s take a closer look at what QKᵀ is doing.

The animation illustrates that the product of the query and key matrices, QKᵀ, yields a 3x4 matrix.

This stage corresponds to the dot-product operation, since the (i, j) entry of the resulting matrix is the dot product of row i in Q and row j in K, which also represents Kⱼ’s importance to Qᵢ.

In summary, the Scaled Dot-Product Attention mechanism performs the following steps up until now:

  1. It maps the input queries and keys to the same vector space where their inner product results in higher values for more relevant pairs (this transformation is learned by the model).
  2. It computes the attention table A by taking the inner product of the query and key matrices.

2. softmax(A /√dₖ)

The stage corresponds to the Scaled section in its name. To comprehend this step, it is important to first understand the softmax function.

For each row in A, the softmax function maps each element to a value between 0 and 1 such that the sum of the values in each row equals 1.

If we fix a particular coordinate zᵢ and vary only that value while holding the other elements in the vector constant, the resulting function looks similar to a sigmoid curve:

Noticeably, the gradient of the softmax function is almost zero at extreme values, making it difficult to update the corresponding parameters during training.

If dₖ, the dimension of the key and value vectors, is set to be large, the dot product scores in the attention matrix A are more likely to fall into these regions, since there are dₖ terms to be multiplied and summed.

To mitigate this issue, the scaling factor of √dₖ is introduced to reduce the magnitude of the scores and decrease the chances of encountering vanishing gradients.

This technique helps improve the stability and convergence of the model during training.

We refer to the matrix that passes through softmax as A’.

Advanced (optional):

Assuming that each element in the query and key vectors is independently sampled from the standard normal distribution N(0, 1), according to the dot product formula below:

the dot product of q and k follows a normal distribution with mean 0 and variance dₖ.

Dividing the dot product by √dₖ scales it to have variance 1.

3. A’V

Before delving into this section, let’s recall the properties of A’:

  1. A’ has the same number of rows as the number of input queries.
  2. A’ has the same number of columns as the number of input keys.
  3. Each element in each row of A’ lies in the range of 0 to 1, where the (i, j) entry represents the significance of the j-th key for the i-th query.
  4. The sum of all the elements in a single row of A’ is equal to 1.

Assuming A’ looks like this:

Next, let’s discuss V, which represents the actual values behind the keys.

You can think of the keys as student IDs and the values as information such as names, classes, and grades. You use the student ID to find the person, but the relevant information is the value behind it. For the model, it is the same: the values are what it actually uses to calculate.

The dimension of V is represented by dᵥ, which can be set to any suitable value. You can also use Linear(input_k_dim, dᵥ) to convert the input vector into this dimension.

Note that I set the first parameter of the above command to input_k_dim, which is because the source of the keys and values should be the same vector, but one is converted to the space of Q for comparison, and the other is converted to actual values.

In other words, the keys and values have a one-to-one correspondence, with the first value corresponding to the first key, the second value corresponding to the second key, and so on.

The appearance of V is like this:

Let’s take a look at what A’V does:

From the above equation, we can observe that A’V will:

  1. Produce a matrix, where the number of rows is equal to the number of queries.
  2. For each query, its final value is the weighted average of the rows in V.
  3. The weights are based on the inner product of the query and its corresponding key.

Conclusion of Scaled Dot-Product Attention

In fact, two things are being done:

  1. Creating a lookup table (inner product of Q and K, scaled, softmax).
  2. Calculating the final output by taking the weighted average of the rows in V based on the attention table.

You can also think of the keys and values as a kind of database, and attention is actually finding a suitable vector to represent each query based on the content in the database.

In the previous explanation, Query and Key-Value are two different sequences, but in fact, their input can be the same sequence, in which case it is called Self-Attention.

Self-Attention is particularly useful in natural language processing, where the output can be seen as word-embedding of each query taking the context into account.

Example

Finally, a practical example of using attention is provided to deepen your understanding.

Suppose we want to convert the Department of Computer Science and the Department of Philosophy (two queries) into vectors that the model can understand, and there are four subjects (four key-value pairs) in the database: English, Social Studies, Mathematics, and Physics. The attention table calculated by the model may look like this:

The interpretation is as follows: for the Department of XX, the importance of these four subjects is …

Next, we need the values of V. Here, I set dᵥ = 3, where the three values in V represent memory, language ability, and logical reasoning ability.

(Note: In reality, we cannot know what the columns of Value represent, and the model will find out what they should represent on its own.)

The interpretation is as follows: for the subject of XX, the memory, language, and logical reasoning scores needed to do well are respectively …

(The scores are purely fictional.)

Multiplying the two matrices:

The final result is as follows: the memory, language, and logical reasoning scores needed to do well in XX Department are respectively …

The model will also use two vectors, 6.9 7.5 10.1 and 8.3 8.7 7.3, to represent the two departments, respectively.

I hope this example has provided you with a clear understanding of how attention operates and how to establish the attention dimension.

Conclusion

This article has provided a detailed introduction to the computation details and design logic of Scaled Dot-Product Attention, as well as an example of using Attention.

If we compare the model introduced in the previous article carefully, we will find that although the two architectures look very different, the overall logic is quite similar.

Besides attention, the Transformer architecture encompasses several other critical elements, such as the encoder and decoder structures, MultiHead, and Positional Encoding. Although this article cannot cover these aspects in depth, we hope to explore them in more detail in future write-ups.

--

--