DEV Community

João André Gomes Marques
João André Gomes Marques

Posted on

The Math Behind E8 Lattice Quantization (with Code)

The Math Behind E8 Lattice Quantization (with Code)

Standard scalar quantization — what every LLM quantizer from GPTQ to AWQ does — rounds each number independently to the nearest representable value. E8 lattice quantization rounds groups of 8 numbers jointly to the nearest point on a mathematical lattice. The difference sounds subtle. It isn't.

This post is a complete walkthrough of how E8 quantization works, why it beats scalar quantization by ~30% in distortion, and exactly what the algorithm does line by line.

Why Lattices?

The core problem in quantization is sphere packing. You want to cover n-dimensional space with the fewest representable points, such that any real vector is "close" to at least one codebook entry.

For 1D scalar quantization, you're placing points on a number line. Easy — evenly space them.

For 8D vector quantization, you want to pack 8D balls as densely as possible. The densest known packing in 8 dimensions is the E8 root lattice, proven optimal by Maryna Viazovska in 2016 (she won the Fields Medal for it in 2022). Its packing density is 2x better than any irregular arrangement.

What this means for quantization: for the same number of bits, E8 introduces 30% less distortion than optimal scalar quantization, and about 1.4 dB better signal-to-noise ratio.

The E8 Lattice, Briefly

The E8 lattice is the set of all 8D vectors that satisfy either of these two conditions:

  1. All coordinates are integers, and their sum is even
  2. All coordinates are half-integers (x.5 values), and their sum is an integer (even or odd)

That's the whole definition. A lattice, not a codebook. No learned vectors, no storage.

The packing radius is 1/√2. Every point in ℝ⁸ is within distance 1/√2 of some E8 lattice point.

The Conway-Sloane Nearest-Point Algorithm

Finding the nearest E8 point to an arbitrary input x is what you need for quantization. The algorithm, from Conway and Sloane (1982), is:

  1. Find the nearest point in each of the two cosets separately
  2. Return whichever is closer to x

Coset 1: Integer lattice with even coordinate sum (D8)

Round each coordinate to the nearest integer. If the sum is already even, done. If odd, flip the coordinate where x[i] was farthest from its rounded value (i.e., we pay the smallest extra cost to make the sum even).

Coset 2: Half-integer lattice (D8 + (0.5,...,0.5))

Subtract 0.5 from x, round each coordinate, add 0.5 back. Apply the same even-sum correction.

Pick the winner:

Compute squared distance to each coset's candidate. Return the closer one.

Python Implementation

Here's the full nearest_point function from NexusQuant's e8_lattice.py:

import torch

def nearest_point(x: torch.Tensor) -> torch.Tensor:
    """Find nearest E8 lattice point to each 8D vector.

    Args:
        x: (..., 8) tensor
    Returns:
        Nearest E8 lattice point, same shape
    """
    # --- Coset 1: integer lattice with even coordinate sum ---
    r_int = x.round()
    sums = r_int.sum(dim=-1)
    odd = (sums % 2 != 0)

    if odd.any():
        # Find the coordinate we can flip at minimum cost
        gaps = (x - r_int).abs()
        idx = gaps[odd].argmin(dim=-1)          # cheapest coordinate to adjust
        fix = torch.zeros_like(r_int[odd])
        fix.scatter_(-1, idx.unsqueeze(-1), 1.0)
        # Which direction to flip? Toward x, not away
        sign = ((x[odd] - r_int[odd]).gather(-1, idx.unsqueeze(-1)) >= 0).float() * 2 - 1
        r_int[odd] = r_int[odd] + fix * sign

    # --- Coset 2: half-integer lattice ---
    r_half = (x - 0.5).round() + 0.5
    sums_h = r_half.sum(dim=-1)

    # NOTE: strict E8 requires even integer sum for the half-integer coset.
    # We intentionally relax this — see the "Relaxed Parity" section below.
    odd_h = ((sums_h * 2).round() % 2 != 0)

    if odd_h.any():
        gaps_h = (x - r_half).abs()
        idx_h = gaps_h[odd_h].argmin(dim=-1)
        fix_h = torch.zeros_like(r_half[odd_h])
        fix_h.scatter_(-1, idx_h.unsqueeze(-1), 1.0)
        sign_h = ((x[odd_h] - r_half[odd_h]).gather(-1, idx_h.unsqueeze(-1)) >= 0).float() * 2 - 1
        r_half[odd_h] = r_half[odd_h] + fix_h * sign_h

    # --- Pick closer coset ---
    d_int  = ((x - r_int)  ** 2).sum(dim=-1)
    d_half = ((x - r_half) ** 2).sum(dim=-1)

    res = r_int.clone()
    res[d_half < d_int] = r_half[d_half < d_int]
    return res
Enter fullscreen mode Exit fullscreen mode

The full quantizer wraps this with per-group scaling:

def quantize(x: torch.Tensor, levels: int = 8) -> torch.Tensor:
    """Quantize tensor with E8 lattice VQ and per-group scaling.

    levels: 4 = 2-bit, 8 = 3-bit, 16 = 4-bit
    """
    shape = x.shape
    pad = (8 - shape[-1] % 8) % 8
    if pad > 0:
        x = torch.nn.functional.pad(x, (0, pad))

    flat = x.reshape(-1, 8)
    amax = flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
    sc = amax / (levels / 2)          # scale each 8D group to [-levels/2, levels/2]

    lp = nearest_point(flat / sc)     # find nearest E8 point in normalized space
    lp = lp.clamp(-levels / 2, levels / 2)
    result = (lp * sc).reshape(x.shape)

    if pad > 0:
        result = result[..., :shape[-1]]
    return result.reshape(shape)
Enter fullscreen mode Exit fullscreen mode

Why Per-Group Scaling?

Raw KV values span a wide dynamic range. A single global scale would waste most of the representable range on the quiet regions and saturate the loud ones.

Per-group scaling (one FP16 scale per 8 values) costs 2 bytes / 8 values = 2 bits per value of overhead — the same as the lattice encoding itself. This doubles the effective bits. But the Hadamard rotation applied upstream equalizes variance across dimensions, so within each 8D group the dynamic range is small. The overhead is worth it.

Alternatively, quantize_perhead uses one scale per full 128D head vector. This cuts scale overhead to 0.125 bits/dim at a cost of ~0.07% more PPL degradation. The headroom depends on how well Hadamard equalizes your specific data.

The Relaxed Parity Discovery

This is the part that surprised us.

Strict E8 theory says the half-integer coset must also have an even sum. When we first implemented this and tested on actual KV cache data, the strict version performed worse than the relaxed version — by 0.3-0.4% PPL.

Why? KV cache distributions are sub-Gaussian (lighter tails than Gaussian, more mass near zero). The strict parity constraint over-regularizes the half-integer coset for this regime. Relaxing it effectively adds beneficial dithering: the quantizer can reach slightly off-lattice half-integer points that happen to sit closer to the true KV values.

The original E8 construction targets uniform or Gaussian distributions. Real transformer attention data is neither. The relaxed constraint is a free 0.3% quality improvement, just by removing three lines of enforcement code.

This is the kind of result that only shows up when you measure empirically rather than trusting the textbook.

Bit Rate Accounting

At 3 bits per dimension with per-group scales (group size 8):

  • Lattice codes: 3 bits × 8 dims = 24 bits per group
  • Scale (FP16): 16 bits per group
  • Total: 40 bits / 8 dims = 5 bits per dimension raw storage

After delta coding + zstd compression on the lattice indices (consecutive tokens produce similar codes; deltas compress 2-3x):

  • Effective rate: ~2 bits per dimension

That's 16 bits FP16 → 2 bits effective = 8x compression from quantization alone. Combined with 2.5x token eviction: ~17x total (the "balanced" preset).

Running It

pip install nexusquant-kv
Enter fullscreen mode Exit fullscreen mode
from nexusquant.core.e8_lattice import E8Lattice
import torch

# Quantize any tensor — works for arbitrary shape
x = torch.randn(100, 128)  # 100 KV vectors, 128-dim heads
q = E8Lattice.quantize(x, levels=8)  # 3-bit E8

print(f"MSE: {((x - q)**2).mean():.6f}")
print(f"Relative error: {((x - q)**2).sum() / (x**2).sum() * 100:.2f}%")

# Or per-head scaling (less overhead, slightly more error)
q2 = E8Lattice.quantize_perhead(x, levels=8)
Enter fullscreen mode Exit fullscreen mode

What Makes This Hard in Practice

Three things the textbook doesn't tell you:

  1. The Hadamard rotation is load-bearing. Without it, KV vectors have heavy outliers in specific dimensions (the same phenomenon that makes LLM quantization hard). E8 performs poorly on unrotated data. Hadamard spreads the energy uniformly so the per-group scale assumption holds.

  2. RoPE removal matters. Rotary position embeddings make adjacent tokens' keys look "rotated" relative to each other. Removing RoPE before quantization reduces inter-token correlation and cuts quantization error ~0.7% PPL.

  3. GPU kernels are unfinished. The current implementation dequantizes immediately for compatibility. Real memory savings require keeping the compressed representation on-device, which needs a Triton kernel to pack/unpack E8 codes on the fly. That's the main engineering gap.

Repo: github.com/jagmarques/nexusquant

Best regards, João Marques

Top comments (0)