Understanding YaRN: Extending Context Window of LLMs

RAJAT CHAWLA
6 min readNov 15, 2023

--

YaRN: Yet another RoPE extensioN method

In this post, I’ll go throw the main ideas of the YaRN fine-tuning approach, which was introduced in the paper from November 2023, titled “YaRN: Efficient Context Window Extension of Large Language Models.” This approach aims to efficiently extend the context sizes of large pre-trained language models (LLMs) without incurring significant computational costs.

Introduction

Rotary Position Embeddings (RoPE) are great for ecoding positional information in transformer-based language models. But, these models struggle when it comes to understanding really long sentences. Now, there’s YaRN, a smart way to help these models understand much longer sentences without using a lot of computer power. It needs way fewer pieces of information (tokens) and takes less time to learn compared to older methods, requiring 10x less tokens and 2.5x less training steps. With YaRN, LLaMA models can now handle and make sense of much longer sentences than they were originally trained for.
And here’s the cool part: YaRN isn’t just good at making sense of longer sentences during fine-tuning, it can also understand things beyond what it learned from the limited context data during fine-tuning.

Sliding window perplexity (S = 256) of ten 128k Proof-pile documents truncated to evaluation context window size
Sliding window perplexity (S = 256) of ten 128k Proof-pile documents truncated to evaluation context window size

Positional Encodings

Position encodings play a crucial role in addressing the context window limitation. Original Transformer architecture used absolute sinusoidal position encoding, later improved to a learnable absolute position encoding. Relative positional encoding schemes, such as T5 Relative Bias, RoPE, XPos, and ALiBi, have further improved transformer performance. But, Positional encodings struggle to generalize beyond the context window seen during training.

YaRN is an improved method to extend the context window efficiently for models trained with Rotary Position Embeddings (RoPE). It applies to LLaMA, GPTNeoX, Mistral and PaLM model families. YaRN achieves state-of-the-art performances in context window extensions after fine-tuning on a small fraction of the original pre-training data.
YaRN reaches state-of-the-art performance in context window extensions after fine-tuning on less than ∼0.1% of the original pre-training data. Dynamic-YaRN, combined with Dynamic Scaling at inference time, allows for more than 2x context window extension without any fine-tuning.

Let’s go through the methodology needed to understand YaRN.

  1. NTK-aware

In order to resolve the problem of losing high frequency information when interpolating the RoPE embeddings, the “NTK-aware” interpolation came into picture. Instead of scaling every dimension of RoPE equally by a factor s, we spread out the interpolation pressure across multiple dimensions by scaling high frequencies less and low frequencies more. One can obtain such a transformation in many ways, but the simplest would be to perform a base change on the value of θ.

This method performs much better at extending the context size of non-finetuned models compared to PI.
However, one major disadvantage of this method is that given it is not just an interpolation scheme, some dimensions are slightly extrapolated to “out-of-bound” values, thus fine-tuning with “NTK-aware” interpolation yields inferior results to PI. Furthermore, due to the “out-of-bound” values, the theoretical scale factor s does not accurately describe the true context extension scale. In practice, the scale value s has to be set higher than the expected scale for a given context length extension

2. “NTK-by-parts” Interpolation

When stretching RoPE dimensions uniformly (by a scale s or using a base change b’), tokens become closer, impairing the model’s ability to understand small and local relationships between internal embeddings. This compression leads to confusion about the positional order of close-by tokens, harming the model’s abilities.

“NTK-by-parts” Interpolation proposes a solution by not interpolating higher frequency dimensions at all and always interpolating lower frequency dimensions.

Introduces a ratio r = L/λ to decide whether to interpolate or not based on the wavelength.

Specifies conditions for interpolation:

  • If λ is much smaller than L, no interpolation.
  • If λ is equal to or bigger than L, only interpolate (avoid extrapolation).
  • Dimensions in-between can have a bit of both, similar to “NTK-aware” interpolation.

Introduces two extra parameters α and β to define the boundary for interpolation strategies. Defines a ramp function γ to determine the interpolation behaviour based on the ratio r.

The release of the “NTK-by-parts” interpolation method, claiming it performs better than PI and “NTK-aware” methods, both with non-fine-tuned and fine-tuned models.

3. Dynamic Scaling — “Dynamic NTK” interpolation

Interpolation method that uses a scale factor s, including methods like Position Interpolation (PI), “NTK-aware,” and “NTK-by-parts.”
Throughout the entire inference cycle, the embedding layer is fixed with a scale factor s = L’/L, where L’ is the fixed number of extended context size.
In each forward-pass, the position embedding updates the scale factor s = max(1, l’/L), where l’ is the sequence length of the current sequence.

With the above approach the model may experience a performance drop at a length less than L and an abrupt degradation when the sequence length is longer than L’. Introducing the Dynamic Scaling method as a solution to gracefully degrade instead of breaking immediately when hitting the trained context limit L’.

Dynamic Scaling as an inference-time method where the scale factor is updated dynamically during each forward-pass based on the sequence length. When combined with “NTK-awared” interpolation, it is termed “Dynamic NTK” interpolation.

Dynamic NTK” interpolation works exceptionally well on models pretrained on L without any finetuning (L’ = L)
Notes that in repeated forward-passes, kv-caching (key-value caching) is often applied to reuse previous key-value vectors and improve overall efficiency.

The importance of handling RoPE embeddings correctly when implementing Dynamic Scaling with kv-caching, suggesting that kv-embeddings should be cached before applying RoPE, as the RoPE embedding of every token changes when s changes.

YaRN (Yet another RoPE extensioN method)

In addition to the previous interpolation techniques, we also observe that introducing a temperature t on the logits before the attention softmax has a uniform impact on perplexity regardless of the data sample and the token position over the extended context window. More precisely, we modify the computation of attention weights into softmax

Computation of attention weights
Recommended values for LLaMA and Llama 2 models

The reparametrization of RoPE as a set of 2D matrices has a clear benefit on the implementation of this attention scaling: we can instead use a “length scaling” trick which scales both qm and kn by a constant factor p 1/t by simply scaling the complex RoPE embeddings by the same amount. With this, YaRN can effectively alter the attention mechanism without modifying its code. Furthermore, it has zero overhead during both inference and training, as RoPE embeddings are generated in advance and are reused for all forward passes. Combining it with the “NTK-by-parts” interpolation, we have the YaRN method.

The YaRN method combines all our findings and surpasses all previous methods in both fine-tuned and non-fine-tuned scenarios. Thanks to its low footprint, YaRN allows for direct compatibility with libraries that modify the attention mechanism such as Flash Attention 2.

Evaluation

  1. Long Sequence Language Modeling
Sliding window perplexity (S = 256) of ten 128k Proof-pile documents truncated to evaluation context window size
Sliding window perplexity (S = 256) of ten 128k Proof-pile documents truncated to evaluation context window size

The model exhibits strong performance across the entire targeted context size.

2. Passkey Retrieval

10 iterations were performed of the passkey retrieval task with the passkey placed at a random location uniformly distributed across the evaluation context window on different context window sizes ranging from 8k to 128k. Both 7b and 13b models fine-tuned using YaRN at 128k context size passes the passkey retrieval task with very high accuracy (> 99%) within the entire context window size.

3. Standardized Benchmarks

Performance of context window extensions methods on the Hugging Face Open LLM benchmark suite compared with original Llama 2 baselines

Conclusion

YaRN improves upon all existing RoPE interpolation methods and can act as a drop-in replacement to PI, with no downsides and minimal implementation effort. The fine-tuned models preserve their original abilities on multiple benchmarks while being able to attend to a very large context size. Furthermore, YaRN allows efficient extrapolation with finetuning on shorter datasets and can take advantage of transfer learning for faster convergence, both of which are crucial under compute-constrained scenarios. Finally, we have seen the effectiveness of extrapolation with YaRN where it is able to “train short, and test long”.

--

--