Sparse Transformers Explained | Part 1

Mahtab
6 min readApr 29, 2024

--

Capturing long-range dependencies in texts/audio/images requires a larger context length. Sparse Transformers¹ reduces the computation complexity of the Transformer networks. GPT-3 uses the Sparse Transformers architecture in their Transformers.

Typical Transformer³ networks constrain context length as the computational complexity rises quadratically with context/sequence length. The computation complexity arises from the attention calculation of query (Q), key (K) and value (V) matrices.

Figure 1. Attention Calculation Formula. Source: https://arxiv.org/abs/1706.03762

The attention formula computes mainly two matrix multiplications; firstly between K and Q and secondly between V and the former. If these vectors are of size BxLxE where B is the batch size, L is the sequence length/context length, and E is the embedding/head size, then the computation complexity of these two matrix multiplications can be computed as below -

Note: For two matrices of size Matrix A (MXN) and Matrix B(NxP), computation complexity = O(MNP)

Figure 2. Computational complexity Calculation of Attention Formula

if L>>E, then the complexity mainly depends on the quadratic term where L is the sequence length/context length as mentioned before.

The sparse transformer network aims to divide the attention mechanism into p blocks where each block is a small sub-span of length mL, where 0≤ m≤ 1, thus reducing the complexity by a factor of . This is denoted as factorized attention head in the paper. When calculating attention for the ith index, instead of choosing all indices where j ∈{j : j ≤ i}, the factorized attention head chooses p separate attention heads, where the mᵗʰ head defines a subset of the indices A(m)ᵢ ⊂ {j : j ≤ i} and |A(m)ᵢ| ∝ (√n)^(1/p).

In the paper¹, they choose m=2. So where |A(m)ᵢ| ∝ √n. The original indices for calculating attention for index i, j ∈{j : j ≤ i} will be broken into two subspans: A(1)ᵢ and A(2)ᵢ

Selection of these subspans must follow the below criteria —

For every j ≤ i pair, we set every Aᵢ that i can attend to j through a path of locations with maximum length p + 1.

This signifies that any token in the ith index should be able to attend to the 0th index in max (p+1) attention steps. Earlier, all indices ≤i were attended in 1 attention step as they were directly connected to index i. But, now the ith index is only connected to a subspan of indices. Let’s denote it as Aᵢ as in the paper. If p=2, then in the 2nd attention block, the information from index 0 (lowest) will reach index i. By efficiently choosing Aᵢ, this was achieved in the paper.

why (p+1) is chosen as the max length?

We can imagine attention as a fully connected graph. In a normal attention head, each vertex is connected to all other vertices as they attend to all other tokens. In the modified sparse attention layer, we are breaking this attention head into p attention heads. To achieve the original attention head’s behavior, information from the leftmost index (index = 0) should reach the ith index after these p attention heads. For example, {j, a, i} is a path of length 3 where p=2 and 0≤j≤i. In the first attention block, index a is calculated via attending over j. So index a has information about j by attending it. In the second attention block, index i attends to a which has again information about j. So, the ith index now has information about the jth index.

This is achieved by choosing indices in such a way that j ∈ Aₐ, a ∈ Aᵢ. Thus, the ith index gets information about the jth index in the 2nd attention block where the attention path = j -> a-> i, where length(attention_path) = 3 = 2+1= p+1. Thus, the original attention head’s behaviour is replaced with p separate attention heads called factorized attention heads in the paper.

They provide two mechanisms for selecting these sub spans -

Figure 3. Source: https://arxiv.org/abs/1904.10509
  1. Strided
  • A(1)ᵢ = First attention head attends to the previous l locations. Formally, A(1)ᵢ = {t, t + 1, …, i} for t = max(0, i − l)
  • A(2)ᵢ = Second attention head attends to every lᵗʰ location less than i. Formally, A(2)ᵢ = {j : (i − j) mod l= 0}

Figure 3 shows the attention pattern for index i (drawn in deep blue colour) where the A(1)ᵢ attends indices in blue colour and A(2)ᵢ attends indices in light blue colour. l is chosen as a value close to √n (generally 124,248).

How does this follow the criteria described before for choosing subspans?

Let, l=4. 
Index 15 to index 0(first index) path =
{0,a,15 : 0∈A(1)ₐ and a∈A(1)₁₅ | j∈A(1)₁₅}
This should be maintained to follow the criteria.

1.
i=15 (see the last row of second big block of Figure 3),
A(1)ᵢ=11,12,13,14,15
A(2)ᵢ=3,7,11 where, 15–3=12mod4=0, 15–7=8mod4=0, 15–11=4mod4=0
2.
i=3, (see the 4th row of second big block of Figure 3)
A(1)ᵢ=0,1,2,3
A(2)ᵢ=3 where, 3–3=0mod4=0

0∈A(1)₃ and 3∈A(1)₁₅ where a=3.
So,
the strided attention formula follows the criteria of max p+1 length.

Why did they choose such a pattern for formulating the attention head?

Figure 4: In layers 19 and 20, the network learned to split the attention across a row attention and column attention. Source: https://arxiv.org/abs/1904.10509

The above Figure:4 contains the learned attention patterns from a 128-layer network on CIFAR-10 trained with full attention. White highlights denote attention weights for a head. You can see that the white patterns form a row and column attention pattern which is same to the attention pattern in Figure 3 (see the above small blocks).

2. Fixed

Figure 5. Source: https://arxiv.org/abs/1904.10509

This pattern is used especially in texts where the attention patterns do not form row, or column patterns.

  • A(1)ᵢ = {j: ⌊j/l⌋ = ⌊i/l⌋} , where the brackets denote the floor operation
  • These cells summarize the information from the previous cells. A(2)ᵢ = {j : j mod l∈ {t, t + 1, …, l}, where t = l− c and c is a hyperparameter.
Let, l=4, c=1
mod_elems={t, t + 1, …, l} where t=l-c=3
mod_elems={3,4}={3} as 4 cannot be part of mod4

i=15 (see the last row of second big block of Figure 5)
A(1)ᵢ=12,13,14,15 where, 12//4=3, 13//4=3, 15//4=3
A(2)ᵢ=3,7,11 where, 3mod4=3, 7mod4=3, 11mod4=3

See the above calculation to understand the index calculation of this method.

Concretely, if the stride is 128 and c = 8, then all future positions greater than 128 can attend to positions 120–128, all positions greater than 256 can attend to 248–256, and so forth.

In the second part, how the attention heads (2 per index as the paper) are joined in the attention layer will be described.

Reference

  1. Child, Rewon, et al. “Generating long sequences with sparse transformers.” arXiv preprint arXiv:1904.10509 (2019).
  2. Brown, Tom, et al. “Language models are few-shot learners.” Advances in neural information processing systems 33 (2020): 1877–1901.
  3. Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).

--

--