Reformers and Performers: A comprehensive summary of space and time optimizations on Transformers (Part — 2)

Priya Shree
Walmart Global Tech Blog
13 min readOct 1, 2021

In the previous article, we discussed optimizations and heuristics used by two models, namely, sparse Transformers and Longformers, to overcome quadratic time and space used by Transformer models. With optimizations mentioned in papers for these models, the authors could not only perform equivalent to transformers in lesser space and time, but also could achieve better performance in certain density estimation tasks. In this article, we will look at two more such models and discuss optimizations and techniques leveraged by them which are mentioned in following papers:

1. Reformer: The Efficient Transformer
2. Rethinking Attention with Performers

Let us begin by discussing these models one by one. I am dividing this article into three sections for better readability: feasibility and motivation for the paper, methods used for reducing the complexity and derivation or explanation for that reduction, and performance summary for each model.

Reformer: The Efficient Transformer

1. Feasibility and Motivation: The authors of this paper cited following three major compute intensive areas in transformers which could be optimized:

  • O(n²) time and space complexity of scaled-dot product or self-attention operation of transformers, ‘n’ being the length of the sequence.
  • High memory requirements due to large number of layers in transformers: For deep learning models like transformers, activations of all layers need to be in memory for backpropagation. This causes memory cost of such models to be proportional to number of layers.
  • Memory requirements by feed-forward layers : Feed-forward layers in transformers have more depth (no. of nodes or hidden units in a layer) than attention layers. This causes feed-forward layers to consume a lot of memory for large sequence lengths.

2. Method and Explanation for Reduction in complexity: Let us discuss the techniques mentioned in the paper to address the issues with transformer architecture stated above. I will be discussing three techniques used in the paper, namely, Locality Sensitive Hashing (LSH), Reversible Residual Networks (RevNets) and Chunking. While discussing these methods, I will first explain the techniques and then mention how they were used by the paper.

A. Locality Sensitive Hashing (LSH)

(i) Technique : Locality Sensitive Hashing is a technique which is used to find or group similar data points. It is widely used to reduce complexity of algorithms like k-Nearest Neighbors (kNN). Finding nearest neighbor of a data point ‘x’ using LSH does not require comparing it to all ‘n’ data points available. Instead, ‘x’ is compared to only few data points, which are in the same “bucket” or “region” as ‘x’. This reduces the complexity of finding nearest neighbor of a single data point from O(n) to O(# points in same hash bucket as ‘x’). Let’s understand LSH in a little more detail.
In LSH, we divide the d-dimensional space in which all the data points exist, using ‘k’ random hyperplanes (or a hash function). Thus, the space is roughly divided into 2ᵏ regions or buckets in which data points fall. We want only similar data points to land in same bucket or region. To achieve this, we don’t partition the space just once, but ‘r’ times, r being the number of times we partition the space and assign points to bucket. This reduces the chance of comparing two dissimilar points which could have landed in same bucket by chance.

(ii) Complexity of LSH: To find nearest neighbor of a point ‘x’ using LSH, we compare d-dimensional point ‘x’ to all ‘k’ hyperplanes and decide the region in which the point lands. This results in O(dk) complexity. We then compare ‘x’ to all points in that region. On average, number of points in any region will be n/2ᵏ. Thus, the complexity of this step will be O(#points in a region * cost of 1 comparison) = O((n/2ᵏ) * d). Since, we repeat this process ‘r’ times, the overall complexity becomes O( rdk)+O(r*(n/2ᵏ) * d ). If we chose ‘k’ to be equal to log(n), the complexity of finding nearest neighbor for a single point becomes O(log(n)) which is an improvement over O(n) complexity of vanilla kNN.

(iii) Locality Sensitive Hashing in Reformers: The quadratic complexity of transformers comes from comparing each query to all the keys in the sequence or each word to all other words in the sequence. Reformers overcome this quadratic complexity by comparing each query to only few similar keys in the sequence. To identify keys similar to any particular query, the authors use LSH. This reduces the complexity of attention operation to O(nlog(n)) from O(n²), where ‘n’ is the length of the sequence. To ensure that equal sized buckets are formed, and operations are optimized while implementing LSH in transformers, authors use following two important heuristics:

  • Shared-QK attention : The keys and queries are made identical for LSH attention. This is achieved by using same linear layer for obtaining query and key vectors and separate layer for value vectors. This is required because queries in a hash bucket attend to keys present only in that bucket. Using same key and query vectors for each data point ensures that similar queries and keys land in same bucket.
  • Chunking : The authors narrowed down the number of data points each data point attends to by using LSH. However, points landing in same bucket might be far away in sequence. This would lead to added complexity of identifying locations of similar points in sequence and then performing attention operation. To overcome this challenge, authors proposed sorting queries by bucket numbers and within each bucket, the queries are sorted by their position in sequence. After sorting, the authors proposed dividing queries into chunks of size ‘m’. A query ‘q’ belonging to hash bucket ‘b’ and in chunk ‘c’ can attend only to following queries:
    - Queries in same chunk ‘c’ which belong to hash bucket ‘b’.
    - Queries in preceding chunk ‘c-1’ which belong to hash bucket ‘b’.

(iv) Reduction in Complexity of Reformers due to LSH : A query attends to only limited set of queries (or keys) in sequence, reducing the complexity of transformers from O(n²) to O(n*logn), where n is the length of the sequence.

B. Reversible Residual Networks (RevNets)

As discussed earlier in this article, deep learning models having many layers require more memory as activations of all layers need to be present in memory for backpropagation. We had discussed a technique called Gradient Checkpointing in previous article which counters this complexity to some extent. Gradient checkpointing reduces the memory requirements of a deep learning model having ‘L’ layers from O(L) to O(√L ). Reformers use another technique known as RevNets which makes complexity of deep learning models independent of number of layers. Let’s understand it below.

(i) Technique: RevNets re-architect the forward propagation in such a manner that activations of layer ‘l’ can be reconstructed exactly and using only layer l+1’s activations during backpropagation. Thus, to perform backpropagation on a deep network, all that is required in memory is last layers' activations. Using activations of last layer of a model, we can reconstruct activations of all preceding layers at the cost of just one additional forward pass. This makes the memory or space complexity of reformers independent of number of layers in model. To achieve this gain in complexity, RevNets partition the input ‘x’ into two groups ‘x₁’ and ‘x₂’. There are several ways in which input can be partitioned and more about it can be understood from this paper. Each reversible block in RevNet takes inputs (x₁,x₂) and produces output (y₁,y₂) using following equations:

RevNets can easily reconstruct activations back using following equations:

Notice that x₂ is reconstructed before x₁ as reconstructing x₁has a dependency on x₂. Thus, with RevNets, activations of all layers need not be in memory at once, they can be reconstructed back, one layer at a time, to be used by backpropagation.

(ii) RevNets in Reformers: RevNets are used in transformers by combining and placing attention and feed-forward layers in a single RevNet block. Layer normalization is also moved inside the revnet block. In equations 1 and 2, F represents attention layer and G represents feed-forward layer in transformers.

(iii) Reduction in Complexity of Reformers due to RevNets : Using RevNets in reformers makes memory requirements of reformers independent of number of layers, which is a significant optimization.

(c) Chunking

(i) Technique: Feed-forward layers in transformers have large number of hidden units which consume a lot of memory. With state-of-the-art transformers using very large corpuses as single sequence, computations in feed-forward layers become even more very costly. To counter this, reformers chunk positions in the sequence and instead of performing the feed-forward operation on entire sequence at once, they perform it on chunks of positions. The model also chunks reverse computations and backward pass.

(ii) Reduction in Complexity of Reformers due to Chunking : Chunking brings down memory requirements considerably, as memory is occupied by one chunk at a time instead of entire sequence.

3. Performance and Summary : Reformers take lesser time to train than transformers and the paper reports performances comparable to transformers in density estimation and machine translation tasks. Techniques like shared query-key attention and reversibility in network for re-computing activations do not compromise on metrics like accuracy or perplexity. Instead they help reformers achieve at par performance with transformers in lesser time and space. Though the assumptions of sparsity made by reformers through LSH may not be sufficient for extremely large sequences and reformers might need more heuristics to achieve state-of-the-art performance on very long sequence lengths, nevertheless, reformers provide significant optimisation on transformers.

Rethinking Attention with Performers

  1. Feasibility and Motivation : The models we have discussed so far approximate the attention matrix computed by transformers by making assumptions of sparsity or low rankness of attention matrix. Though these models reduce the complexity of transformers through these approximations, these approximations are biased as they depend on priors and assumptions. Performers overcome this limitation by proposing an algorithm which can generate “unbiased or nearly-unbiased” estimate of attention matrix with very low variance, without making any assumptions of sparsity. Infact, the paper states that other models based on sparsity assumptions do not approximate attention as implemented by transformers, instead they propose different and simpler attention mechanisms. Let’s discuss the algorithm used by performers in next section.
  2. Method and Explanation for Reduction in complexity: Performers aim to approximate attention not by making any assumptions of sparsity, but by decomposing the attention matrix. However, attention operation in transformers applies softmax non-linearity over dot-product of query and key vectors which makes decomposition of attention matrix difficult. This is why most models resort to assumptions of sparsity for approximating attention. In this paper, the authors proposed a novel approach called as Fast Approximation via Orthogonal Positive Random Features (FAVOR+), using which they can approximate softmax kernels (and decompose attention matrix) in an unbiased manner in linear time and with very low variance.
    The paper discusses the formulations and theoretical guarantees for their approximations in depth using a lot of mathematical equations. However. I’ll be touching upon only the concepts mentioned in the paper without delving into much maths. I will be discussing FAVOR+ algorithm in two subsections. In first subsection I will explain Fast Attention part of FAVOR+ algorithm. In second subsection, I will discuss Positive Orthogonal Random Feature part of FAVOR+ algorithm.

A. Fast Approximation of Attention Matrix (FAVOR+)

Let’s understand kernel trick as a premise for understanding Fast Approximation technique.

Kernel Trick : Many models use kernel trick to estimate dot product of two vectors x and y in high dimensional spaces. Kernel trick estimates the dot product between two vectors in some high dimensional space without explicitly mapping the vectors x and y into that high dimensional space. Instead, it uses a function, called kernel function, which approximates the dot product of x and y in that high dimensional space, i.e.

Thus, using the appropriate kernel function on vector x and y, we get an approximation of their dots products in high d space.

Let’s now move on to understand fast approximation of attention.

Fast Approximation : For transformers, attention A is given by

where Q, K and V are query, key, and value matrices of dimension batch_size*length_sequence*dimension_model (BS* N*d). If we ignore the normalization terms, the softmax term in equation for attention can be written as

Instead of using a kernel function to approximate mapping of vectors in some other dimensional space (right hand side of equation), FAVOR+ approximates the mapping in that other dimensional space to estimate the kernel. FAVOR+ provides a generalized kernelizable attention framework, as it can approximate not only softmax-attention kernels but also other attention kernels. This also overcomes the restriction of using only softmax attention kernels in transformers and opens avenues for experimenting with other attention-kernels like ReLu, which can be approximated using FAVOR+.

Reduction in Complexity of Reformers due to FAVOR+ (Fast Approximation): FAVOR+ decomposes the softmax-attention kernel given in equation 7 into matrices Q and K of dimensions (BS*N*r), where ‘r’ is the model dimension or features in new space. After this decomposition, we can reframe the attention equation in 7 as follows:

This reformulation allows multiplication of K' and V first followed by multiplication by Q' due to which the computations come down from O(N² . d) in regular softmax attention to O(N.r.d) in performers. Thus, performers can approximate softmax-attention in linear time complexity without any assumptions or priors.

B. Orthogonal Random Positive Features in FAVOR+

FAVOR+ approximates kernels using functions which can map input x to Φ(x). These functions take data-points and few random features which can represent the data-points in the space Φ and use those features to map the data-points to Φ. To approximate softmax-attention kernel for transformers, FAVOR+ specifically uses functions which generate positive random and orthogonal features. Following points summarize why it is essential to use positive and orthogonal random features for estimating softmax kernel in transformers:

  • Using Positive Random Features: Softmax kernels always produce positive values for attention. Hence it becomes important to use functions which generate positive features for decomposing the kernels. Also, often a data point attends to very few other points in sequence, resulting in several entries as 0 in attention matrix. Choosing functions which generate negative or zero values give unstable approximations for zeros in attention matrix. Thus, it is imperative to use functions which generate positive random features to approximate softmax kernel robustly.
  • Using Orthogonal features: To estimate softmax kernel for transformers, the paper uses features which are drawn from normal isotropic distribution and are strictly orthogonal to each other. Exact orthogonal features help in reducing the variance of approximated kernel. They also help reduce the dimensionality in new space Φ, as orthogonal features drawn from isotropic distributions can represent data in lesser dimensions.
    The paper provides theoretical proofs and results which show that these heuristics help FAVOR+ obtain unbiased or nearly-unbiased estimate of attention kernel with lesser variance.

3. Performance and Summary: The capabilities of performer model can be summarized as follows:

  • Performers are fully compatible with transformer architectures as they just modify the attention matrix computation part of transformers, leaving other blocks and components exactly the same. Due to this capability, performers can utilise the pre-trained weights of transformers to improve the performance of model.
  • Performers reduce the time and memory requirements of transformers significantly, from quadratic in transformers to nearly linear in performers.
  • Reduced memory and time complexities in performers allow training the model on larger sequence lengths and batch size. This also allows training performers on machines with lesser bandwidth. This overcomes the requirements of very expensive compute for training documents with very large sequence length.
  • Performers achieve performance comparable to Reformers and Linformer on common datasets like Imagenet64 in lesser number of layers. This is because performers approximate complete attention matrix without making assumptions. Also performers are faster to train than other optimisations on transformers like reformers and linformers.

Ability of performers to train with lesser resources and complexity has opened avenues for training deep transformer models on large corpora of texts, making them even more powerful and efficient. Since performers can train on larger sequence lengths due to linear complexity, performers can be used to model complex tasks like protein sequence modelling. The paper mentions other benefits of performers like they are more environment friendly due to lesser compute requirements, they expedite research on transformers due to lesser complexity, they are backward compatible with transformers as they can be used for inference on pre-trained transformer models etc.

Ending Note

With this, we come to an end on our discussion on optimisations of transformer models. As mentioned in the beginning of this series, the purpose of this series was to learn about several techniques which can be leveraged while optimising on models and appreciate the efforts and heuristics used by the authors in the papers discussed. There are also other optimisations on transformers which I haven’t covered in this series and can be explored, like Routing Transformers, BigBird, Linformers etc. Contributions and optimisations like these contribute significantly to development of efficient and powerful models in deep learning and NLP.

References

1. Kitaev, Nikita, Łukasz Kaiser, and Anselm Levskaya. “Reformer: The efficient transformer.” arXiv preprint arXiv:2001.04451 (2020).
2. Choromanski, Krzysztof, et al. “Rethinking attention with performers.” arXiv preprint arXiv:2009.14794 (2020).
3. https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html
4. https://ai.googleblog.com/2020/10/rethinking-attention-with-performers.html

Link To Previous Article : https://medium.com/walmartglobaltech/sparse-transformers-and-longformers-a-comprehensive-summary-of-space-and-time-optimizations-on-4caa5c388693

--

--