Intuition for Multi-headed Attention.

Follow up on the intuition on attention mechanism

Ngieng Kianyew
7 min readSep 1, 2023

In my previous article ‘Attention Distilled’, I explained the intuition of the attention mechanism. In this article, we take it a step further to improve the Attention mechanism by using an idea very simple: Increase the number of Attention.

There are three parts to this article:

(1) The first part explains what does it mean by “Multi-Head”

(2) The second part explains the Need for Multi-Head Attention and the limitation of the single-head attention

(3) The third part is the implementation in Pytorch

In case my article is not enough:

I want to give a huge shoutout to a very under-rated YouTube video on the attention mechanism here by ‘Algorithmic Simplicity’: https://www.youtube.com/watch?v=kWLed8o5M2Y&t=343s&ab_channel=AlgorithmicSimplicity

This is the best video that I know thus far that explains the intuition of the attention mechanism, highly recommended.

What is Multi-head

Single head attention(left) and Multi-Head Attention (right). Image from attention is all you need paper

We recall from the summary of my article ‘Attention Distilled’, that the Attention mechanism takes into account the following:

  • (1) How similar or relevant one word is to other words with the dot product
  • (2) How similar a word should be based on previous timesteps with the causal mask
  • (3) How the network should learn to handle less important words with the attention dropout
  • (4) How the input vector representation should be based on past, present and future timesteps with the projection layer
  • (5) How the final input vector representation should also learn to handle less important words with the residual dropout

In transformers network terminology, a ‘head’ is a single Attention mechanism

  • On the left side of the picture, it depicts a ‘Scaled Dot-Product Attention’. This is essentially the attention mechanism we explained in my previous article. This is considered a single-head Attention mechanism.
  • On the right side of the picture, it depicts a Multi-Head Attention(‘Scaled Dot-Product Attention’). We concatenated the outputs from all the attention heads into a linear projection layer. This is considered a Multi-head Attention.

Referring to my summary above, we essentially repeat steps 1 to 3 for “h” (number of heads) times, and then concatenate them together before step (4). So if we have three heads, then h = 3, we repeat steps 1 to 3 for three times, and then concatenate them together (usually concat and then `torch. flatten`) before step (4).

In summary:

  • Multi-head just means the number of attention layers we want to use because the projection layer
  • Multi-head means that we are using more than one attention layers

The Need for Multi-head Attention

Recall that the single-head attention mechanism’s purpose is just to get “a” better vector representation of each word in the sentence.

Note that it is ‘a’ vector representation. Keep this in mind because it will be of important relevance to the example below

  • However, a question to ask ourselves is does a word always has the same semantic meaning? The answer is no!

A word can have more than one definition/meaning

Consider the following two sentences with the word ‘bank’:

Example to show a word can have two different meaning
  • In sentence A, the word “bank” refers to a place where people deposit money
  • In sentence B, the word “bank” refers to the sloping land alongside a river.

Think about the complications of using only a single-head attention layer.

  • We know that a single-head attention layer returns only “a” vector representation(or meaning) of the word ‘bank’.
  • This means that the single-head attention layer will not be able to differentiate the meaning of the word ‘bank’ in both contexts(sentence A and sentence B).

Either the single-head attention layer thinks that ‘bank’ is the sloping land along the river or ‘bank’ is a place for people to deposit their money. Not both.

So how do we handle words that can have more than one definition/meaning

  • We know that a single-head attention mechanism’s purpose is just to get “a” vector representation(meaning/definition) of each word in the sentence
  • if we want to handle words that have more than one meaning, intuitively we just need to increase the number of attention layers!

Referencing the example above:

Referring to the two sentences above containing ‘bank’ which have two meanings, we will need two attention layers:

  • One attention layer for the word ‘bank’ which means a place for people to deposit their money.
  • Another attention layer for the word ‘bank’ which means a sloping land along the river

Therefore, we will need 2 attention heads to handle words that have more than one meaning in different contexts, and because we are using more than one attention head, we are using multi-headed attention.

In summary:

  • single-head attention mechanisms cannot handle words that can have different meanings in different contexts.
  • We then require multi-headed attention to handle words that can have different meanings in different contexts.

Implementation in Pytorch

  • Honestly, there is not much difference between single-head attention and multi-head attention.

The things to take note of are:

(1) The use of a linear layer to project embedding of size (B,T,E) to (B,T, 3 * dim_head * heads).

  • Where ‘3’ is for the query tensor, key tensor, and value tensor.
  • Each of the tensors is shaped (B, T, dim_head * heads).

(2) Rearranging the ‘heads’ dimension, from (B,T, dim_head * heads) to (B, heads, T, dim_head)

  • This is done by reshaping tensor from (B,T, dim_head * heads) to (B,T, heads, dim_head) and then transposing to (B, heads, T, dim_head)
  • We need to do these because for a single-head attention layer, the query, key, and value tensors are of size (B, T, dim_head), or equivalently (B, 1, T, dim_head).
  • Multi-head means we have more than one head. So we essentially append the heads on dimension on dim=1.
  • Single-head: (B, 1, T, dim_head), multi-head: (B, h, T, dim_head)

(3) Before passing the output from the attention on the value tensor, we need to concatenate the heads

  • The output from the attention on value tensor is of shape (B, h, T, dim_head).
  • To concatenate the heads, we reshape (B, h, T, dim_head) -> (B, T, h, dim_head), and then flatten from (B, T, h, dim_head) -> (B, T, h * dim_head)
  • This is then used as input to the projection layer

To facilitate my learning on torch.einsum and einops, I implemented the multi-head attention using einops and torch.einsum as well

(1) # Implementation using `.reshape` and `.transpose`

class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=16, dropout=0):
super().__init__()
inner_dim = dim_head * heads
self.inner_dim = inner_dim
self.scale = dim_head ** -0.5 # divide by sqrt of dimension

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)

self.heads = heads
self.dim_head = dim_head

self.residual_dropout = nn.Dropout(dropout)

def forward(self, x):
B, T, E = x.shape

# get q, k , v
q,k,v = self.to_qkv(x).chunk(3, dim=-1) # (B,T,3 * inner_dim) -> (B,T, inner_dim)
# separate the heads
q = q.reshape(B,T, self.heads, self.dim_head).transpose(1,2) # (B, T, 3 * inner_dim) -> (B, H, T, dim_head)
k = k.reshape(B,T, self.heads, self.dim_head).transpose(1,2) # (B,T, 3 * inner_dim) -> (B, H, T, dim_head)
v = v.reshape(B,T, self.heads, self.dim_head).transpose(1,2) # (B,T, 3 * inner_dim) -> (B, H, T, dim_head)
# calculate similarity score
sim = q @ k.transpose(-1,-2) # (B,H,T,inner_dim), (B,H, inner_dim, T) -> (B, H,T, T)
# normalize
sim = sim * self.scale

# calculate probabilities
sim = F.softmax(sim, dim=-1)
post_softmax_attn = sim

# apply dropout
sim = self.dropout(sim)

# calculate attention on V
attention = sim @ v # (B, H, T, T), (B,H, T, dim_head) -> (B,H, T, dim_head)
# apply linear layer
# flatten the heads
attention = attention.transpose(1,2).reshape(B,T,-1) # (B,H,T,dim_head) -> (B,T, inner_dim)
attention = self.to_out(attention) # (B,T,inner_dim) -> (B,T, dim)
attention = self.residual_dropout(attention)
return attention, post_softmax_attn

x = torch.randn(2,3,4)
mha = MultiHeadAttention(dim=4, heads=8, dim_head=16, dropout=0)
mha_out = mha(x)
mha_out[0].shape

(2) # Implementation using einsops and einsum

class att_einops(nn.Module):
def __init__(self, dim, heads=8, dim_head=16, dropout=0):
super().__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5 # (divide by square root of head)
# self.to_qkv = nn.Linear(dim, inner_dim * 3) # 3 because q,k,v
self.to_qkv = mha.to_qkv # 3 because q,k,v


# linear layer to project back to original dimension
# self.to_out = nn.Linear(inner_dim, dim)
self.to_out = mha.to_out # 3 because q,k,v

self.dropout = nn.Dropout(dropout)

def forward(self, x):
q,k,v = self.to_qkv(x).chunk(3, dim=-1) # (B,T, dim) - > 3 * (B,T, inner_dim)
h = self.heads
# q = rearrange(q, 'b t (h d) -> b h t d', h=h)
# separate the heads
q, k ,v = map(lambda x: rearrange(x, 'b t (h d) -> b h t d', h=h), (q,k,v))

# calculate similarity score
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

# scale
sim = sim * self.scale

# normalize
sim = F.softmax(sim, dim=-1)

# apply dropout
sim = self.dropout(sim)

# calculate attention
attention = torch.einsum('b h i j, b h j d -> b h i d', sim, v)

# project to original dimension
attention = rearrange(attention, 'b h t d -> b t (h d)')
attention = self.to_out(attention)
return attention
att_einops = att_einops(dim=4)
einops_out = att_einops(x)
einops_out.shape

## show that they are equal
assert torch.equal(einops_out, mha_out) == True

Conclusion

(1) We know what does it mean by multi-head attention

  • Multi-head just means the number of attention layers we want to use because the projection layer
  • Multi-head means that we are using more than one attention layers

(2) We know what are the limitations of single-head attention

  • Single-head attention mechanisms cannot handle words that can have different meanings in different contexts.

(3) We know why is there a need for Multi-head Attention

  • We require multi-headed attention to handle words that can have different meanings in different contexts.

--

--