The Rise of RNN? Review of “Retentive Network: A Successor to Transformer for Large Language Models”

Sehyun Choi
7 min readJul 21, 2023

--

TLDR;

Retentive Network (RetNet) has comparable performance with same-sized Transformer, can be trained in parallel, but supports recurrence mode which allows O(1) inference complexity per token.

The unofficial yet complete implementation can be found in my repo below:

The “Impossible Triangle” for Generative Sequence Models

For Sequence Models, esp. generative ones, we have the above three desiderata: fast inference, parallel training, and strong performance. (In my opinion, there’s one more dimension: sequence length extrapolation. This might be supported by RetNet, but no explicit experiment on it.)

RNN has fast inference but slow training, linear transformers have weaker performances, and transformers have O(n) per token inference. RetNet satisfies all three: Parallel training, O(1) inference, and beats transformers.

Quick History

There have been multiple approaches to mitigate the expensive inference of generative transformers. Notable works include Linear Transformers, Attention-Free Transformers (AFT; from Apple), and RWKV (from BlinkDL, based on AFT).

These deserve a separate post, so I won’t go into detail: but in my opinion, they are all very elegant mathematically, especially the derivation of how the RNN can be parallelized. while I find RetNet a bit more interesting, as it also has chunkwise representation and some nifty tricks like xpos.

So How does this Work?

RetNet is a somewhat plug-and-play substitution of “attention” to “retention” in the same Transformer architecture.

I will go through them in a top-down manner.

1. Each RetNet Block

The equation for each RetNet Block.

At the highest level, RetNet consists of several stacks of identical blocks, each containing MultiScaleRetention (MSR) and FeedForwardNetwork (FFN). They also have layer-norm and skip-connections, the same as Transformers. FFN is almost identical to Transformers too, which is a 2-layer MLP, hidden dim size = 2 x embedding size, and with gelu activation.

If we substitute MSR with MultiHeadAttention, this is just Transformer. So all the differences can be found in MSR.

2. Gated Multi-Scale Retention

Multi-Scale is analogous to Multi-Head. In the equation above, γ is some hyperparameter to be used in retention, and this is defined separately for each head. Before the group norm, this is plain-old multi-head attention, but with retention.

Gated MSR adds group norm, swish gate, and out projection on the output, which can be considered an auxiliary design choice. (group-norm allows for scaled dot-product, but not that important for now.) The most important distinction (retention module) is still yet to come.

3. Retention

Finally, let’s look at what is retention. Retention has 3 paradigms: Parallel, Recurrent, and Chunkwise-Recurrent. Let’s look at them one by one.

Parallel Retention

Parallel Representation of Retention

Focus on the last line. Ignoring D, again, this is dot-product attention without softmax. So the important detail is again in the D and Theta.

  • Theta (and bar(Theta), the complex conjugate) is the complex representation of the xpos encoding”this builds on rotary embedding so that the model can extrapolate sequence length better. There’s an identical representation in non-complex space, which is precisely the xpos built on RoPE.

Refer to the xpos paper. I also found this lecture note helpful in understanding this.

  • D is causal masking + decay matrix.

If you draw D, D looks like the following:

gamma = 0.9
exponent = [[0, 0, 0, 0],
[1, 0, 0, 0],
[2, 1, 0, 0],
[3, 2, 1, 0]]

D = tril(gamma**exponent)
# [[1., 0., 0., 0.],
# [0.9000, 1., 0., 0.],
# [0.8100, 0.9000, 1., 0.],
# [0.7290, 0.8100, 0.9000, 1.]])
  • The upper triangle is 0 → causal masking.
  • The exponent = the number of times the previous token representation has been decayed. This will become more clear when we see the recurrent representation.

Recurrent Retention

Recurrent Retention

Sn is analogous to the KV-cache in the transformers. Instead of concatenating all of them in sequence, RetNet aggregates them into a single matrix with the recurrent in the first line. Then, This is multiplied by the current step’s query.

This is exactly the same as parallel Retention.

Informal Proof Sketch:

Let S_0 = 0. If we solve the recurrence of S_n,

Recall the last row of the exponent matrix for the D in the parallel representation, which was [3, 2, 1, 0]. Notice that n=4. When we compute the retention for the 4th token vs 1st token, we decay it 3 times, which is equivalent to n — i = 3 in the above equation! Since the rest is the same, the parallel and recurrent representations are identical to each other.

Chunkwise Retention

This looks complicated, but it is actually parallel compute per chunk + recurrent connection of the chunks. The only important thing is again the number of decays applied.

Mistake in the paper

Actually, the paper’s chunkwise representation (equation above) for Ri is wrong! In fact, it should be

where the X operator is the cross product and D_B is the last row of the D matrix. Intuitively, this follows from the decay multiplication of the parallel representation and the recurrent representation.

Schematic Diagram

That’s it! Above is the summary diagram of the two representations.

Why Decay?

So basically, the most important detail is that it uses something called decay, and applying the decay the correct number of times allows for parallelization. But we must understand what’s the motivation behind such a decay. The derivation (at the high level) is pretty simple.

  1. we define the recurrent state (s_n) as kv_cache. Then, the recurrence relationship is in the first line in the figure above.
  2. Then, we define output at time n as Q_n * s_n. The second line above writes this and solves the recurrence to roll out the full dependence. Notice that A matrix is applied multiple times.

3. Now, we diagonalize the A matrix into the following.

4. Then, the Λ symbols can be absorbed into other learnable parameters (Q_n = X * W_k, so Λ can be absorbed into W_k!). Therefore, we are left with the middle part only.

The middle part is precisely the γ (decay) and theta we observed before.

Intuitively, they kind of work as a “closed-form positional encoding” that also has a recurrent-form, so that the encoding at time n can be computed in advance, allowing parallelization.

Empirical Findings

  • RetNet beats Transformer as it gets larger. (Critic: Not sure if this trend may continue)
  • RetNet beats other linear time transformers in performances.
  • RetNet is fast. (Critic: this is obvious based on the architecture. Showing 3 figures to emphasize this is pointless. TBH, there’s even no need to run experiments to draw these plots…)

Critics

  • There are a few missing details in the paper, which won’t be made clear until the official code is out.
  • RWKV also supports training parallelization, but it is misrepresented in the paper as impossible.
  • Kinda bragging that the RetNet is fast with 3 figures saying the same thing. :-)
  • Curious if this trend will scale to larger models.
  • Not sure if they will release pre-trained weight.
  • Not sure if they will beat models like LLaMA.

Pros

  • FAST! (I criticized their bragging, but it is indeed fast, which is good)
  • Comparable performance. If this trend continues and there’s no drop in performance in larger models, this might become the de-facto of LLM as they are much cheaper.

For those interested, please take a look at my implementation of RetNet too:

--

--