RAG: Part 3: Embeddings

Mehul Jain
10 min readApr 5, 2024

--

Representation of the text as a vector is the most crucial part of any NLP problem. In Rag, we convert chunks of documents to these vectors to retrieve similar documents as of user query.

Getting an accurate embedding that captures the context helps enhance the downstream tasks. In the past 3–4 years, a lot of research has happened to enhance the accuracy from a simple one-hot-encoded embedding to a dense vector representation.

Photo by Armand Khoury on Unsplash

In this blog, I will discuss why embedding is used in RAG and its various types, enhancements etc.

Embedding

Embeddings are dense, low-dimensional representations of words or entities in a continuous vector space. They capture semantic relationships between words and enable machine learning models to understand and process natural language more effectively.

Embeddings map each word to a vector of real numbers, where similar words are represented by similar vectors. Embeddings provide a more meaningful representation of words compared to one-hot encoding or other traditional methods.

Types of Embeddings

1. Word Embeddings

  • Word2Vec: A popular method that learns word embeddings by predicting neighbouring words in a large corpus.
  • GloVe: Global Vectors for Word Representation, which combines word co-occurrence statistics with matrix factorization techniques.
  • FastText: Embeddings that capture subword information, useful for handling out-of-vocabulary words and morphologically rich languages.

2. Contextual Embeddings

  • ELMo: Embeddings from Language Models, which produces contextual embeddings by combining hidden states of a deep bidirectional LSTM.
  • BERT: Bidirectional Encoder Representations from Transformers, a pre-trained model that generates contextual word embeddings by considering the surrounding context.

In Transformers, the attention mechanism was introduced which is a most fundamental component of many advanced neural network architectures. It enables models to focus on specific parts of the input sequence when making predictions, allowing them to weigh the importance of different elements dynamically.

Let’s see how researchers enhanced the Attention mechanism

Types of Attention Mechanism

1. Self-Attention

Self Attention enables the model to capture relationships between any pair of tokens within the context window, allowing them to weigh the significance of each element in the context of the entire sequence.

Let’s break it down using its formula.

Q is the query matrix of a token for which we want to calculate the similarity score

K is the Key matrix of a token, using which we want to calculate the score

V is the Key matrix of the same token as K

Step 1: Compute Similarities:

For each token in the sequence, compute a similarity score with every other token. This is done by taking the dot product of the query (Q) representation of the current element with the key (K) representations of all elements in the sequence.

Step 2: Calculate Attention Weights:

The similarity scores are normalized using a softmax function to obtain attention weights. These weights indicate the importance of each element relative to others in the sequence.

Step 3: Compute Context Vector:

The attention weights are then multiplied with value (V) representations. This vector captures the relationships between different tokens in the sequence, emphasizing those that are most relevant to each other.

Step 4: Compute the self-attention weight:

Sum up the weighted value (V) representations. This produces the output weights of the self-attention layer for a given token, which will then be passed to a feedforward neural network.

2. Multi-Head Attention

It is an advanced version of self-attention (basically more learnable weights)

Query, Key, and Value Representations:

Like standard attention mechanisms, multi-head attention operates on query, key, and value representations of the input sequence.

Multiple Attention Heads:

In multi-head attention, the query, key, and value representations are split into multiple heads. Each head learns a unique set of weights during training, allowing it to focus on different aspects of the input sequence.

Independent Attention Mechanisms:

Each attention head independently computes attention scores between the query and key representations of the input sequence. This enables the model to capture diverse patterns and relationships by attending to different parts of the input sequence simultaneously.

Concatenation and Linear Transformation:

The outputs of all attention heads are concatenated and linearly transformed to produce the final output of the multi-head attention layer. This transformation allows the model to combine information from different heads and generate a comprehensive representation of the input sequence.

3. Multi Query Attention

MQA is a technique used in the decoder portion of Transformer models. It’s a variation of the multi-head attention (MHA).

  • MQA uses the same basic attention mechanism but with a key difference: it has only one set of K and V vectors for all the attention heads.
  • In regular MHA, each head has its own K and V vectors.
  • By sharing these vectors across all heads, MQA significantly reduces the memory footprint and computational cost, especially for long sequences.
  • While MQA is faster and more memory-efficient, it can lead to a slight decrease in model performance compared to MHA.

4. Grouped Query Attention

GQA was first used in LLama2 to improve efficiency during the inference stage. It essentially balances the quality of attention achieved by Multi-Head Attention with the speed of Multi-Query Attention.

Source of Image by Auther
  • Interpolation: GQA acts as an intermediate approach between MQA (with a single group) and MHA (with a number of groups equal to the number of query heads). This allows it to address the quality degradation issues of MQA while maintaining some level of efficiency.
  • Efficiency: By reducing the number of key and value heads compared to MHA, GQA achieves faster computation during inference.
  • Memory Bandwidth Optimization: GQA avoids the memory overhead associated with replicating a single key-value head in sharded models (used for very large models).

5-A. Flash Attention

Flash Attention is a technique designed to improve the efficiency of transformer models.

  • Transformer models rely heavily on an attention mechanism to understand the relationships between different parts of the input text.
  • This attention mechanism, however, becomes computationally expensive for very large models due to its O(n^2) quadratic time and memory complexity.
  • As models get bigger, the amount of data they need to process grows rapidly, creating a memory bottleneck.

Let's break down the 2 main components of Flash Attention.

Tiling:

  • Standard attention mechanisms process the entire input sequence at once.
  • Flash Attention breaks down the large attention matrix into smaller, more manageable tiles.
  • This tiling process reduces memory usage and improves the efficiency of computations on the graphics processing unit (GPU).

Recomputation:

  • During training (backward pass), standard attention mechanisms store intermediate results, which can consume a significant amount of memory.
  • Flash Attention avoids this by recomputing these intermediate results using previously stored outputs and calculated statistics.
  • This eliminates the need for excessive memory storage during the backward pass, although it increases the floating point operations.
Source: Image by huggingface

Here in Flash Attention, in one operation of loading the data from HBM to Compute, calculations can be done using SRAM (faster read-write then HBM), reducing the number of HBM read-write operations.

5-B. FlashAttention-2

FlashAttention-2 is an improvement on a previous algorithm, designed to speed up the training of large language models (LLMs) by 2 times.

FlashAttention-2 focuses on 3 aspects:

Algorithm:

Standard attention mechanisms rely heavily on operations other than matrix multiplications, which GPUs are specialized for. FlashAttention-2 minimizes these less efficient operations by potentially restructuring calculations or using alternative algorithms.

Parallelism:

In the standard attention mechanism, every element in the query sequence is compared with every element in the key sequence. FlashAttention-2 breaks the mould by parallelizing the computations across the sequence dimension. With more threads working on smaller chunks of the sequence, the GPU can utilize its resources more efficiently.

Work Partitioning:

GPUs group processing units into thread blocks that work together on tasks. While FlashAttention-1 achieved some level of parallelism, FlashAttention-2 optimizes how work is distributed within these blocks. It minimizes the need for threads to constantly access and update shared memory, a small but fast memory space on the GPU. This reduces communication bottlenecks and allows threads to work more efficiently in parallel.

6. Hyper Attention

HyperAttention achieves near-linear time complexity for long-context attention. This means the processing time increases proportionally to the sequence length, not quadratically unlike standard attention. FlashAttention does exact computations while HyperAttention trades off some accuracy for speed.

It focuses on 3 aspects:

Hierarchical Encoding:

HyperAttention represents the input sequence using a hierarchical structure. It breaks down the sequence into smaller subsequences and then builds a hierarchy by progressively grouping these subsequences.

Sparse Interactions:

Instead of comparing every element in the sequence with every other element, HyperAttention focuses on relevant interactions within the hierarchical structure. This reduces the number of computations significantly, especially for long sequences. It employs random feature projections to efficiently capture long-range dependencies within the sequence.

Approximate Row Sums:

Calculating the exact row sums of a matrix (A) containing masked values (large values set to zero using random feature projections) becomes expensive for normalization. This skews the actual sum and hinders proper normalization. HyperAttention’s method based on Approximate Row Sums squared L2 norms avoids the computational overhead. This obtains an unbiased estimate of the actual row sums

7. Sliding window Attention

Sliding Window Attention was introduced in Longformer. It was also used in Mistral It is specifically designed to handle long sequences. It tackles the computational challenges that arise when a model tries to consider all parts of a long sequence at once.

Source of Image by Auther

Focusing on Local Context with Sliding Windows

  • Sliding Window Attention addresses this issue by introducing the concept of a “window.” This window focuses on a fixed-size segment of the sequence around the element of interest.
  • The model calculates attention scores only for elements within the window. This significantly reduces the computational cost compared to considering the entire sequence.
  • The window then “slides” along the sequence, one element at a time, allowing the model to gradually build up an understanding of the entire sequence while focusing on the local context at each step.
  • This is a CNN-type learning.

Sliding Window Attention in Mistral 7b:

  • Attention Span: Although the window itself has a fixed size, Mistral 7b can theoretically “attend” to a much larger span of the sequence. This is achieved through a technique called KV-Cache with Rolling Buffer.
  • The KV-Cache stores key-value pairs from previous window calculations.
  • The Rolling Buffer allows the cache to efficiently manage information as the window slides.
  • By combining information from the current window and the KV-Cache, Mistral 7b can effectively capture context beyond the immediate window size.
  • Large Cache and Pre-filling: Mistral 7b utilizes a relatively large cache to store key-value pairs. This cache can be pre-filled with information about the prompt which is known in advance. It further enhances the model’s ability to access relevant context even for tokens far away in the sequence.

8. Ring Attention

Ring Attention offers a novel approach that enables handling near-infinite context.

Source of image by Auther

Author computed the original Transformer block-by-block. Each host is responsible for one iteration of the query’s outer loop, while the key-value blocks rotate among the hosts. As visualized, a device starts with the first query block on the left; and then iterate over the key-value blocks sequence positioned horizontally. The query block, combined with the key-value blocks, are used to compute self-attention (yellow box), whose output is pass to feedforward network (cyan box).

  • Blockwise Computation: It breaks down both self-attention and feedforward layers of a Transformer into smaller blocks. This allows for processing long sequences in manageable chunks.
  • Distributed Processing: The sequence is distributed across multiple devices (e.g., GPUs) in a ring-like fashion.
  • Overlapping Communication and Computation: While one device computes attention for its assigned block, it simultaneously communicates key-value blocks with its neighbouring devices. This overlapping process avoids idle time and optimizes computation.
  • Full Attention: Unlike chunking methods, Ring Attention still calculates full attention within each block, preserving accuracy.

9. BurstAttention

Burst Attention is one of the latest developments to handle attention mechanisms in large language models (LLMs) dealing with extremely long sequences.

In the above techniques, we have seen that to increase the context length either we can do Distributed Attention or we can do Subsequences processing. But they have a few limitations —

  • Distributed Attention: Distribute the computation across multiple devices in a cluster. However, this introduces communication overhead and additional memory usage.
  • Subsequence processing: Divides the sequence into smaller subsequences and processes them independently. This might miss long-range dependencies between elements in different subsequences.

BurstAttention aims to optimize both memory access and communication for attention in distributed settings with extremely long sequences.

Here’s how it works:

  1. Sequence Partitioning: BurstAttention first divides the long sequence into partitions based on the number of devices in the distributed cluster. Each partition is assigned to a specific device.
  2. Local Embedding Projections: Each device independently projects its assigned sequence partition into query, key, and value embeddings.
  3. Pinned Queries: Query embeddings from all partitions are pinned in memory, meaning they are accessible by all devices.
  4. Distributed Attention Score Calculation: Each device iterates through its assigned key-value partitions and calculates attention scores with all pinned queries from other devices. This avoids the need to send the entire key-value data across devices.
  5. Local Attention: Additionally, each device performs a separate local attention computation within its own assigned sequence partition.

Challenges

  • Bias and Fairness: Embeddings may reflect biases present in the training data, leading to biased predictions and unfair outcomes.
  • Embeddings selection: Deciding whether to use word embeddings or contextual embeddings depends on several factors, including the nature of your NLP task, the availability of training data, computational resources, and the desired level of performance
  • Dynamic Embeddings: Developing embeddings that adapt to changes in language and context over time.

Conclusion

  • Embeddings play a fundamental role in NLP tasks by providing compact and semantically meaningful representations of words and entities. Understanding different types of embeddings and their applications is essential for building robust and effective NLP systems.

Thanks for spending your time on this blog. I am open to suggestions and improvements. Please let me know if I missed any details in this article.

--

--