Sparse Transformers and Longformers: A comprehensive summary of space and time optimizations on Transformers(Part — 1)

Priya Shree
Walmart Global Tech Blog
11 min readSep 29, 2021

Note : This article assumes familiarity with transformer architecture and basics of Natural Language processing. Transformers can be understood in depth through this paper or blog.

Transformer models have brought a revolution in field of Natural Language Processing (NLP) by introducing self-attention as a powerful mechanism to capture context and dependencies in sequences. Through self-attention, each word in a sequence can attend to all other words, determining how related any two words are in the context of sequence. This capability enables models with transformer architecture to achieve state of the art in various generative and discriminative NLP tasks.

Figure 1 : Equation for scaled dot-product or self-attention operation in transformers

However, with great power comes great ̶r̶e̶s̶p̶o̶n̶s̶i̶b̶i̶l̶i̶t̶y̶ computation complexity and memory requirements. For a sequence of length n, the QKᵀ operation in self-attention takes O(n²) time and results in a matrix of dimension n*n. This quadratic time and space complexity becomes a limiting factor with increasing sequence length.

Generative Pre-Training (GPT) and other deep transformer models, which use enormous amounts of text to train powerful language models, use several heuristics and very high bandwidth machines with multiple GPUs to avoid running out of memory. Requirements of such high compute to train transformers with very large sequences has motivated development of models which combat the O(n²) complexity. This and a subsequent article will discuss few of those models and how they achieved complexities lesser than O(n²). I will be covering techniques discussed in following papers:

1. Generating Long Sequences with Sparse Transformers
2. Longformer: The Long-Document Transformer
3. Reformer: The Efficient Transformer
4. Rethinking Attention with Performers

These models will be discussed over two articles. In this article (first article in series), I will be discussing optimizations used in Sparse Transformers and Longformers. The second article will be about Reformers and Performers. The motivation for these articles was to explore different techniques that can be leveraged to cut down the complexity of models rather than learn about the state-of-the-art model with minimum complexity. Thus, I will be discussing only the major optimizations which brought down the complexity for each of these models.

It is important to note that many of these models also used several heuristics and changes in transformer architecture along with optimizations on self-attention to reduce complexity. But scope of this article is to touch upon only important optimizations.

Before we move on to look at these models, let’s discuss an important technique called Gradient Checkpointing that is used in deep learning models to reduce memory requirements.

Gradient Checkpointing

Updating weights of a layer ‘l’ in a neural network through backpropagation requires activations of all subsequent layers connected to ‘l’, to be in memory. In other words, node Q’s activation can be removed from memory only when all the nodes in preceding layers which will be using Q’s activation to update their weights, have their weights updated. Thus, the memory required for backpropagation becomes proportional to the number of layers and grows significantly for deeper models.

Gradient checkpointing counters this memory constraint by keeping activations of only few layers, marked as checkpoints, in memory. Activations of remaining layers can be computed again using activations of checkpoints. If the choice of checkpoints is optimal, this method results in significant reduction in memory requirement at the cost of just one additional forward pass. For e.g., if activations of every sqrt(L)ᵗʰ layer are stored in memory (L being total number of layers), then memory complexity of backpropagation becomes proportional to O(√L) which is way less than O(L). For nodes in layers between the checkpoints, just one forward pass is required to recompute activations for all those nodes. In figure 2 below, activations A⁽¹⁾, A⁽⁴⁾, A⁽⁷⁾ can be checkpointed in memory and then can be used to recompute activations of A⁽²⁻³⁾, A⁽⁴⁻⁵⁾ and A⁽⁸⁻⁹⁾respectively.

Figure 2 : Checkpoints marked at every sqrt(L) layer (L=9 in this figure)

Gradient Checkpointing in Transformer Models: As discussed earlier, a single self-attention matrix takes O(n²) space. With very large number of layers, having attention matrix for each layer in memory becomes infeasible for systems with moderate computation power and memory. Therefore, deep attention models checkpoint self-attention matrices of few layers and recompute matrices for rest of the layers from checkpoints. This brings down memory requirement significantly at slightly increased time complexity of one additional forward pass.

Let’s now understand the techniques used by the first two papers I had mentioned above. To make this article more comprehensible I am dividing discussion for each paper into three segments: feasibility and motivation for the paper, methods used for reducing the complexity and derivation or explanation for that reduction, and performance summary for each of the model. Throughout this article, ‘n’ refers to the number of words in a sequence.

Generating Long Sequences with Sparse Transformers (OpenAI)

1. Feasibility and Motivation : As we saw in the previous section, gradient checkpointing can reduce memory requirements by transformers considerably. However, for very long sequences, computing even a single attention matrix becomes infeasible.
To address this issue, the authors proposed the concept of sparse or factorized attention patterns. Analysis of attention patterns learnt by different layers of transformers showed that though few layers had global attention span, rest of the layers mostly attended to very few and fixed data points. This motivated the authors to train sparse transformers, where instead of attending to all others words in the sequence, a data point attends to only few other points, hence mitigating the O(n²) complexity.

2. Method and Explanation for Reduction in Complexity : The authors introduced two methods to factorize attention patterns for training sparse transformers. Though each method can have ‘p’ different kernels or attention heads to which each position attends in sequence, the paper proposes two attention heads for each type of attention pattern. Let’s have a look at them below:

a. Strided Attention: In this type of attention, each position ‘i’ roughly attends to other positions in its own row and column. The paper mentions following two kernels, denoted by Aᵢ , to implement strided attention.

(i). Aᵢ = {t, t+1, t+2, …, i}, where t=max(0,i-l) and ‘l’ is the stride length. Through this kernel, each position ‘i’ attends to ‘l’ contiguous positions preceding it.
Complexity: The stride length is chosen to be close to O(√n), so that the complexity comes down to O(n√n).

(ii). Aᵢ = {j: (i-j)%l =0}, where ‘l’ is the stride length. Let’s understand this kernel through an example. If i= 83 and l= 8, then position ‘i’ attends to all such positions where i-j is a multiple of 8. Thus, j={3+(0*8), 3+(1*8), 3+(2*8), …, 3+(9*8)} or i is attending to positions with a gap of ‘l’ in between.
Complexity: No. of elements in set Aᵢ for each position ‘i’ would be ⌊i/l⌋. Thus, for all ‘n’ words in sequence, ∑⌊i/l⌋ = (∑⌊i⌋)/l = n(n+1)/2l = O(n√n). (The limit of summation is from 1 to n).

b. Fixed Attention: Through this attention pattern each position ‘i’ looks at specific positions in sequence. Let’s understand the two kernels proposed for fixed attention.

(i). Aᵢ = {j: ⌊i/l⌋ = ⌊j/l⌋}, where ‘l’ is the stride length. Let’s take an example to understand this kernel too. If i = 83 and l= 8, then substituting it in ⌊i/l⌋ = ⌊83/8⌋ = 10. So, for ⌊j/l⌋ = 10, j={80, 81, 82, …, 89}. In case of autoregressive models, where each position in sequence can attend to only positions preceding it, j = {80, 81, 82, 83} for i=83 and j = {80, 81, …, 87} for i=87. Through this kernel too, a position is attending to contiguous positions preceding it.
Complexity: No. of elements in set Aᵢ for each position ‘i’ would be at max ⌊i/l⌋. Thus, for all ‘n’ words in sequence ⌊i/l⌋ = ( ⌊i⌋)/l = n(n+1)/2l = O(n√n). (The limit of summation is from 1 to n).

(ii). Aᵢ = {j: j%l ∈ t, t+1, … , l}, where t=l-c, ‘l’ is the stride length and c is a hyperparameter. Through this kernel a position attends to certain fixed positions, independent of it is current position ‘i’ (note the absence of i in equation for kernel). To understand this, let’s assume i=130, l=128 and c=8. Then, t=(l-c)=120 and Aᵢ = {j: j%128 ∈ 120,121,…,128} or Aᵢ={120,121,…,127, 248,249,…,255, …}. Autoregressive or casual models attend to only preceding positions (‘i’ will attend to only positions 120 to 128 as per our example). Bi-directional models can attend to future positions too.
Thus, a position attends to contiguous blocks of ‘c’ positions at different locations in sequence. This form of attention is also known as dilated sliding window attention.
Complexity: Each position i attends to ⌊i/l⌋*c positions. Thus, ( ⌊i/l⌋ * c) = c * ( ⌊i/l⌋) = O(n√n). (The limit of summation is from 1 to n).

3. Performance and Summary: Sparse transformers performed empirically well on density estimation tasks and achieved lower bits per dimension and state of the art results on CIFAR-10, enwik8 and ImageNet 64 datasets. This model reported bits per byte of 2.80, 0.99 and 3.44 on these datasets respectively, which was comparable or better than transformer model.
The paper reported losses lower than transformer models and lesser time to train due to sparse connectivity patterns. The model also performed well in tasks like image generation and raw audio waveform generation, which showed the models’ capability to learn what comes next in sequence even with sparse attention patterns. However, the model did not perform very well in generating very high-resolution images and audio forms and motivates exploration of more optimal sparsity patterns in attention matrices.

Longformer: The Long-Document Transformer (Allen Institute for AI)

1. Feasibility and Motivation : Longformers aimed to introduce a notion of global attention, where a data point attends to all other data points in sequence, in addition to local and sparse attention as introduced by Sparse Transformers. The paper suggested that a combination of local and global attention could act as a “drop-in replacement” for full self-attention matrix used by the transformers.

2. Method and Explanation for Reduction in complexity : The paper proposes following mechanisms of attention to reduce complexity to O(n).

a. Sliding window Attention : In this mechanism, each data point in the sequence attends to ‘w/2’ data points on both sides of it, ‘w’ being the size of window. The size of the window does not remain constant across all the layers, but increases as we move deeper into the network. A layer ‘l’ has window size of l*w, allowing it to have wider local attention.
Complexity: For a sequence of length ’n’, the complexity of sliding window attention is O(n*l*w), which is linear in ‘n’.

b. Dilated Sliding Window attention : To increase the attention span of each data point without an increase in complexity, gaps of size ‘d’ are introduced in the windows in each layer. A data point still attends to w/2 tokens on either side but those tokens are not contiguous anymore. They have gaps of ‘d’ between them. This increases attention span of a token significantly as now words can attend to very far off words. ‘d’ can be set to different values for each layer.
Complexity: For a sequence of length ’n’, the complexity of dilated sliding window attention is O(n*l*d*w), which is linear in ‘n’.

c. Global attention : Longformers introduce global attention only for few tokens in the sequence, thus keeping the complexity to linear and harnessing benefits of few important words attending to all words in sequence. The tokens for global attention are decided based on the nature of task. For example, for question answering tasks, the tokens in question are chosen for global attention. Similarly, in classification tasks, the classifier (CLS) token at the end of sequence attends to all tokens in sequence. To improve the performance, global attention operation is made symmetric, i.e., tokens chosen for global attention attend to all tokens in the sequence and all tokens in sequence attend to them.
Complexity : Since very few tokens are chosen for global attention, the complexity of this operation remains constant in ‘n’ and overall complexity remains linear in ‘n’.

d. Linear projections for global attention: Unlike transformers which use single set of Q,K and V vectors, longformers use separate sets of linear projections for sliding window (Qₛ, Kₛ and Vₛ) and global (universal) attention (Qᵤ, Kᵤ and Vᵤ). This helps the models learn different types of attention pattern using the same network and leads to improvement in its performance.
Complexity : There is no additional complexity for using different set of linear projections for sliding and global attention patterns. The complexity still remains linear in ‘n’ due to this operation.

3. Performance and Summary: Longformers established state-of-the-art for text8 and enwik8 datasets by achieving test BPC of 1.10 and 1 respectively. The model outperformed the Transformer-XL model, matched the performance of Sparse Transformers, and performed only slightly less or equivalent to other language models having approximately double the numbers of parameters than longformers. Longformers were also fine-tuned for several downstream tasks like coreference resolution, document classification etc. and performed well on those too.

Ending Note

One might wonder what is working for these models. How are they able to perform comparable to transformers without computing the full attention matrix? Authors of both these papers mention that their assumptions about sparsity or combination of local and global attention serve as inductive bias for models which helps them in performing at par or better than transformers.

To summarize, both these models performed at par or better than transformers while cutting down on memory and space requirements. In the next article we will discuss two more models, namely Reformers and Performers, and discuss optimizations used by them to curb quadratic complexity while ensuring at par or better performance.

Appendix

1. Dilated sliding window attention : When a position pays attention to other positions through a sliding window, but gaps are introduced between sliding windows so that each position can attend to farther locations, such attention is known as dilated sliding window attention.

2. Density estimation or Density Modeling: Generative modeling or density modeling refers to the class of algorithms which aim to learn the distribution underlying data i.e., p(x; θ), where θ are the model parameters.

3. Bits per dimension / BitsPerPixel (bpp) / BitsPerCharacter (bpc) / Bits per byte: Bits per dimension, bits per byte, bits per pixel and bits per character are evaluation metrics used to compare two models in density estimation tasks. For comparing models trained on image datasets we mostly use bits per pixel, for evaluation of language models or density models trained on text data we use bits per character. Bits per byte or bits per dimension can also be used generically to report performance of generative models.
Simply put, bits per pixel is the number of bits required to encode discrete input images. For e.g., if a model reports bpp of 0.99, it means at least 0.99 bits are required to encode or represent each pixel in input image. Similarly, bits per character can be understood as the average number of bits required to encode each character in input. To understand this concept in detail and how these metrics relate to loss achieved by model, please refer to this blog and this paper.

4. Inductive Bias : Inductive bias refers to the set of assumptions that a model makes to learn patterns in training data and make predictions on test data. More on inductive bias can be read here.

References

1. Child, Rewon, et al. “Generating long sequences with sparse transformers.” arXiv preprint arXiv:1904.10509 (2019).
2. Beltagy, Iz, Matthew E. Peters, and Arman Cohan. “Longformer: The long-document transformer.” arXiv preprint arXiv:2004.05150 (2020).
3. Theis, Lucas, Aäron van den Oord, and Matthias Bethge. “A note on the evaluation of generative models.” arXiv preprint arXiv:1511.01844 (2015).
4. Chen, Tianqi, et al. “Training deep nets with sublinear memory cost.” arXiv preprint arXiv:1604.06174 (2016).
5. Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems. 2017.

Note: For purpose of brevity, the links to blogs have not been repeated in references.

--

--