PART 3 — Decoding Attention: The Magic Behind Generative & Contextual AI Models (Transformer, Llama, GPT, BERT, etc.)
Before delving deep into this article, you can refer to my previous two articles in this series to gain a better understanding and grasp of this world of GAI
- PART 1 — The First GAI-LLM: Original “Transformer” Embodied the Foundational Concepts of Modern GAI(LLM/LVM/LAM/LTMs), RAG, Prompt Tuning, and More , link :https://medium.com/@trishulchowdhury.23/part-1-the-first-gai-llm-original-transformer-embodied-the-foundational-concepts-of-modern-d3f8dea1f027
- PART 2 — The “Attention-Based” Model Chronicles: A Comedic Breakdown,link:https://medium.com/@trishulchowdhury.23/part-2-the-attention-based-model-chronicles-a-comedic-breakdown-389adb869d25
Content:
- I🌟 Intuitive understanding of Attention Mechanism in the context of Generative and Contextual AI (GAI & CAI)
- 🔍 What does the attention mechanism allow a model to do
- 📏 Complexity mathematically and computationally
- 🔢 Mathematical and Pythonic intuition of different Types of Attention Mechanisms used in different GAI and CAI:
- 🤖 Self-Attention (Original Transformer & others)
- 🌐 Multi-Head Attention (Original Transformer & others)
- 🔄 Cross Attention (Original Transformer & others)
- 🕵️♂️ Masked Self-Attention (Original Transformer & others)
- 👥 Grouped Query Attention (GQA) (LLaMA 2 and 3.1)
- 📊 Scaled Dot-Product Attention
- 📏 Relative Positional Attention
- 📐 Distance-Aware Attention
5. 🧠 Why Queries, Keys, and Values? — mimics human cognitive processes
6. 🔦 Attention with real-time intuitive analogy — 🌟🔦 Flashlight in a Dark Room
Ready, Set, Go! Let’s Dive In!
Alright, let’s take a breather from the models, tasks, and objectives and dive into the heart of it all — the foundational algorithm behind the Transformer model. Yes, I’m talking about Attention!
Attention mechanisms have revolutionized AI, enabling models to focus on relevant parts of input data, thereby improving performance in tasks like language translation, image recognition, and more. In this article, we’ll explore various attention mechanisms, explain their intuition, and provide simple Python snippets to illustrate their workings.
Why do we need Attention?
In the fields of Natural Language Understanding (NLU) and Natural Language Processing (NLP), it is essential to convert natural language into numerical representations. Each word (token) in a sentence is converted into a numerical representation called a tensor, capturing its unique characteristics. These tensors are then passed through an embedding layer, which transforms them into dense vectors in an embedding space. This embedding space helps the model understand the similarities and relationships between words.
But where context is key (especially in complex sentences), are traditional methods (TF-IDF, BOW, etc.) or Blazing Text Analysis (Word2vec) capable enough to make the model understand the context for further processing?
Consider the following sentences:
- The animal cannot cross the road because it is too tired.
- The animal cannot cross the road because it is too dark.
- The animal cannot cross the road because it is too wide.
- TF-IDF kind of models assign a fixed importance to words based on their frequency across documents, failing to capture contextual nuances. The word “it” (“tired”) in the first sentence and “it” (“dark”) in the second would have the same importance across different contexts, which is not ideal.
- Whereas Self-Attention dynamically adjusts the importance of each token based on its role in the specific sentence.
What does the attention mechanism allow a model to do?
- Focus on Relevant Parts : Focus on different parts of the input sequence
- Contextual Understanding : For example, predicting “tired” in the first sentence requires understanding the context provided by “animal” and “road.”
- Dynamic Adjustment : The model can dynamically adjust which words it focuses on, providing a more accurate and context-aware understanding of the sentence
Key Message
The original “Transformer Paper” — “Attention is All You Need” came up with the following formula to calculate the attention weights (do not read this as the score in this formula) of each token (word):
Let me reiterate “attention scoring” is a fundamental algorithm. While its core principles remain the same across various architectures like Transformer, BERT, GPT, and LLaMA families , the specific implementation may differ.For example, Self attention, Masked Self Attention , Grouped Query Attentio,Encoder-Decoder (Cross) Attention, Multi-Head Attention,Scaled Dot-Product Attention,Relative Positional Attention,Distance-Aware Attention etc.
We will discuss everything in detail!
Before diving into the core calculations, it’s essential to understand why the attention mechanism is considered complex both mathematically and computationally. We can simply figure it out by seeing the equation.
Mathematical Complexity
- Self-Attention, Multi-Head Attention, and Positional Encoding: These concepts require a deep understanding of linear algebra, probability, and optimization to implement effectively.
Computational Complexity
- Implementing and Training any type of Transformer Models: The quadratic complexity of self-attention makes it demanding, as it requires substantial computational resources to handle large-scale data and extensive matrix operations.
- Attention Mechanism: This mechanism has a quadratic complexity of ,
where n is the sequence length and ddd is the embedding dimension.
- Multi-Head Attention: The parallel attention heads further increase computational demands due to the parallel computation of multiple attention mechanisms.
Resource Requirements
- State-of-the-Art Performance: Achieving state-of-the-art performance with large datasets necessitates significant computational resources like GPUs or TPUs and high memory capacity to store intermediate representations and gradients.
Types of Attention Mechanisms
As I said earlier, the fundamental idea of “attention scoring” appears in different models in different ways of implementation
Let’s first understand the working of this Attention algorithm in the pioneer “Transformer model” and then various other forms of attention like GQA (used in LLaMA 2 and 3.1).
1. Original Transformer Self-Attention
📚 Step 1- Input Embeddings : Each word in the input sentence is converted into a fixed-size vector (embedding). Let’s consider an example sentence: “The cat sat on the mat.”
🛠️ Step 2 - Creating Query, Key, and Value,Matrices : For each word, we create three vectors: Query (Q), Key (K), and Value (V). These vectors are obtained by multiplying the input embeddings with learned weight matrices
Where E is the input embedding matrix and
are the weight matrices.
Intuition: The Query, Key, and Value transformations allow the model to project the input data into different spaces that capture various relationships and dependencies between words.
🔍 Step 3 - Calculating Attention Scores : The attention scores are calculated using the dot product of the Query vector of the current word with the Key vectors of all words in the sentence. This measures the similarity between the current word and other words.
Intuition: The dot product is used because it efficiently captures the degree of alignment between the Query and Key vectors. A higher dot product indicates a stronger correlation, meaning the words are more related in the given context and helps measure the relevance or similarity between words, while the Value vector captures the actual information to be transferred based on this relevance.
📏✨ Step 4 - Scaling the Attention Scores :The dot products can result in large values, especially when the dimensionality of the vectors is high. To mitigate this, we scale the scores by the square root of the dimensionality of the Key vectors dk.
Intuition: Scaling ensures that the variance of the dot product remains stable regardless of the dimensionality, leading to more balanced gradients and better convergence during training.
🎛️🔥 Step 5 - Applying the Softmax Function :The scaled attention scores are passed through a softmax function to obtain the attention weights. The softmax function converts the scores into probabilities that sum up to 1.
Intuition: The softmax function normalizes the attention scores into probabilities, making it easier to interpret the relative importance of each word in the context of the current word.
📊💡Step 6 - Computing the Weighted Sum of Values : The final output for each word is obtained by computing the weighted sum of the Value vectors, where the weights are the attention weights calculated in the previous step.
Intuition: The weighted sum aggregates the information from all words in the sentence, with each word contributing according to its relevance. This allows the model to dynamically focus on the most important parts of the input.
"""
in the sentence "The cat sat on the mat,"
self-attention helps the word "cat" to relate to "sat" and "mat."
"""
import torch
import torch.nn.functional as F
def self_attention(query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float32))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, value)
return output
query = key = value = torch.randn(1, 5, 64) # (batch_size, sequence_length, d_model)
output = self_attention(query, key, value)
print(output)
"""
Your input tensor query = key = value = torch.randn(1, 5, 64)
has a shape of (batch_size, sequence_length, d_model).
O/p : tensor([[[ 0.3885, 1.4492, -1.2677, -0.6434, -1.0158, 1.2113, -0.6044,
-1.0162, 2.4235, -1.0818, 0.6371, -1.2526, -0.1937, -0.3456,
...],
[-1.1771, 0.2208, 1.2248, -0.4764, 1.0884, -1.4994, -1.1418,
-0.7764, 1.0507, -1.1256, -0.4755, -1.7120, 0.0080, 0.9431,
...],
[ 0.3001, -1.9093, 0.3644, 2.1997, 1.4283, -0.1375, 1.0424,
-1.6352, -2.0777, 1.3198, -1.4479, 0.8534, 1.0697, -0.4432,
...],
[-0.1678, -0.6441, 0.1086, 1.1505, 1.3514, -1.1478, 0.6115,
-1.0266, -1.9608, 0.0865, -0.9920, 0.7417, -0.0379, -1.1117,
...],
[-0.1431, 0.3357, 0.1479, 3.2758, -0.9460, 0.4410, 1.2958,
-2.1141, -1.0537, -0.6751, 0.2261, 1.1235, 0.7113, 0.8712,
...]]])
Each of these 5 output vectors corresponds to the attention-weighted
sum of the value vectors for each token in the sequence, considering
its relationship (attention scores) with all other tokens """
This Self-Attention, sometimes called the Full-Attention mechanism, captures various relationships and patterns within the input data (without any mask), providing a richer and more nuanced representation.
2. Original Transformer Multi-Head Attention
Multi-head attention is a key component in Transformer models, enabling the model to focus on different parts of the input sequence simultaneously. The multi-head attention mechanism can be mathematically represented as follows
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, d_k):
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(3)]) # Adjusted
def forward(self, Q, K, V):
batch_size = Q.size(0)
Q, K, V = [l(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (Q, K, V))]
heads = [scaled_dot_product_attention(q, k, v, self.d_k) for q, k, v in zip(Q, K, V)]
concat = torch.cat(heads, dim=-1).view(batch_size, -1, self.d_k * self.num_heads)
output = self.linear_layers[-1](concat) # Ensure final transformation is correct
return output
# Example usage
Q = K = V = torch.rand(1, 10, 512) # (batch_size, seq_length, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output = mha(Q, K, V)
print(output.shape)
"""
O/P : torch.Size([1, 10, 512])
"""
print(output)
"""
tensor([[[-0.4728, 0.0103, 0.1501, ..., 0.1999, 0.0523, 0.2012],
[-0.2052, -0.0293, 0.1955, ..., 0.2577, 0.0355, 0.2996],
[-0.1817, 0.1691, -0.0383, ..., -0.0795, -0.2282, -0.2054],
...,
[ 0.0599, 0.2470, -0.0963, ..., 0.1081, 0.0790, -0.0865],
[ 0.2901, 0.2257, -0.0301, ..., -0.1064, 0.0507, -0.1555],
[-0.1240, 0.3088, -0.1013, ..., -0.0478, -0.2885, -0.4456]]],
grad_fn=<ViewBackward0>)
"""
3. Encoder-Decoder (Cross) Attention
Encoder-Decoder (Cross) Attention is an essential mechanism in Transformer models, allowing the decoder to attend to relevant parts of the encoded input sequence. This process is critical for tasks like machine translation, where the model needs to generate output sequences based on the encoded input.
def cross_attention(query, key, value):
return self_attention(query, key, value)
encoder_output = torch.randn(1, 5, 64) # Encoded source sentence
decoder_query = torch.randn(1, 5, 64) # Query from decoder
output = cross_attention(decoder_query, encoder_output, encoder_output)
print(output.shape)
"""
O/P : torch.Size([1, 5, 64])
"""
4. Masked Self-Attention
Masked self-attention is used in tasks like language modeling, where the model should not see future tokens. It masks out future tokens to prevent information leakage.
Intuition: A lower triangular matrix is used as a mask in the transformer decoder. Each row represents an iteration where the current token can only attend to the previous tokens and itself. This ensures that no information is accessed from future tokens, enforcing the autoregressive property necessary for sequence generation
"""
Here is a step-by-step implementation of
masked self-attention in a decoder using the example sentence
"I love Science". We will use PyTorch to demonstrate this process.
"""
import torch
import torch.nn.functional as F
#Define Masked Self-Attention Function
def masked_self_attention(query, key, value, mask):
# Calculate the attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float32))
# Apply the mask (setting masked positions to a very large negative value)
scores = scores.masked_fill(mask == 0, -1e9)
# Calculate the attention weights
weights = F.softmax(scores, dim=-1)
# Apply the attention weights to the value
output = torch.matmul(weights, value)
return output
# Prepare Input Data and Mask
"""
Assume we have a sentence: "I love Science" represented as a tensor. For simplicity,
let's consider each word is represented by a random 64-dimensional vector.
"""
# Example sentence tensor representation (batch_size, sequence_length, d_model)
query = key = value = torch.randn(1, 3, 64) # "I love Science" has 3 tokens
# Mask to prevent attending to future tokens (upper triangular part filled with 0s)
mask = torch.tril(torch.ones(3, 3)).unsqueeze(0) # (batch_size, sequence_length, sequence_length)
#print("Query:", query)
print("Mask:", mask)
"""
O/p :
Mask: tensor([[[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]]])
"""
#Apply Masked Self-Attention
output = masked_self_attention(query, key, value, mask)
print("Output shape:", output.shape)
5. Grouped Query Attention
In this variation of the Attention mechanism, it divides the input into groups and applies attention within each group, reducing computational complexity and capturing localized patterns.
Grouped Query Attention (GQA) is an advanced attention mechanism designed to improve the efficiency of traditional multi-head attention while maintaining high-quality performance. The key idea behind GQA is to partition the queries into distinct groups and compute attention within each group independently.
Key Concepts to remeber
🧠🔍 Multi-Head Attention: Multi-head attention allows the model to focus on different parts of the input sequence simultaneously by using multiple sets of queries, keys, and values. However, this can be computationally expensive.
🔄🔑 Multi-Query Attention: Multi-query attention reduces the computation by sharing keys and values across all heads, but this can sometimes lead to a loss in the quality of attention.
🌐📊 Grouped Query Attention (GQA): GQA strikes a balance by grouping the queries and performing attention separately within each group, thus achieving a quality close to multi-head attention with the speed of multi-query attention.
Intuition: Multi-head attention has H query, key, and value heads. Multi-query attention shares single key and value heads across all query heads. Grouped-query attention instead shares single key and value heads for each group of query heads, interpolating between multi-head and multi-query attention
The steps are as follows :
Intuitions Behind Each Step
Intuitions Behind Each Step:
🔵 Partitioning Queries: Splitting the query matrix Q into smaller groups reduces the dimensionality and computational load for each group. This approach retains the essential features while allowing efficient processing.
🔵 Dot-Product Attention: By computing the dot product between each query group and all keys, we measure the relevance of each key to the query. The softmax function normalizes these relevance scores, ensuring they sum to 1 and can be interpreted as probabilities.
🔵 Concatenation: Merging the attention outputs from all groups ensures that the diverse attention patterns are combined, providing a comprehensive representation that incorporates multiple aspects of the input data.
import torch
import torch.nn as nn
# Define the Grouped Query Attention class
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_groups):
super(GroupedQueryAttention, self).__init__()
assert d_model % num_heads == 0
assert d_model % num_groups == 0 # Ensure d_model is divisible by num_groups
self.d_model = d_model
self.num_heads = num_heads
self.num_groups = num_groups
self.d_k = d_model // num_heads
self.d_g = d_model // num_groups # Dimension per group
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.size()
# Linear projections
queries = self.query(x).view(batch_size, seq_len, self.num_heads, self.d_k)
keys = self.key(x).view(batch_size, seq_len, self.num_groups, self.d_g) # Use d_g instead of d_k
values = self.value(x).view(batch_size, seq_len, self.num_groups, self.d_g) # Use d_g instead of d_k
# Scaled Dot-Product Attention
scores = torch.einsum('bqhd,bkhd->bhqk', queries, keys) / (self.d_k ** 0.5)
attn = torch.softmax(scores, dim=-1)
# Apply attention to values
context = torch.einsum('bhqk,bkhd->bqhd', attn, values)
# Concatenate heads and put through final linear layer
context = context.contiguous().view(batch_size, seq_len, self.d_model)
output = self.out(context)
return output
# Example usage
d_model = 512 # dimension of the model
num_heads = 8 # number of attention heads
num_groups = 8 # number of groups in GQA (must divide d_model)
seq_len = 10 # sequence length
batch_size = 2 # batch size
# Dummy input
x = torch.randn(batch_size, seq_len, d_model)
gqa_layer = GroupedQueryAttention(d_model, num_heads, num_groups)
output = gqa_layer(x)
print("Output shape:", output.shape)
"""
Output shape: torch.Size([2, 10, 512])
"""
6. Scaled Dot-Product Attention
This is a core mechanism where the dot product of query and key is scaled by the square root of the key dimension. It stabilizes gradients during training.
def scaled_dot_product_attention(query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float32))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, value)
return output
output = scaled_dot_product_attention(query, key, value)
7. Relative Positional Attention
Incorporates positional information to help the model understand the order of words, which is crucial for tasks like parsing and translation.
# Simplified illustration (actual implementation can be complex)
def relative_positional_attention(query, key, value, pos_embedding):
scores = torch.matmul(query + pos_embedding, key.transpose(-2, -1))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, value)
return output
pos_embedding = torch.randn(1, 5, 64)
output = relative_positional_attention(query, key, value, pos_embedding)
8. Distance-Aware Attention
Incorporates distance between elements in the sequence, useful for tasks where spatial or temporal distance matters.
def distance_aware_attention(query, key, value, distances):
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float32))
scores += distances
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, value)
return output
distances = torch.randn(1, 5, 5)
output = distance_aware_attention(query, key, value, distances)
Cool! Now we know the intuition behind various attention mechanisms and have an understanding of why these are important in GAI (Generative AI) and CAI (Contextual AI) across all kinds of LLM, LVM, LAM, LTM, etc.
Now let’s end this article with an intuitive question and a real-life intuitive analogy of the Attention mechanism
Why Queries, Keys, and Values?
The use of queries, keys, and values in attention mechanisms mimics human cognitive processes. When humans focus on a task, they selectively attend to relevant information (queries), reference prior knowledge (keys), and extract useful details (values). This setup allows models to perform similar cognitive functions.
The three matrices (Q, K, and V) are fundamental to the attention mechanism because they:
- Allow us to ask specific questions about the input (Query).
- Provide a means to determine the relevance of different parts of the input (Key).
- Retrieve the relevant information based on that relevance (Value).
Adding more matrices would complicate the mechanism without providing additional necessary functionality for the attention process. The three matrices efficiently encapsulate the necessary operations for attention. With these matrices, we effectively compute the attention weights and the context vectors, enabling the model to focus on relevant parts of the input sequence when making predictions.
CONCLUSION
Let’s conclude this without any technical jargon; rather, let’s draw a quick real-life intuitive analogy:
🌟🔦 Flashlight in a Dark Room
Let’s break down the attention mechanism with a real-world analogy:
Input Embeddings
- 🏠 Objects in the Room: Imagine you are in a dark room with various objects scattered around. Each object has certain characteristics (size, shape, color), representing the input sequence. These characteristics are encoded into embeddings.
Weight Matrices
Different lenses you can put on your flashlight to highlight different features of the objects:
- 🔍 Wq (Query): This lens helps you focus on certain characteristics you’re interested in (e.g., looking for rectangular shapes if you’re searching for a book).
- 🔑 Wk (Key): This lens adjusts how each object in the room can match the characteristics you’re looking for (e.g., how well each object matches the idea of being a book).
- 📖 Wv (Value): This lens lets you gather more detailed information about the objects that match your query.
Linear Transformation (Matrix Multiplication)
When you shine your flashlight through these lenses, you are transforming the view of the objects:
- ✨ Multiplying the Input Embeddings by Wq, Wk, and Wv: This changes how you perceive the objects in the room based on the specific lens you’re using.
Summary
- 🏠 Input Embeddings: The objects in the room.
- 🔍 Wq (Query): Lens focusing on characteristics of interest.
- 🔑 Wk (Key): Lens matching objects to the characteristics.
- 📖 Wv (Value): Lens gathering detailed information.
By applying these lenses, you can effectively use your flashlight to find a specific object (e.g., a book) in a dark room, highlighting the power and versatility of the attention mechanism in transforming and understanding input data.
“If you find this article helpful, a clap and following my profile will be highly appreciated.” Cheers!