Explained: Attention Mechanism in AI

In simple English for everyone

XQ
The Research Nest
12 min readMar 9, 2024

--

Created by me using DALLE3

Following a previous article on Transformers, we continue our exploration by focusing on attention — what it means and how it’s done.

Here’s a quick TL;DR on the basic intuition behind attention mechanisms.

The goal is to enable the AI model to selectively focus on different parts of the input data when producing output.

The attention mechanism generates scores (often using a function of the inputs), determining how much focus to place on each data part. These scores are used to create a weighted sum of the inputs, which feeds into the next network layer. This allows the model to capture context and relationships within the data that might be missed with traditional, fixed approaches to processing sequences.

Let’s go on a journey exploring the different concepts that led us to where we are today.

The Origin

Researchers wanted to translate text from one language to another using neural networks. Some of the early approaches consisted of the Encoder-Decoder architecture. The encoder creates a vector representation of the given sentence while the decoder tries to use it to generate the correct translation.

A major problem in those early networks was the size of inputs they could handle. No matter how long the input sentence is, the encoders had to encode them into fixed-length vectors, which may not capture all the intricacies of the input. When the length of the sentences was increased, the performance dropped.

Researchers knew they had to modify the basic encoder-decoder networks to solve this.

What can you do if the sentence is too long?

Let’s ask what a human will do. You would probably search through the long sentences to find what’s most relevant and use that information to translate the sentence. And voila! That’s exactly what can be applied to the model architecture.

Instead of the fixed length vectors, we can create a sequence of vectors for all the words in the sentence. When we have to translate a long piece of text, we can only use the most important vectors.

Note: A vector in this context is basically an array of numbers created by applying some methods to convert data like text into a numeric format.

The next immediate question — How do we know which vectors are important? As a human, when you try to find out the important parts of a long paragraph, you naturally pay attention to the stuff that seems to have valuable information, like a fact, some key point, or a complex word.

Bingo! Make the machines do the same. As they go through the sentences, make them pay more attention to anything that looks important and provide a higher weight to it.

This is the origin of attention mechanisms.

Early Encoder-Decoder Networks

In 2014, we saw the use of recurrent neural networks as the encoder and decoder, as proposed by Cho et al. As discussed in my Transformers article, RNN-based networks had issues with very long sentences.

Around the same time, Ilya Sutsekever and his team at Google (this was before OpenAI even came into existence) proposed sequence-to-sequence learning using LSTMs (a variation on RNNs) as encoders and decoders. Both these models were applied to the task of translating text from English to French. LSTMs seemed to work well even with longer sequences. Even back then, they believed this could work much better with optimizations and for other tasks beyond translation.

The idea of using LSTM was inspired by the work of Alex Graves a year prior, who used RNNs to synthesize handwriting.

Early Signs of Attention

Fixed-length vectors in encoder-decoder models were a bottleneck. To solve this, Bahdanau et al., in their paper on neural machine translation, proposed a mechanism to make the model focus on the parts of a sentence that are contextually more relevant to predict the target word — the foundation of attention as we know it.

What was new in this architecture?

  1. When the decoder works on translating, it looks at the original sentence to find the best parts that match what it’s trying to say next.
  2. Each word the decoder wants to translate creates a context vector. This vector is like a summary that helps decide how to translate that word based on the original sentence.
  3. The decoder doesn’t use the whole summary at once. Instead, it calculates weights to determine which parts of the summary are most important for the current word it’s translating.
  4. The process of deciding these weights involves something called an alignment model. It’s a special tool that helps the decoder pay attention to the right parts of the original sentence when translating a word.

Let’s logically understand what happens with code.

Do note that this is just an oversimplified version of what it really looks like.

Each step will use a far more complex method to get the outputs.

The goal is first to get a basic intuition of how the data is processed.

# Step 1: Tokenize the sentence
sentence = "Awareness is power in a world where knowledge is everywhere"
words = sentence.split()
print(f"Step 1 - Tokenized Words: {words}")

# Step 2: Simulate embeddings (simple numerical representations)
word_embeddings = {word: [ord(char) - 96 for char in word.lower()] for word in words}

print("\nStep 2 - Word Embeddings:")
for word, embedding in word_embeddings.items():
print(f"{word}: {embedding}")

# Step 3: Automatically generate attention weights based on word length
total_characters = sum(len(word) for word in words)
attention_weights = {word: len(word) / total_characters for word in words}

print("\nStep 3 - Attention Weights:")
for word, weight in attention_weights.items():
print(f"{word}: {weight:.3f}")

# Step 4: Compute weighted sum of embeddings
weighted_embeddings = {word: [weight * val for val in embedding]
for word, embedding in word_embeddings.items()
for word_weight, weight in attention_weights.items() if word == word_weight}

final_vector = [0] * len(max(word_embeddings.values(), key=len))
for embedding in weighted_embeddings.values():
final_vector = [sum(x) for x in zip(final_vector, embedding)]

print("\nStep 4 - Final Vector Before Transformation:")
print(final_vector)

# Step 5: Apply a simple transformation as a simulated model process
processed_output = sum(final_vector)

print(f"\nStep 5 - Processed Output: {processed_output}")
Step 1 - Tokenized Words: ['Awareness', 'is', 'power', 'in', 'a', 'world', 'where', 'knowledge', 'is', 'everywhere']

Step 2 - Word Embeddings:
Awareness: [1, 23, 1, 18, 5, 14, 5, 19, 19]
is: [9, 19]
power: [16, 15, 23, 5, 18]
in: [9, 14]
a: [1]
world: [23, 15, 18, 12, 4]
where: [23, 8, 5, 18, 5]
knowledge: [11, 14, 15, 23, 12, 5, 4, 7, 5]
everywhere: [5, 22, 5, 18, 25, 23, 8, 5, 18, 5]

Step 3 - Attention Weights:
Awareness: 0.180
is: 0.040
power: 0.100
in: 0.040
a: 0.020
world: 0.100
where: 0.100
knowledge: 0.180
everywhere: 0.200

Step 4 - Final Vector Before Transformation:
[10.100000000000001]

Step 5 - Processed Output: 10.100000000000001

Don’t worry much about the code right now. We will explore this in detail as we go forward.

The Scaled Dot Product Attention

Attention can be done in many different ways. As the research caught more attention, people tried to find better ways. In a paper titled, Effective Approaches to Attention-based Neural Machine Translation, researchers propose the dot product attention where the alignment function is a dot product.

A few years from then, we arrive at the grounding-breaking “Attention is all you need” paper, which introduces the scaled dot product attention.

Snippet from the attention paper.

Here’s a step-by-step approach to scaled dot product attention.

Step 1: Understand the Inputs

  • Q (Queries): Matrix containing the query vectors. These represent the set of items you want to draw attention to. In the context of processing a sentence, a query is typically associated with the current word you’re focusing on. The model uses the query to seek out relevant information across the sequence.
  • K (Keys): Matrix containing the key vectors. Keys are paired with values and are used to retrieve information. Each key is associated with a value in a way that the model can use the similarity between a query and a key to determine how much attention to pay to the corresponding value.
  • V (Values): Matrix containing the value vectors. Values hold the actual information the model wants to retrieve. Once the model determines which keys (and thereby values) are most relevant to a given query, it aggregates these values, weighted by their relevance, to produce the output.

So, the attention mechanism requires three inputs, Q, K, and V; all are generally derived from input embeddings.

Step 2: Obtain Q, K, and V

  1. Starting from Input Embeddings: Assume you start with input embeddings representing your data (e.g., word or sentence embeddings in a text application). These embeddings capture the semantic meaning of your input elements. To learn more about them, refer to my previous article, Explained: Tokens and Embeddings in LLMs.
  2. Learned Linear Transformations: The model learns separate linear transformations (weight matrices) to project the input embeddings into query, key, and value spaces. These transformations are part of the model’s parameters and are optimized during training. In short, you apply these weight matrices to your input embeddings to get Q, K, and V.
  3. Purpose of Different Spaces: By projecting the input into three different spaces, the model can independently manipulate the aspects of the input that are used to calculate attention weights (via Q and K) and the aspects that are used to compute the output of the attention mechanism (via V).

If you want to understand this, check the below code sample. For simplicity, we use random values.

import numpy as np

# Example input embeddings
X = np.random.rand(10, 16) # 10 elements, each is a 16-dimensional vector

# Initialize weight matrices for queries, keys, and values
W_Q = np.random.rand(16, 16) # Dimensions chosen for example purposes
W_K = np.random.rand(16, 16)
W_V = np.random.rand(16, 16)

# Compute queries, keys, and values
Q = np.dot(X, W_Q)
K = np.dot(X, W_K)
V = np.dot(X, W_V)

X here is generally derived from the process of creating input embeddings. It’s a large vector representation.

We initialize three weight matrices.

We simply do a dot product of the weight matrices, each with the input X.

Even Q, K, and V will have ten elements if X has ten.

Each element will be a vector that constitutes an array of 16 numbers (or whatever we define initially).

Step 3: Calculate Dot Products of Q and K Transpose

This is a very straightforward step once you have Q and K.

dot_product = np.dot(Q, K.T) ## Using numpy's inbuilt functions

Step 4: Get the Scaled Dot Product

Divide the dot products by the square root of the dimensions of the keys ​​ to prevent large values of dot products.

d_k = K.shape[-1] ## This will be 16 in this case

## Basically dividing by square root of 16, i.e. 4 in this case
scaled_dot_product = dot_product/(np.sqrt(d_k))

Step 5: Apply Softmax to the Scaled Dot Product

Softmax is basically a math function that converts numbers in the vector into probabilities. In simpler terms, it transforms a given array of numbers (the vector) into an array of new numbers that add to 1.

Here’s an example in code.

import numpy as np

def softmax(z):
exp_scores = np.exp(z) # This is e^z where e is a math constant
probabilities = exp_scores / np.sum(exp_scores)
return probabilities

# Example usage
vector_array = [2.0, 1.0, 0.1]
print("Softmax probabilities:", softmax(vector_array))
Softmax probabilities: [0.65900114 0.24243297 0.09856589]

So, the goal of this step is to convert the vector arrays that you computed in the previous step within the range of 0 to 1 as a series of probabilities. The output from this step are the attention weights, which essentially determine how important a given input word must be.

attention_weights = softmax(scaled_dot_product)

The higher this probability number is for the corresponding input from which it was derived, the higher its importance.

Step 6: Multiply by V

Multiply the attention weights with the value matrix V. This step aggregates the values based on the weights, essentially selecting which values to focus on. It’s basically connecting the probabilities back to the input matrix via V.

output = np.dot(attention_weights, V)

Putting everything together into a single math equation, we get:

As a Python function, it would look something like this:

import numpy as np

def softmax(z):
exp_scores = np.exp(z - np.max(z, axis=-1, keepdims=True)) # Improve stability
return exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

def scaled_dot_product_attention(X, W_Q, W_K, W_V):
# Compute queries, keys, and values
Q = np.dot(X, W_Q)
K = np.dot(X, W_K)
V = np.dot(X, W_V)

# Calculate dot products of Q and K^T
dot_product = np.dot(Q, K.T)

# Get the scaled dot product
d_k = K.shape[-1]
scaled_dot_product = dot_product / np.sqrt(d_k)

# Apply softmax to get attention weights
attention_weights = softmax(scaled_dot_product)

# Multiply by V to get the output
output = np.dot(attention_weights, V)

return output, attention_weights

# Example usage
# Define input embeddings and weight matrices
X = np.random.rand(10, 16) # 10 elements, each is a 16-dimensional vector
W_Q = np.random.rand(16, 16)
W_K = np.random.rand(16, 16)
W_V = np.random.rand(16, 16)

output, attention_weights = scaled_dot_product_attention(X, W_Q, W_K, W_V)
print("Output (Aggregated Embeddings):")
print(output)
print("\nAttention Weights (Relevance Scores):")
print(attention_weights)
Output (Aggregated Embeddings):
[[5.09599132 4.4742368 4.48008769 4.10447843 5.73438516 5.20663291
3.53378133 5.82415923 3.72478851 4.77225668 5.27221298 3.62251028
4.68724943 3.93792586 4.3472472 5.12591473]
[5.09621662 4.47427655 4.48007838 4.10450512 5.73436916 5.20667063
3.53370481 5.82419706 3.72482501 4.77241048 5.27219455 3.6225587
4.68727098 3.93783536 4.34720002 5.12597141]
[5.09563269 4.47417458 4.48010325 4.10443597 5.734411 5.20657353
3.53390455 5.82409992 3.72473137 4.77201189 5.2722428 3.6224335
4.68721553 3.9380711 4.34732342 5.12582479]
[5.09632807 4.47429524 4.48007309 4.10451827 5.7343609 5.2066886
3.5336657 5.82421494 3.72484219 4.77248653 5.27218496 3.62258243
4.6872813 3.93778954 4.34717566 5.12599916]
[5.09628431 4.47428739 4.4800748 4.10451308 5.73436398 5.20668119
3.5336804 5.8242075 3.72483498 4.77245663 5.2721885 3.62257301
4.68727707 3.93780698 4.3471847 5.12598815]
[5.0962137 4.47427585 4.48007837 4.10450476 5.7343693 5.20667001
3.53370556 5.82419641 3.72482437 4.77240847 5.2721947 3.62255803
4.68727063 3.93783633 4.34720044 5.12597062]
[5.09590612 4.47422339 4.48009238 4.10446839 5.73439171 5.20661976
3.53381233 5.82414622 3.72477619 4.77219864 5.27222067 3.62249228
4.68724184 3.93796181 4.34726669 5.12589365]
[5.0963771 4.47430327 4.48007062 4.10452404 5.73435721 5.20669637
3.53364826 5.82422266 3.72484957 4.77251996 5.27218067 3.62259284
4.68728578 3.93776918 4.34716475 5.12601133]
[5.09427791 4.47393653 4.48016038 4.10427489 5.7345063 5.2063466
3.53436561 5.8238718 3.72451289 4.77108808 5.27235308 3.6221418
4.68708645 3.93861577 4.34760721 5.12548251]
[5.09598424 4.47423751 4.48008936 4.1044777 5.73438636 5.20663313
3.53378627 5.82415973 3.72478915 4.77225191 5.27221451 3.62250919
4.68724941 3.93793083 4.34725075 5.12591352]]

Attention Weights (Relevance Scores):
[[1.39610340e-09 2.65875422e-07 1.06720407e-09 2.46473541e-04
3.30566624e-06 7.59082039e-07 9.47371303e-09 9.99749155e-01
6.23506419e-13 2.87692454e-08]
[9.62249688e-11 3.85672864e-08 6.13600771e-11 1.16144738e-04
6.40885383e-07 1.06810081e-07 7.39585064e-10 9.99883065e-01
1.63713108e-14 2.75331274e-09]
[3.69381681e-09 6.13144136e-07 2.80803461e-09 4.55134876e-04
6.64814118e-06 1.49670062e-06 2.41414264e-08 9.99536010e-01
2.48323405e-12 6.63936853e-08]
[9.79218477e-12 6.92919207e-09 7.68894361e-12 5.05418004e-05
1.23007726e-07 2.34823978e-08 1.24176332e-10 9.99949304e-01
9.59727501e-16 4.83111885e-10]
[1.24670936e-10 4.27972941e-08 1.21790065e-10 7.57169471e-05
7.82443047e-07 1.57462636e-07 1.16640444e-09 9.99923294e-01
2.35191256e-14 5.02281725e-09]
[1.35436961e-10 5.32213794e-08 1.10051728e-10 1.17621865e-04
8.32222943e-07 1.59229009e-07 1.31918356e-09 9.99881326e-01
3.69075253e-14 5.44039607e-09]
[1.24666668e-09 2.83110486e-07 8.25483229e-10 2.97601672e-04
2.85247687e-06 7.16442470e-07 8.11115147e-09 9.99698510e-01
4.61471570e-13 2.58350362e-08]
[2.94232175e-12 2.82720887e-09 2.06606788e-12 2.14674234e-05
7.58050062e-08 1.04137540e-08 3.39606998e-11 9.99978443e-01
1.00563466e-16 1.36133761e-10]
[3.29813507e-08 2.92401719e-06 2.06839303e-08 1.23927899e-03
2.00026214e-05 6.59439395e-06 1.48264494e-07 9.98730623e-01
4.90583470e-11 3.74539859e-07]
[3.26708157e-10 9.74857808e-08 2.53245979e-10 2.52864875e-04
1.76701970e-06 3.06926908e-07 2.62423409e-09 9.99744949e-01
1.07566811e-13 1.10075243e-08]]

Visually, the output is just another giant matrix of numbers, but these numbers now have more nuance and context in connection with the input embeddings used to create them.

Phew! That’s a long process, but in the grand scheme of things, it’s another small block in our journey.

If you read the transformer article, you’d know that this exact attention process is done in a “multi-head” manner across each layer of the transformer. The weight matrices differ in values at every layer and are optimized via training.

When the previous layer's output provides the Q, K, and V for the next layer, it’s known as self-attention.

Beyond Scaled Dot Product Attention

Many modern-day LLMs use a version of attention derived from the above process, changing a few steps in between or optimizing for memory and speed.

For example, the Longformer paper uses sliding window attention, which was later used in the Mistral models. It restricts the attention to a local neighborhood, reducing the number of computations to make it more manageable for long sequences.

Another recent approach is flash attention, which focuses on optimizing the attention calculation for efficiency, potentially through hardware optimization and advanced algorithms.

In case of any errata or follow-up questions, feel free to discuss them in the article responses section. I will update the content accordingly.

Loved the content and want me to write such in-depth articles for your startup website, blog, or documentation? Feel free to hit me up with a proposal at adityavivek.xq@gmail.com.

--

--

XQ
The Research Nest

Exploring tech, life, and careers through content.