Retentive Networks (RetNet) Explained: The much-awaited Transformers-killer is here

Shantanu Chandra
AI FUSION LABS
Published in
17 min readAug 4, 2023

Transformers have become the de-facto architecture for LLMs, as they efficiently overcome the sequential training issues of the recurrent neural networks (RNNs). However, transformers are not perfect either, as they solve for just two arms of the so-called “impossible triangle”. Well, the RetNet from Microsoft claims to sit right at the dead center of this impossible triangle trumping all the methods that have tried but failed to achieve this feat. Just to emphasize on the magnitude of breakthrough we are talking about here:

· RetNet has BETTER language modeling performance

· RetNet achieves that with 3.4x lower memory consumption

· ….8.4x higher throughput

· …15.6x lower latency

These are orders of magnitudes faster than current SOTA while also giving better performance!! This is going to be huge if other teams can reproduce this and if it hits the open-source arena, but for now Microsoft is definitely sitting on a golden egg!

Fig 1: RetNet makes the “impossible triangle” possible, which achieves training parallelism, good performance, and low inference cost simultaneously.

But the begging question is, what makes it so great? We will uncover the answers to this question in this blog post. We will slice open each equation to dive deeper and visualize what is happening. We will process the RetNet with a worked-out example to see how it dethrones Transformers and shows huge promise to be the new king of the North. Buckle up and put your thinking (read: basic Algebra) hats on. Lots of really interesting stuff coming your way!

Motivation — the impossible triangle

The “impossible triangle” represents that current sequence models/LLMs fail to achieve all the three desired dimensions of training parallelism, low-cost inference as well as strong performance simultaneously. The methodologies on the arms denote the two dimensions they achieve, while missing on the desired property of the third vertex. However, RetNet manages to achieve all the properties under a single framework.

Fig 2: RetNet achieves training parallelization, constant inference cost, linear long-sequence memory complexity, and good performance

Let us understand this in more detail as this is the core motivation of developing this architecture.

Training Parallelism

As the name suggests, RNNs process the sequence recurrently, i.e., one after the other in order. The processing of input at a time step depends on the hidden state of the previous time step, and hence can not be computed in parallel until all the previous steps have been processed. This slows down training significantly.

Since Transformers deploy the highly parallelizable self-attention mechanism, the outputs at each time step can processed in parallel using the Q,K,V matrices. However, this self-attention that helps transformers parallelize so well on GPUs becomes its greatest enemy at inference time as we will see later.

RetNet borrows the best of both worlds as it is equipped with three processing paradigms — parallel training, recurrent/chunk-wise inference. It adopts the parallelizable self-attention mechanism of transformers albeit some very neat tricks that help it ditch its shortcomings (more on this later)!

Inference Cost + Memory Complexity

Inference cost (per time step) refers to GPU memory, throughput, and latency while memory complexity refers to the scaling laws of the memory footprint with respect to sequence length. Since RNNs use simple and cheap operations like just matrix multiplications, their inference cost does not scale with sequence length but is constant (i.e., O(1)). At the same time, their memory complexity scales linearly with sequence length.

On the other hand, since transformers use self-attention blocks, they need to maintain the “NxN” matrix at inference time which you can see scales linearly in inference cost (O(N)) and quadratically in memory complexity (O(N2)).

Fig 3: RetNet achieves orders of magnitude better inference cost than transformers. Results of inference cost are reported with 8k as input length (image from original paper)

While RetNet use Transformer’s self-attention blocks to parallelize training and achieve state-of-the-art performance, it does not suffer from the aforementioned inference cost and memory complexity issues. This is due to its tweaked self-attention module, which it replaces with retention module + the recurrent inference paradigm that it uses to mimic self-attention at inference time.

Performance

The main advantage of Transformers over RNNs was their ability to process longer sequences without catastrophic forgetting thanks to their self-attention heads. RetNet achieves similar or better performance than Transformers and we will soon see how.

RetNet: overview

The main contributions of RetNet can be summed up in two broad points. However, the beauty lies in the details of how they get from point A to point B as we will discuss in detail subsequently:

1. RetNet introduces a multi-scale retention mechanism to substitute multi-head attention. This is the key to do away with the one component of self-attention mechanism that was the devil. Although, this retention mechanism has a minor theoretical drawback that I could think of based on an assumption by the authors. Hope someone can confirm/clarify that!

2. RetNet works on three computation paradigms compared to just one of Transformers that used the same sequence processing paradigm during both training and inference.

a. Parallel representation empowers training parallelism to utilize GPU devices fully.

b. Recurrent representation enables efficient O(1) inference in terms of memory and computation. The deployment cost and latency can be significantly reduced. Moreover, the implementation is greatly simplified without key-value cache tricks.

c. Chunk-wise recurrent representation can perform efficient long-sequence modeling. We parallelly encode each local block for computation speed while recurrently encoding the global blocks to save GPU memory.

RetNet vs Transformers

RetNet proposes to leverage the best of both worlds and shows how can we make this work. It uses the parallelizable training paradigm of Transformers instead of the inefficient and slow auto-regressive step of RNNs. We will see in a minute the subtle but neat changes it does to the original self-attention computation to replace it with retention instead during training. However, at inference time, RetNet adopts the more memory and compute efficient recurrent paradigm of the RNNs smoothly due to the retention mechanism instead of the self-attention.

Step 1: Parallel Representations for training

RetNet deploys the parallel representation learning of original transformers during training to move away from the restrictive auto-regressive sequence processing of RNNs. However, it makes a few changes in the overall process. Let’s see if we can spot them in the summary diagrams that I drew below:

Fig 4: Transformer computation on the left and RetNet on the right. Can you spot the difference?

We can see that RetNet ditches the softmax operation for a Hadamard product with a newly introduced D-matrix followed by GroupNorm operation. Isn’t that strange? Softmax operation was the entire base of self-attention from which the Transformers derived their state-of-the-art performance — the softmax gives relative attention weights to each token in the input sequence that helps the model to learn and retain long term dependencies. However, if you will recall, this softmax computation is the exact reason that causes the poor inference time performance of Transformers since they have to keep the softmax(Q.KT) in memory which is a NxN matrix and grows quadratically to the sequence length! The one thing that gives Transformers superior edge during training and downstream performance, is its biggest enemy during inference!

Many previous works have tried to circumvent this step by introducing ways to approximate this softmax operation, but the resulting architecture ends up suffering on performance. But then RetNet comes in with its magical D-matrix and GroupNorm which ends up performing similar or better than Transformers, while being magnitudes faster and more memory efficient during inference AND is able to train efficiently using parallelization during training as well !

So what is this D-matrix + GroupNorm operation and how does it help?

We will explore this in detail in the coming sections, but the short answer according to me is (please correct me if my reasoning is wrong or incomplete) :

Softmax in Transformers according to me achieved two objectives:

1. Weighting the different time-steps differently. This helped the model to “attend” to different parts of the sequence and pick up on the right signals. This was also one of the important contributors to their superior performance over RNNs. The proposed D-matrix takes care of this part, but with a limiting assumption (in my opinion). The D-matrix is a causal mask so-to-say with a defined pre-defined weighting factor baked in. Specifically, it prevents each time-step to attend to future steps, and at the same time it relatively weighs all the previous time-steps BUT in a pre-defined exponential way. The D-matrix assumes that more recent time steps are exponentially more important than the past ones, and thus deploys an exponential decay weighting of previous steps. Thus, while softmax was flexible enough to weigh different steps differently, D-matrix weighs all of -them in a fixed pre-defined way (exponential decay). While this is intuitive and may even be true for MOST sequential cases, it is still not as flexible as the softmax. But the trade-off was an efficient O(1) inference and O(N) memory complexity. And looking at the results, looks it this is indeed a very good approximation of the softmax operation in real-world use-cases! Interesting!

2. Introduce non-linearity. In the absence of softmax, the Q.KT operation is just an affine transformation that will vastly restrict its learning ability no matter how many of these layers you stack. The GroupNorm operation introduces the much needed non-linearity. Why GroupNorm? Well I guess the authors tried different things and it worked the best? Still do not have a solid explanation of this specific question. Let me know your thoughts!

The new Retention mechanism — details

The retention mechanism is in essence an amalgamation of the core principles of RNNs and Transformers: REcurrent + self-attenTION = RETENTION

Let us now look at the differences/similarities between Transformers as seen in Fig 4 in more detail.

If you recall, the original Transformers outputs are generated by first applying affine transformations to the input embedding X with the WQ, WK and WV matrices, followed by taking the softmax of the resultant (Q.KT) and finally multiplying the result with the V matrix. It looks something like this where O is the output matrix containing the contextualized embeddings of input matrix X:

Equation 1: The well-known Transformer self-attention computation. Nothing new here

Since RetNet operates both in recurrent as well as parallelized paradigms, the authors first motivate the RetNet “retention” block in a recurrent setting (i.e., processing each “n”th input element individually). They then vectorize their proposed recurrent retention block. Thus, the initial recurrent formulation looks something like this:

Equation 2: RetNet retention computation (recurrent). Refers to Eq1 of the original paper

We can clearly see that this looks very similar to the original Transformer formulation albeit a few changes. We see that the softmax has been replaced with the positional embedding term (pos). What exactly is this pos concept and what is it doing? Keep those questions on hold just for a little longer. We will dive deeper into it very soon. For now what is important to note is that RetNet replaces the softmax of original Transformer with the pos matrix. The above equation can be expanded as follows to be more informative about what is the pos doing:

Equation 3: RetNet retention computation expanded with the pos and pos’ definitions. Refers to Eq3 of the original paper

Where pos’ is the complex conjugate of pos. On further simplifying the above equation with the γ as a scalar value, we can easily parallelize this computation during the training iterations as follows:

Equation 4: RetNet retention computation (parallel). Refers to Eq5 of the original paper

We can clearly see that the first step of getting Q,K and V is the same as the original Transformer. Except, now we have the pos/pos’ embedding being multiplied element-wise to the Q and K matrices. We will get into the details of pos/pos’ and D-matrix in a second. But looking at the final step of this parallel training phase formulation, we can see that it closely resembles the original Transformer computation (albeit the softmax à D-matrix substitution), and thus is completely parallelizable (the D-matrix can be pre-computed as it is just a relative positional embedding + causal mask representation).!

Great! Now we know that RetNet with its minor changes can be trained in a parallel paradigm. But we still need to dive deeper into the pos/pos’ and D-matrix details.

Relative positional embeddings (pos/pos’)

We do not need to go too much into detail of these positional embeddings as they borrow their intuition and functionality from the original positional embeddings of Transformers/LLMs. But still, to get an idea of what exactly is happening in that equation put together, let us dive in a little.

We already know from Euler’s formula:

Equation 5: Euler’s formula in the complex plane

The Θ in Eq 4 above therefore encodes “relative positional information” into each of the vectors of Q and K matrices via vector rotation. This essentially makes them “position aware” and achieved via an Hadamard product between Q,K vectors and their respective position specific vector rotations as seen below:

Fig 5: The Hadamard product between (Q.KT) and Θ looks like this. The red arrows at each position are the rotation vectors as per Euler’s formula

The Qn and Km vectors at each position are rotated by the rotation vectors denoted by the red arrows. You can see from the accompanying vector rotation diagram in Eq5 above that einθ/ eimθ when n=m=1 have a single rotation. These are the rotation vectors Q1 and K1 positions.. Similarly for n=2,m=2 positions the vectors have a double rotation. The dot product between vectors with the same rotation (i.e., all the positions on the diagonals) will =1. Furthermore, when n=1,m=2 dot product is between two differently rotated vectors and will correspond to a specific positional value for the vectors in that position. Note that as we move further (eg, m=2, n=1,..,n), the vector dot products tend to 0 as the vectors tend to become orthogonal to one another.

Why do we care just about the bottom triangle in the figure above? This will be more clear in the coming sections, but the short answer is because at each time step we care about and want to “attend to” only the past time steps information.

Ok, so this is the part where Θ in Eq 4 is multiplied element-wise to each vector in Q and K to make them “position aware”. Next we will check out what does the proposed D-matrix do.

Causal masking and exponential decay (D)

The D-matrix acts as a causal mask as well as an exponential decay weighting scheme for past positions.

Equation 6: When the ordered vectors are in the past n<m, an exponential smoothing scheme is applied via γ ; for vectors in the future n>m, the weight is 0 and hence these time steps are not attended to

From the definition of D in Eq6 above, we can see that D in retention computation achieves the tasks that masked attention and softmax did in self-attention.

Masked-attention — causal mask: for positions where n>m, the vectors of (Q.KT) are multiplied by 0 to make sure that the causal assumption of sequence processing is in place. This ensures that information from future time-steps is not leaked.

Softmax-exponential decay: for positions where n<=m, the vectors of (Q.KT) are weighed with an exponentially decaying factor of γ. This means that the further a token is in the past, the less important it is for the current time step. This achieves the task of weighing the information of previous time steps differently which was achieved by softmax in self-attention. While this is more constrained and inflexible than the softmax operation due to its restrictive assumption, it has shown to work equally well by the authors!

Thus, the D matrix ends up looking something like this:

Fig 6: Each of the future positions are weighed 0 while the past time steps hold an exponentially decaying weight

The next step is to see how the “position-aware” (Q.KT) and D matrix come together to give the final output embedding of each input token in X.

Putting them together

We can now combine the operations of Fig 5 and Fig 6 above using the given Hadamard product to obtain the final step of parallel operation detailed in Eq4 as:

Fig 7: The position-aware (Q.KT) (left) combined with the causal-mask + exponential decay D (right)

You realize now why we do not care about the upper triangle of “position-aware” (Q.KT), as these values are put to 0 post this operation with D! You can clearly see now how this entire operation is completely parallelizable during training. But before we jump into the recurrent inference procedure, let us work this out with an example to see how this pans out in practice!

Parallel training — working example

Let us assume we have just a two token sequence, i.e., N=2 with embedding size D=3. Let’s say in this example, this gives us the following Q,K and V matrices of NxD dimensions (first row is the first token in each, etc):

Step 1: Q.KT

Step 2: Hadamard product between Q.KT and D

Step 3: Multiplying (Q.KT and D) with V

Equation 7: The final embeddings of the 2 tokens of the sequence obtained via the parallel training paradigm. We will match this with the recurrent retention later

There! We have our final contextualized embeddings of the 2 input tokens using RetNet’s parallel paradigm which is used during training.! Remember this final result, as we will have to match it with their recurrent paradigm which should technically yield the exact same result in the subsequent sections. Let us dive into the recurrent inference paradigm!

Step 2: Recurrent Retention for inference

RetNet’s recurrent retention paradigm is obtained by de-constructing the parallel computations such that the recurrent representation works exactly the same during inference, but have a fraction of memory complexity. This is one of the main contributions of this work, and also very interesting. Let us see how:

Fig 8a: RetNet’s recurrent computation for inference. Doesn’t it look familiar yet different?

This looks quite familiar — has the general flow of an RNN but the operations inside the cell are Transformer-like! Let us see what is happening here in more detail and some annotations to make things clearer:

Fig 8b: RetNet’s recurrent computation reveals some interesting details when we look closer

The first thing we notice is that Q,K,V matrices are now time-step indexed (the n subscripts) and thus are vectors of dimension 1xD and not matrices of NxD as earlier. This makes sense since it is a recurrent setting and the given block shows the processing of a specific token. Second thing we notice is that there is state vector S that carries forward from the previous time-step to convey the temporal/positional information. This Sn-1 is multiplied by the exponential decay/discount factor, γ at each time step to achieve the task of D recurrently. This controls the kind of information that is retained in the state vector S for future steps.

Third and most interesting part is that if you look at the computation within the cell, we see that the first operation in the recurrent setting is now KT.V and not Q.KT and the Q gets multiplied later. These matrices were trained in a different order during the parallel training phase, however at inference time they are being computed differently and we still expect this to work?? This is the neat trick and one of the major contributions of this paper that shows how the Transformer setting can be de-constructed in a recurrent paradigm with some unintuitive modifications. Let us see what are the exact operations of the recurrent paradigm:

Equation 8: The state vector is updated and then the retention of that time step is calculated using the updated state vector.

The operations in Eq 7 summarize what we see in the Fig 8. First the state vector is updated using the discount factor on the previous state vector and adding it with the KT.V operation. Finally, the updated state vector is multiplied with Q to get the final output of this step. All the outputs are later collated to form the final output matrix. Since we already have the knowledge of how γ and KT.V kind of operations work from the worked out example above, this is quite intuitive already. Just the one nagging question, how does KT.V work instead of Q.KT to achieve the same results?? Time to work this out with an example!

Recurrent inference — working example

We carry forward our two token sequence (N=2), embedding size D=3 example from earlier for consistency. Thus, we have our old Q,K,V matrices as:

Step 1: Compute KT.V for n=1. If you haven’t been paying attention, KT.V is not a dot product like Q.KT, but rather an outer product between 2 vectors that gives a matrix instead of a scalar! Subtle but important change to make this mystery work. Also, now we will be working through the tokens iteratively. So for n=1:

Step 2: Get S1. Since there is no S0, S1 is the same as the previous step as nothing gets added to it:

Step 3: Multiplying Q and S1 to get the final output. There is a catch here. While the diagram and equation do not mention this explicitly, the pseudo code suggests that we need to do an element-wise multiplication followed by a column-wise addition to obtain the final output vector of each time step as highlighted below:

So, after struggling with matching the shapes for a while, the pseudo code finally helped and we get the following output for the first token:

Voila! Do you notice that the first token embedding obtained here via recurrent retention is the same as the parallel training computation earlier in Eq 8? This is awesome! So even with unintuitive changes to the recurrent block computation, the results match exactly for the first step. But let us work through another step to see how does the S1 computed in this step gets used there.

Step 4: Compute KT.V for n=2. Repeating the same outer product procedure, we get:

Step 5: Get S2. The computation here is a little more involved as we have to multiply S1 with the discount factor, γ, before adding it to the result of the above step:

Step 6: Get the final output. Now that we have leveraged the previous state information and updated it with the K,V information from the current state, we can multiply it with Q to get the final output as:

Voila! We again see that the embedding of the second token has come out to be exactly the same as that of the parallel retention in Eq 8. The subtle changes made to the computation has enabled RetNet to de-construct the parallel training computation into a recurrent one exactly without any approximation!

To me personally, this was very very exciting and beautiful to see everything fall into place so neatly. Kudos to the authors to have achieved this feat! We now have an architecture that combines the training supremacy of Transformers, with the inference effectiveness of RNNs.

Final words

Phew! That was a long read with dense information, and I hope you did not fall asleep mid-way. While we dived deep into the inner workings of each component of the RetNet to understand the intuition with working examples, this is still not the complete story. There are MANY more interesting details and components left out on purpose in this blog that you can find in the original paper. This blog was aimed at providing you with all the necessary knowledge and math required to read the paper in more detail yourself and share your learnings and thoughts with everyone further. I hope by now you can clearly see all the hidden details between this diagram and appreciate it even more:

This work is not getting the attention it deserves and I hope more people will feel enabled to take this work and idea forward post reading this blog. Thank you for reading, and I will see you in the next one!

About the author: Shantanu is an AI Research Scientist at the AI Center of Excellence lab at ZS. He did his Bachelor’s in Computer Science Engineering and Master’s in Artificial Intelligence (cum laude) from University of Amsterdam with his thesis at the intersection of geometrics deep learning and NLP in collaboration with Facebook AI, London and King’s College, London. His research areas include Graph Neural Networks (GNNs), NLP, multi-modal AI, deep generative models and meta-learning.

--

--

Shantanu Chandra
AI FUSION LABS

AI Research Scientist, AI Lab @ ZS | MS in AI, Univ of Amsterdam