DEV Community

Chaitany
Chaitany

Posted on

PolarQuant: Quantizing KV Caches with Polar Transformation

A deep dive into how PolarQuant compresses LLM key caches by 4x using polar coordinates, and why it works so well.


If you have ever tried running a large language model on long contexts (32K, 64K, or 128K tokens), you have hit the wall: the KV cache. It grows linearly with sequence length, eating up GPU memory and becoming the dominant bottleneck during inference.

PolarQuant, introduced by researchers from KAIST, Google Research, and Yale (arXiv:2502.02617), offers an elegant solution. Instead of quantizing key embeddings the usual way (in Cartesian space), it converts them to polar coordinates (angle and radius) and quantizes those instead. The result is a ~4x compression of the key cache with near-lossless quality on long-context benchmarks.

Let's break down exactly how it works.


What Is PolarQuant?

Every time an LLM generates a token, it needs to attend to all previous tokens. To avoid recomputing everything from scratch, the model stores Key and Value embeddings in a cache. This cache grows with every token generated.

PolarQuant compresses this cache by:

  1. Converting key embeddings from Cartesian coordinates into polar coordinates (angle + radius).
  2. Quantizing the angles and radii to low-bit integers (e.g., 4 bits each).
  3. Computing attention directly from the quantized representation, without ever reconstructing the full keys, using a custom GPU kernel.

Why Not Just Quantize Normally?

Traditional methods like KIVI or GPTQ-style quantization work in Cartesian space. They need per-block normalization constants (zero-points, scales) stored in full precision, which adds significant memory overhead, often over 1 extra bit per quantized value.

PolarQuant sidesteps this problem with a mathematical insight:

  • RoPE (Rotary Position Embeddings), used by most modern LLMs, already operates on pairs of dimensions as 2D rotations. This makes polar coordinates a natural fit.
  • After applying a random preconditioning matrix, the angles in polar coordinates follow a predictable, concentrated distribution. This means we can quantize them with pre-computed codebooks, and no per-block normalization is needed.

How It Works

Step 1: Cartesian to Polar Conversion

Every key vector of dimension d (e.g., d = 128) is split into d/2 pairs of adjacent dimensions. Each pair (x, y) forms a point in a 2D plane.

The conversion to polar coordinates is straightforward. For each pair, we compute the angle as phi = atan2(y, x), which gives the direction of the 2D vector. We also compute the radius as r = sqrt(x^2 + y^2), which gives the magnitude. If the angle comes out negative, we shift it by adding 2 * pi so that all angles fall in the range [0, 2*pi).

This conversion is lossless. The same information is represented, just in a different coordinate system. You can always go back: x = r * cos(phi), y = r * sin(phi).

Each pair of FP16 values (x, y) becomes an angle phi and a radius r.

You can see this in the source code at modeling_llama_polar.py:

phi = torch.atan2(key_states[:, :, :, :, 1, :], key_states[:, :, :, :, 0, :])
phi = torch.where(phi < 0, phi + 2 * torch.math.pi, phi)
radii = torch.norm(key_states, p=2, dim=-2)
Enter fullscreen mode Exit fullscreen mode

The Paper's Recursive Polar Transformation

The paper goes deeper with a recursive version. Instead of stopping at one level of pairing, it recursively applies the polar transform to the radii themselves.

First, at Level 1, we pair adjacent dimensions (x1, x2), (x3, x4), ... into d/2 polar pairs. Each pair gives an angle and a radius.

Then, at Level 2, we take the d/2 radii from Level 1 and pair them again into d/4 pairs. Each pair of radii gives a new angle and a new radius.

We keep repeating this process at Level 3, Level 4, and so on, until a single scalar radius remains.

The final output is one radius plus (d - 1) angles organized across log2(d) levels. In practice, the implementation only recurses for 4 levels (not all log2(d)), leaving d/16-dimensional radii stored in full precision.

Why Does This Help?

Here is the key mathematical insight from the paper. After applying a random preconditioning (rotation) matrix to the key vectors:

  • Level 1 angles are approximately uniformly distributed over [0, 2*pi).
  • Level 2+ angles follow a distribution proportional to sin^(d-1)(2*theta), which is tightly concentrated around pi/4. The higher the level, the more concentrated.

A concentrated distribution means fewer quantization bins are needed to represent it accurately. The paper uses 4 bits for Level 1 (wide range) and just 2 bits for higher levels (concentrated range).


Step 2: How the Quantization Happens

Once we have polar coordinates, we quantize angles and radii to low-bit integers. This is where the lossy compression happens.

Grouping the tokens. First, we organize the keys along the sequence dimension into groups of G tokens (default G = 128). All the quantization statistics are computed independently within each group and each 2D sub-dimension.

Quantizing the angle. Within each group, we find the minimum and maximum angle values across the G tokens. The range between them is divided into 2^tbits equal-width bins (with tbits = 4, that is 16 bins). We compute the step size as delta_phi = (phi_max - phi_min) / 16. Then, for each token's angle, we figure out which bin it falls into by computing floor((phi - phi_min) / delta_phi) and clamping the result to the range 0 through 15. This gives us an integer index theta between 0 and 15.

Quantizing the radius. We do the exact same thing for the radius values. Find the min and max radius within the group, divide the range into 2^rbits bins (again 16 bins with rbits = 4), and map each radius to an integer index rho between 0 and 15.

Packing into one byte. Since both theta and rho are 4-bit values, we combine them into a single uint8 byte by shifting the radius index left by 4 bits and adding the angle index: indices = (rho << 4) + theta. This means every 2D sub-dimension per token is stored in just one byte.

You can see the full implementation in the quantize_and_pack_nbit method at modeling_llama_polar.py:

def quantize_and_pack_nbit(self, key_states):
    B, N, L, D = key_states.shape
    assert D % 2 == 0 and L % self.group_size == 0 and self.rbits + self.tbits <= 8
    D, G = D // 2, self.group_size

    key_states = key_states.view(B, N, L // G, G, 2, D)

    phi = torch.atan2(key_states[:, :, :, :, 1, :], key_states[:, :, :, :, 0, :])
    phi = torch.where(phi < 0, phi + 2 * torch.math.pi, phi)

    tmx, tmn = phi.max(-2, keepdim=True)[0], phi.min(-2, keepdim=True)[0]
    tscale = (tmx - tmn) / (2 ** self.tbits)
    theta = torch.clamp(torch.floor((phi - tmn) / tscale).to(torch.uint8), 0, 2 ** self.tbits - 1)

    radii = torch.norm(key_states, p=2, dim=-2)
    rmx, rmn = radii.max(-2, keepdim=True)[0], radii.min(-2, keepdim=True)[0]
    rscale = (rmx - rmn) / (2 ** self.rbits)
    rho = torch.clamp(torch.floor((radii - rmn) / rscale).to(torch.uint8), 0, 2 ** self.rbits - 1)

    indices = (rho << self.tbits) + theta

    return indices, rscale, rmn, tscale, tmn
Enter fullscreen mode Exit fullscreen mode

What gets stored per group: the packed indices (uint8), the angle step size and minimum (tscale, tmn in float16), and the radius step size and minimum (rscale, rmn in float16).

Compression math: Each 2D pair originally takes 4 bytes (two FP16 floats). After quantization, it takes 1 byte (packed uint8) plus a small share of per-group scales. That is roughly a 4x compression.

The Paper's Codebook Approach

The paper goes further than simple min/max quantization. Since random preconditioning makes the angle distribution predictable and analytically known, they use 1-D k-means clustering on the distribution to find optimal bin boundaries and centroids that minimize the mean squared quantization error.

They also allocate bits differently across levels:

  • Level 1 angles (range [0, 2*pi)): 4 bits = 16 bins
  • Level 2+ angles (range [0, pi/2], concentrated): 2 bits = 4 bins

This gives: 16 bits (FP radius) + 32 bits (8 level-1 angles * 4 bits) + 14 bits (7 higher angles * 2 bits) = 62 bits per block of 16 coordinates = 3.875 bits per coordinate.


Step 3: Polar to Cartesian (Dequantization)

To reconstruct a key from its quantized form, we reverse the process. There are four steps to this.

First, we unpack the byte by separating it back into the angle index and the radius index. The lower 4 bits give us the angle index theta (extracted using a bitwise AND with 0xF), and the upper 4 bits give us the radius index rho (extracted by right-shifting by 4).

Second, we reconstruct the angle by converting the bin index back to an actual angle value. Instead of using the left edge of the bin, we use the midpoint by adding 0.5 to the index before multiplying by the step size and adding the minimum: phi_hat = delta_phi * (theta + 0.5) + phi_min. Using the midpoint halves the maximum quantization error compared to using the bin edge.

Third, we reconstruct the radius the same way: r_hat = delta_r * (rho + 0.5) + r_min.

Finally, we convert back to Cartesian coordinates using the standard polar-to-Cartesian formulas: x_hat = r_hat * cos(phi_hat) and y_hat = r_hat * sin(phi_hat).

You can see this reconstruction in the PyTorch reference code at kernel4group.py:

phi = theta * tscale + 0.5 * tscale + tmn
radii = rho * rscale + 0.5 * rscale + rmn
key_states_reconstruct = torch.stack([radii * phi.cos(), radii * phi.sin()], dim=-2)
Enter fullscreen mode Exit fullscreen mode

The Paper's Recursive Dequantization

For the multi-level recursive transformation, dequantization works top-down, starting from the single radius at the highest level and splitting downward.

At each level, we take each radius value and split it into two values by looking up the stored angle centroid for that position. The first value is the radius multiplied by the cosine of the centroid angle, and the second is the radius multiplied by the sine. This doubles the number of coordinates at each level. We repeat this from the top level all the way down to level 1, at which point we have the full reconstructed vector. If preconditioning was applied, we multiply by the transpose of the preconditioning matrix to undo it.


Step 4: The Efficient Decode Kernel (Skip Reconstruction Entirely)

Here is the cleverest part of the implementation. During decoding, we never actually reconstruct the keys. Instead, we compute the attention scores (the dot product Q * K^T) directly from the packed byte codes.

The mathematical trick:

The dot product between a query q and a reconstructed key, for each 2D sub-dimension, can be written as:

q_x * r_hat * cos(phi_hat) + q_y * r_hat * sin(phi_hat)

If we factor out the radius, this becomes:

r_hat * (q_x * cos(phi_hat) + q_y * sin(phi_hat))

The expression inside the parentheses depends only on the angle bin index. Since the angle phi_hat can only take 2^tbits = 16 possible values (one per bin), we can precompute the dot product for all 16 possibilities and then just look up the right one using the stored index.

Here is how the kernel works, step by step:

First, for each group of quantized keys, the kernel precomputes the query-angle dot product for all 16 possible angle bins. For each bin, it computes the midpoint angle and then calculates q_x * cos(angle) + q_y * sin(angle). This creates a small lookup table of 16 values per sub-dimension.

Next, for each token in the group, the kernel extracts the angle index from the lower 4 bits of the packed byte and looks up the corresponding precomputed dot product from the table. No cos/sin computation is needed per token, just a table lookup.

Then, the kernel extracts the radius index from the upper 4 bits, computes the radius midpoint value, and multiplies it with the looked-up dot product. This gives the contribution of that sub-dimension to the total attention score.

Finally, the kernel sums these contributions across all sub-dimensions to produce the final attention logit for that token.

You can see the Triton kernel implementation at kernel4group.py:

phi = tscale * (tl.arange(0, 1 << tbits)[None, None, :, None] + 0.5) + tmn

tp = tl.sum(query * tl.interleave(tl.cos(phi), tl.sin(phi)), axis=-1)

attn = tl.gather(tp, tl.broadcast_to(indices[None, :, :] & (2 ** tbits - 1),
                 (N // Nk, D, G)), axis=-1)

radii = rscale * (tl.arange(0, 1 << rbits)[None, None, :] + 0.5) + rmn

attn *= tl.gather(radii, indices[None, :, :] >> tbits, axis=-1)

attn = tl.sum(attn, axis=1)
Enter fullscreen mode Exit fullscreen mode

Why is this fast?

Traditional (reconstruct + matmul) PolarQuant kernel
Memory reads/token d floats = 2d bytes d/2 bytes (packed uint8)
Compute Full matmul 16 cos/sin (shared) + 2 lookups + 1 multiply per sub-dim
Memory bandwidth Bottleneck on long seqs ~4x less traffic

The kernel reads 4x less data from memory and replaces expensive per-token math with cheap table lookups.


Practical Implementation

Cache Architecture: Hybrid Prefill + Decode

The implementation does not quantize everything. It uses a hybrid strategy, keeping a short full-precision tail alongside the quantized prefix.

During prefill (when processing the full prompt, where multiple tokens are processed at once):

The model first runs full-precision FlashAttention on the entire prompt to compute the attention output. After that, it splits the key cache into two parts. The portion of the sequence whose length is divisible by residual_length (default 128) gets polar-quantized into packed indices and scales. The remaining tail (the last few tokens that don't fill a complete group) stays in full precision. If the entire sequence is shorter than 128 tokens, nothing gets quantized at all.

During decode (when generating one token at a time):

Each new token's key is appended to the full-precision tail. The model then computes attention in two parts. For the quantized prefix, it uses the Triton kernel to compute approximate attention scores directly from the packed bytes. For the full-precision tail, it uses standard matrix multiplication. The two sets of scores are concatenated, divided by sqrt(head_dim), and passed through softmax together. Whenever the full-precision tail grows to a length that is a multiple of residual_length, the entire tail gets quantized and merged into the packed prefix, and the tail is reset to empty.

This logic lives in the forward method of LlamaPolarGroupAttention at modeling_llama_polar.py. The hybrid approach ensures that the most recent tokens (which often matter most for generation quality) are always at full precision.

Default Hyperparameters

Parameter Value What it controls
rbits 4 Bits for radius quantization (16 bins)
tbits 4 Bits for angle quantization (16 bins)
group_size 128 Tokens per quantization group
residual_length 128 Size of the full-precision tail
Constraint rbits + tbits <= 8 Must fit in one byte

These are set in the __init__ method of LlamaPolarGroupAttention at modeling_llama_polar.py.

Performance Results

On LongBench with Llama-3.1-8B-Instruct at ~4x compression:

Method Average Score
Exact (16-bit, no compression) 48.63
PolarQuant-R (online codebook) 48.37
PolarQuant-R (offline codebook) 48.29
PolarQuant (no preconditioning) 48.11
KIVI 46.70
HeadKV 45.34
SnapKV 44.57
PyramidKV 44.03
StreamingLLM 38.36

PolarQuant achieves the best quality among all compression methods, nearly matching the uncompressed baseline, while generating tokens 14% faster than KIVI.

On the Needle-in-a-Haystack test (finding a specific sentence buried in a 104K token document), PolarQuant scores 0.991 vs the uncompressed 0.995, virtually indistinguishable.


Putting It All Together

Here is the full pipeline of what happens to a key vector:

  1. Start with the original key in FP16: [x1, y1, x2, y2, ..., x_{d/2}, y_{d/2}]
  2. Pair adjacent dimensions into 2D pairs: [(x1,y1), (x2,y2), ..., (x_{d/2}, y_{d/2})]
  3. Convert each pair to polar form using atan2 and L2 norm: [(phi1, r1), (phi2, r2), ...]
  4. Quantize each angle and radius to 4-bit integers: [(theta1, rho1), (theta2, rho2), ...]
  5. Pack both into one byte per pair: (rho << 4) + theta, giving [byte1, byte2, ..., byte_{d/2}]

Storage: d values at 2 bytes each = 2d bytes becomes d/2 bytes + small per-group scales. That is roughly a 4x compression.

PolarQuant works because:

  1. Polar coordinates are natural for RoPE-transformed key vectors, since RoPE already operates on 2D rotations.
  2. Random preconditioning makes angle distributions predictable and concentrated, enabling efficient quantization without per-block normalization.
  3. The dot-product factorization trick lets us compute attention directly from packed bytes using table lookups instead of reconstructing full keys.
  4. The hybrid cache keeps recent tokens at full precision while progressively compressing older ones.

The end result is 4x less memory, 14% faster generation, and near-lossless quality, a compelling combination for deploying LLMs on long contexts.


Paper: PolarQuant: Quantizing KV Caches with Polar Transformation by Insu Han, Praneeth Kacham, Amin Karbasi, Vahab Mirrokni, and Amir Zandieh.

Code: github.com/ericshwu/PolarQuant. The implementation supports Llama and Qwen2 models, with a custom Triton GPU kernel for efficient decode-time attention over quantized keys.

Top comments (0)