Demystifying the Attention Logic of Transformers: Unraveling the Intuition and Implementation by visualization

Souvik Mandal
8 min readJun 17, 2023

--

Transformers are on the rise on both image and NLP domain. In this blog we will understand where the idea of attention came from, how it works and finally we will see the implementations.

This is part of a series of blogs I am writing to understand the attention is all you need paper, why transformer works across domains, how learning changes with attention (Transformer architecture) from CNN. I will update this section with the links of the other blogs:

  1. Demystifying the Attention Logic of Transformers (This Blog)
  2. Transformer positional embeddings
  3. Attention is All You Need: Understanding Transformer Decoder and putting all pieces together.

Attention Please! 🎤

During the whole attention process, we learn three metrics query, key and value. The idea of the query, key and value came from information retrieval system. So lets understand the idea with a database query.

Let’s say I have a database with all the Bengali authors and their books information. Now I want to read some book written by Rabindranath. I will do a database query like this: select books written by Rabindranath. Lets assume our database is like below.

Database with information about books written by different authors.

In the database the authors are analogous to keys, and the books are analogous to values. Rabindranath is the key from the query. So, we first need to compute the similarities between the query and the keys (all authors in the database) of the database. Then we return the values (books) of the most similar author (in this case we will return all the books from Rabindranath Tagore).

Similarly, attention has three matrices, which are query(Q), key(K), and value(V) matrix. Each of them has the same dimensions as input embedding. We learn the values of these metrices during training.

We will discuss how the input embeddings and computed in a separate blog. But here you can assume from the each words we create a vector such that we can process the information. For each word we generate a vector of 512 dimension.

So, in our case, all 3 matrices are 512x512 (since the word embeddings dimension is 512). For each token embedding, we multiply that with all three matrices (Q, K, V). So, we will have 3 intermediate vectors of length 512 for each token.

Attention mechanism. I have assumed the results of the dot product to make the image more readable.

We next compute scores, which is the dot product between the query and key vectors. The score determines how much focus to place on other parts of the input sentence as we encode a word at a certain position.

Next, we divide the dot product with the square root of the dimensionalities of the key vector. This scaling is done to prevent the dot product from becoming too large or too small (depending on the positive or negative values), which can cause numerical instability during training. The scaling factor is chosen to ensure that the variance of the dot product is approximately equal to 1.

Then we pass the result through a softmax operation. This normalizes the scores so they’re all positive and add up to 1. The softmax output determines how much information or features (values) we should take from different words. We are essentially computing the weights.

One important note here is why do we need information/features from other words. In a text like this the dog did not attack the old man because it was sleepy. Now, if we just look one word it model does not have any information on if itmeans the dog or the old man.

Finally, we compute the multiplication of the softmax and the values and sum them together.

Matrix is here 🚙

The logic I have shared above is all good but if we were to implement this way it will not be optimize, so lets see the vectorize implementation of this.

The query key and matrix computation can be done as below

Query vector Computation

Same way we can compute the key and value vectors.

Key and value vector computation

Finally we compute the scores and the attention output.

Attention output computation

Lets code 📝

import torch
import torch.nn as nn
from typing import List

def get_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vector of embeddings_dim size for each words
# normally we train a tokenizer to get the embeddings.
# check the blog on tokenizer to learn about this part
embeddings = [torch.randn(embeddings_dim) for word in words]
return embeddings


text = "I should sleep now"
words = text.split(" ")
len(words) # 4


embeddings_dim = 512 # 512 dim because the original paper uses it. we can use other dim also
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])


# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])


# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])


# compute the score
scores = torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) / torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape # torch.Size([4, 4])


# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights = softmax(scores)
attention_weights.shape # torch.Size([4, 4])


# attention output
output = torch.matmul(attention_weights, value_vectors)
output.shape # torch.Size([4, 512])

Multi-Head attention 🙉

Because you can never have too much attention. 😛 — me

The attention I have mentioned above is single head attention. In multi head attention we have more than one head, 8 heads in the original paper.

For both multi-head and single head attention computation are same till query (q0-q3), key (k0-k3), value(v0-v3) intermediate vector.

Multi-head attention

After that, we split the query vector into equal parts into number of heads we have. In the image above we have 8 heads and the query, key and value vectors have dimension of 512. So we create 8 vectors of 64 dimension.

We take the first 64 dim vectors to the first head, second set of vector to the second head and so on. In the image above I have only shown the computation for the first head.

After we have the mini queries, keys and values (the ones with 64 dim) in a head, we compute the remaining logic same as single head attention. Finally, we have 4 vectors of 64 dimension from each of the head.

We combine the first 64 outputs of each head to get the final 512 dim output vector. Same for the remaining 3 vectors results.

Combine the results from the heads

Transformers with multiple heads have a higher capacity to represent complex relationships in the data. Each head is capable of learning different patterns. Multiple heads also provide the ability to attend to different subspaces (64 dim vectors from the 512 dim original vector) of the input representation simultaneously.

Implementation of multi head attention

num_heads = 8
# batch dim is 1 since we are processing one text.
batch_size = 1

text = "I should sleep now"
words = text.split(" ")
len(words) # 4


embeddings_dim = 512
embeddings = get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape # torch.Size([512])


# initialize the query, key and value metrices
query_matrix = nn.Linear(embeddings_dim, embeddings_dim)
key_matrix = nn.Linear(embeddings_dim, embeddings_dim)
value_matrix = nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape # torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])


# query, key and value vectors computation for each words embeddings
query_vectors = torch.stack([query_matrix(embedding) for embedding in embeddings])
key_vectors = torch.stack([key_matrix(embedding) for embedding in embeddings])
value_vectors = torch.stack([value_matrix(embedding) for embedding in embeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape # torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])


# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view = query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view = key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view = value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64])


# We are splitting the each vectors into 8 heads.
# Assuming we have one text (batch size of 1), So we split
# the embedding vectors also into 8 parts. Each head will
# take these parts. If we do this one head at a time.
head1_query_vector = query_vectors_view[0, 0, ...]
head1_key_vector = key_vectors_view[0, 0, ...]
head1_value_vector = value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape


# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1 = torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape # torch.Size([4, 4])


# compute the attention weights for each of the words with the other words
softmax = nn.Softmax(dim=-1)
attention_weights_head1 = softmax(scores_head1)
attention_weights_head1.shape # torch.Size([4, 4])

output_head1 = torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape # torch.Size([4, 512])


# we can compute the output for all the heads
outputs = []
for head_idx in range(num_heads):
head_idx_query_vector = query_vectors_view[0, head_idx, ...]
head_idx_key_vector = key_vectors_view[0, head_idx, ...]
head_idx_value_vector = value_vectors_view[0, head_idx, ...]
scores_head_idx = torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) / torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))

softmax = nn.Softmax(dim=-1)
attention_weights_idx = softmax(scores_head_idx)
output = torch.matmul(attention_weights_idx, head_idx_value_vector)
outputs.append(output)

[out.shape for out in outputs]
# [torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64])]

# stack the result from each heads for the corresponding words
word0_outputs = torch.cat([out[0] for out in outputs])
word0_outputs.shape

# lets do it for all the words
attn_outputs = []
for i in range(len(words)):
attn_output = torch.cat([out[i] for out in outputs])
attn_outputs.append(attn_output)
[attn_output.shape for attn_output in attn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]


# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape # torch.Size([1, 8, 64, 4])


# Transpose the key vector on the last dim
score = torch.matmul(query_vectors_view, key_vectors_view.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)


# reshape the results
attention_results = torch.matmul(score, value_vectors_view)
attention_results.shape # [1, 8, 4, 64]

# merge the results
attention_results = attention_results.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embeddings_dim)
attention_results.shape # torch.Size([1, 4, 512])

The code implemented in this blog are all aggregated in this notebook. Feel free to edit and try out things.

Hope you have enjoyed this blog. 🤗 If you are interested in reading about vision transformer checkout this blog:

Resources

BECOME a WRITER at MLearning.ai // invisible ML // 800+ AI tools

--

--