DEV Community

jidonglab
jidonglab

Posted on

RoPE Scaling: How LLMs Stretch From 8K to 128K Context

A model trained on 8K tokens does not magically read 128K. Someone changed three or four numbers in the rotary embedding config and ran a short fine-tune. That edit is RoPE scaling, and if you have ever loaded a long-context checkpoint, set a rope_scaling field, and watched perplexity either hold or explode, you have already depended on it without seeing the mechanism.

This post is the mechanism: why Rotary Position Embeddings break past their training length, and how linear interpolation, NTK-aware scaling, and YaRN each fix that break differently.

TL;DR

  • RoPE encodes position by rotating each pair of query/key dimensions at a fixed frequency. Past the trained context length, the rotation angles hit values the model never saw, and attention collapses.
  • Linear position interpolation (PI) divides every position index by a scale factor, squeezing new positions into the trained angle range. It works but smears fine local resolution because it slows the high-frequency dimensions too.
  • NTK-aware scaling changes the RoPE base instead of the positions, so high-frequency (local) dimensions stay almost untouched while only the low-frequency (global) dimensions get stretched.
  • YaRN combines NTK-by-parts interpolation with an attention-temperature correction, and reaches long context with the least fine-tuning of the three.
  • All three need a short fine-tune at the target length to be reliable; a few of them work partially zero-shot, but don't ship that.

What does RoPE actually do to a token's position?

RoPE injects position by rotating pairs of dimensions in the query and key vectors by an angle proportional to the token's absolute position. The dot product between a query at position m and a key at position n then depends only on the relative offset m - n. That relative property is why RoPE generalizes better than learned absolute embeddings — until it doesn't.

Each pair of dimensions i (out of d/2 pairs) rotates at its own frequency:

theta_i = base ^ (-2i / d),   base = 10000 (typical)
Enter fullscreen mode Exit fullscreen mode

Dimension 0 rotates fast (high frequency, short wavelength). The last dimension rotates slowly (low frequency, very long wavelength). The wavelength of pair i — how many tokens before that rotation completes a full turn — is:

lambda_i = 2 * pi * base ^ (2i / d)
Enter fullscreen mode Exit fullscreen mode

For the slowest dimensions, the wavelength can exceed the entire training context. Those dimensions never complete even one rotation during training. That detail is the whole reason naive extrapolation fails.

Why does context blow up past the training length?

Because at position 9000 on a model trained to 8192, the high-frequency dimensions produce rotation angles the model literally never observed during training, and attention scores go out of distribution.

The fast dimensions are the problem. A high-frequency pair might complete a full rotation every 10–20 tokens, so within an 8K window it has seen every angle many times — fine. But when you feed token 50000, the accumulated phase across the sequence, and the specific high-frequency angle values at extreme positions, land in regions where the trained attention pattern has no calibration. Empirically you see perplexity stay flat up to the trained length and then climb almost vertically a few hundred tokens past it. The model isn't "forgetting"; its position signal has gone off the edge of the manifold it learned on.

So the goal of every scaling method: keep the rotation angles the model sees at long positions inside (or near) the range it was trained on.

How does linear position interpolation work, and what does it cost?

Linear PI divides every position index by a scale factor s = L_new / L_train, so position 128000 in a 128K window gets the same rotation angle that position 8000 had during 8K training. You compress the new positions into the old, in-distribution angle range.

import torch

def rope_inv_freq(dim, base=10000.0):
    # one inverse frequency per dimension-pair
    i = torch.arange(0, dim, 2).float()
    return 1.0 / (base ** (i / dim))

def linear_pi_freqs(dim, scale, base=10000.0):
    # PI: slow EVERY dimension by 1/scale (equivalently, divide positions by scale)
    return rope_inv_freq(dim, base) / scale

# 8K -> 32K means scale = 4
inv_freq = linear_pi_freqs(dim=128, scale=4.0)
Enter fullscreen mode Exit fullscreen mode

It works, and it was the first method that made long-context fine-tunes cheap. The cost is uniform compression: PI slows the fast dimensions by the same factor as the slow ones. Those high-frequency pairs encode fine-grained local ordering — "is this token 2 or 3 positions back." Compress them 4x or 32x and you blur exactly the short-range resolution the model uses for local syntax. PI-scaled models recover with fine-tuning but tend to lose a little local sharpness, and at large scale factors (16x, 32x) that loss is measurable.

How is NTK-aware scaling different, and why is it usually better?

NTK-aware scaling stretches context by raising the RoPE base instead of dividing the positions. The effect: low-frequency (long-wavelength, global) dimensions get interpolated heavily, while high-frequency (local) dimensions barely move. You spend the distortion where it hurts least.

The intuition comes from the Neural Tangent Kernel observation that networks struggle to learn high-frequency information when inputs are squeezed — so don't squeeze the high-frequency dimensions. Instead of scaling positions, scale the base so that the slowest dimension ends up interpolated by the full factor s, and the fastest is essentially unchanged:

def ntk_aware_freqs(dim, scale, base=10000.0):
    # raise the base so low-freq dims stretch and high-freq dims stay put
    new_base = base * (scale ** (dim / (dim - 2)))
    return rope_inv_freq(dim, new_base)

inv_freq = ntk_aware_freqs(dim=128, scale=4.0)
Enter fullscreen mode Exit fullscreen mode

A useful property: NTK-aware scaling degrades more gracefully without fine-tuning than linear PI, which is why early Llama community extensions reached for it zero-shot. But "graceful" still means worse than a proper fine-tune. The flaw is that the per-dimension treatment is a smooth function of frequency, so some middle dimensions get a scaling that is neither "leave alone" nor "fully interpolate" — and those in-between dimensions can still drift out of distribution at extreme lengths.

What does YaRN add on top of NTK?

YaRN ("Yet another RoPE extensioN") fixes two things NTK leaves on the table: it interpolates dimensions by parts based on their wavelength, and it adds an attention temperature correction. It hits target context with notably less fine-tuning data than PI or plain NTK.

The by-parts idea splits dimensions into three groups using each one's wavelength relative to the original context:

  • High-frequency dimensions (wavelength much smaller than the context): leave them alone — no interpolation. These carry local resolution.
  • Low-frequency dimensions (wavelength larger than the context, never completed a rotation in training): apply full linear interpolation. They were never going to extrapolate anyway.
  • Middle dimensions: blend smoothly between the two via a ramp function.

The second piece is the one people forget. Stretching positions changes the average magnitude of attention logits, which shifts the entropy of the softmax. YaRN multiplies attention scores by a constant temperature 1/t (a function of the scale factor, roughly 0.1 * ln(s) + 1) to put the softmax distribution back where the model was trained. Concretely it folds into the query/key scaling so it costs nothing at inference:

import math

def yarn_attention_scale(scale):
    # temperature correction folded into the qk scale; ~1.0 at scale=1
    return 0.1 * math.log(scale) + 1.0

# applied as an extra multiplier on the attention logits / qk norm
t = yarn_attention_scale(scale=4.0)   # ~1.14
Enter fullscreen mode Exit fullscreen mode

That temperature term is why YaRN reaches a given perplexity with fewer training tokens — it corrects a distribution shift that PI and NTK silently absorb into the fine-tune.

Which RoPE scaling method should you actually use?

Use what the model card tells you to use, and never mix scaling at inference with a checkpoint trained under a different scheme. The methods are not interchangeable knobs; the weights were adapted to one specific position geometry.

Practical guidance:

  • Loading a published long-context checkpoint (Llama 3.1 at 128K, Qwen long-context variants, and similar): the config.json already carries the right rope_scaling block — rope_type of linear, dynamic/NTK, or yarn plus a factor. Respect it. Overriding the factor to push past the trained length usually wrecks the model long before you reach the new target.
  • Extending a base model yourself: prefer YaRN for large scale factors (8x and up) because of its fine-tuning efficiency; NTK-aware is a reasonable, simpler choice for modest extensions.
  • Zero-shot, no fine-tune available: dynamic NTK degrades the most gracefully, but treat anything beyond ~2x as a demo, not production.
  • Watch the failure signature: flat loss up to the trained length, then a near-vertical climb, means your effective scaling factor is too aggressive for the fine-tune (or the data) you gave it.

One more production reality: extending context with RoPE scaling changes position handling, not attention cost. The KV cache and the quadratic attention term grow with the real sequence length regardless of how cleverly you scaled the angles. A 128K context that decodes correctly can still be too slow or too memory-hungry to serve — scaling solves the math, not the bill.

So how do LLMs stretch from 8K to 128K?

RoPE scaling is how LLMs stretch from 8K to 128K context: every method keeps the rotation angles at long positions inside the range the model was trained on, and they differ only in where they spend the unavoidable distortion. Linear position interpolation compresses every dimension uniformly and blurs local resolution. NTK-aware scaling raises the RoPE base so only the slow, global dimensions stretch while fast, local ones stay sharp. YaRN refines that with by-parts interpolation plus an attention-temperature correction, reaching the target length with the least fine-tuning. None of them is free — they all want a short fine-tune at the new length, and they extend the position math without touching the quadratic attention and KV-cache costs that make long context expensive to serve. Pick the scheme the checkpoint was trained with, and read the rope_scaling block before you trust the window.

Top comments (0)