Exploring Multi-Head Attention: Why More Heads Are Better Than One

Hassaan Idrees
4 min readJul 30, 2024

--

Understanding the Power and Benefits of Multi-Head Attention in Transformer Models

Introduction

Multi-head attention is a crucial innovation in Transformer models that has significantly enhanced their performance in natural language processing (NLP) tasks. For students and practitioners in AI and machine learning, grasping the concept and benefits of multi-head attention is essential for leveraging the full potential of Transformer models. This blog post explores the workings of multi-head attention, its advantages, and why having multiple heads is beneficial for model performance.

What is Multi-Head Attention?

Multi-head attention is an extension of the self-attention mechanism used in Transformer models. It allows the model to focus on different parts of the input sequence simultaneously, providing a richer and more nuanced understanding of the data.

Key Concepts:

  • Self-Attention: Calculates the relevance of each token in a sequence with every other token.
  • Multiple Heads: Instead of a single attention mechanism, the model uses several attention mechanisms (heads) in parallel.

Mathematical Formulation: For each head ‘h’, the attention mechanism is defined as:

where Q, K, and V are the query, key, and value matrices, respectively, and ‘d_k’ is the dimension of the key vectors.

In multi-head attention, these heads are computed in parallel:

where each head ‘head_i​’ is an attention function, and ‘W_o’ is a learned linear transformation.

How Multi-Head Attention Works

  1. Linear Projections:
  • Input sequences are linearly projected into queries, keys, and values for each head.

2. Parallel Attention:

  • Each head independently performs the attention mechanism on the projected queries, keys, and values.

3. Concatenation:

  • The outputs of all attention heads are concatenated.

4. Final Linear Transformation:

  • The concatenated outputs are linearly transformed to produce the final output.

Advantages of Multi-Head Attention

1. Richer Representation:

  • Diverse Perspectives: Each head can focus on different parts of the sequence, capturing various aspects of the data.
  • Complex Relationships: Enables the model to learn and represent complex relationships between tokens.

2. Improved Performance:

  • Parallel Processing: Multiple heads allow parallel computation, enhancing efficiency.
  • Enhanced Generalization: Helps the model generalize better to unseen data by learning multiple types of dependencies.

3. Flexibility and Adaptability:

  • Customizable: The number of heads can be adjusted to balance model complexity and computational resources.
  • Task-Specific Adaptation: Different tasks can benefit from different numbers of heads, allowing for tailored models.

Why More Heads Are Better Than One

1. Capturing Diverse Features:

  • Variety in Focus: Different heads can learn to focus on different parts of the input, capturing a wide range of features and relationships.
  • Redundancy: Provides robustness by reducing the impact of individual head failures.

2. Enhanced Attention Distribution:

  • Balanced Attention: Distributes attention more evenly across the sequence, avoiding over-reliance on specific tokens.
  • Comprehensive Understanding: Leads to a more comprehensive understanding of the input data.

3. Mitigating Overfitting:

  • Regularization Effect: Multiple heads can act as a form of regularization, preventing the model from overfitting to the training data.
  • Robust Learning: Encourages the model to learn more general patterns rather than memorizing specific details.

Implementing Multi-Head Attention

Let’s implement multi-head attention in PyTorch to see how it works in practice.

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads

assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"

self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)

values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)

energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))

attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)

out = self.fc_out(out)
return out

# Example usage
embed_size = 512
heads = 8
values = torch.rand((64, 10, embed_size))
keys = torch.rand((64, 10, embed_size))
query = torch.rand((64, 10, embed_size))
mask = None

multihead_attention = MultiHeadAttention(embed_size, heads)
output = multihead_attention(values, keys, query, mask)
print(output.shape) # Expected output shape: (64, 10, 512)

Applications of Multi-Head Attention

Multi-head attention is a key component in various Transformer-based models, enabling them to excel in diverse tasks:

1. Machine Translation:

  • Example: Translating text from one language to another, such as in Google’s NMT system.

2. Text Summarization:

  • Example: Generating concise summaries of long documents, as seen in BERTSUM.

3. Text Generation:

  • Example: Producing human-like text, as demonstrated by OpenAI’s GPT-3.

4. Question Answering:

  • Example: Answering questions based on a given context, as in BERT-based QA systems.

5. Sentiment Analysis:

  • Example: Determining the sentiment of text, such as customer reviews.

Conclusion

Multi-head attention is a powerful mechanism that significantly enhances the performance of Transformer models by allowing them to capture diverse features, distribute attention effectively, and improve generalization. Understanding and leveraging multi-head attention is essential for building state-of-the-art models in NLP and beyond. Experiment with different configurations and observe the impact on your models. Share your thoughts and questions in the comments below, and stay tuned for more insights into the world of machine learning and AI.

--

--

Hassaan Idrees
Hassaan Idrees

Written by Hassaan Idrees

Hassaan is an aspiring Data Scientist with a passion for self-directed learning in ML, eager to showcase his proficiency in NLP, CV, and various ML techniques.

No responses yet