DEV Community

Papers Mache
Papers Mache

Posted on

Aligning Hidden States Stabilizes LLM Distillation

Hidden‑representation alignment drives KL variance to exactly 0, turning on‑policy LLM distillation into a noise‑free optimization problem. The deterministic MSE loss over rollouts removes the stochasticity that has long plagued policy‑gradient distillation pipelines. [1]

Before OPRD, on‑policy distillation (OPD) sampled KL divergence over vocabularies of roughly 150 k tokens, inflating Monte Carlo variance and causing late‑stage stagnation. The teacher was treated as a black box, and only output logits received supervision, leaving the bulk of the model’s internal dynamics untouched. [1]

OPRD’s deterministic MSE objective yields a 1.44× speedup over top‑k OPD while closing the performance gap on AIME 2024/2025 and AIMO benchmarks. Empirically the student reaches teacher‑level scores where pure OPD plateaus, confirming that hidden‑state supervision supplies richer gradients than logits alone. “OPRD’s MSE objective is a deterministic function of the rollout; its gradient carries zero additional sampling variance, eliminating the late‑stage signal‑to‑noise collapse of OPD by construction.” [1]

Memory consumption drops by 54 % because OPRD discards the LM head and stores only selected hidden layers, avoiding the massive top‑k token caches required by OPD. This reduction enables training on commodity GPUs without sacrificing batch size, a practical gain for teams without large‑scale clusters. The codebase offers a REP_DISTILLATION_ONLY=True flag that removes the KL term entirely, confirming the memory savings stem from the representation‑only regime. [1]

The paper’s analysis also reveals why OPD collapses: “The OPD variance in (8) does not vanish as , and through the score‑function term it dominates the policy gradient late in training; this is the mechanism behind the late‑stage stagnation of pure OPD (Section˜1). OPRD adds zero conditional variance and therefore provides a stable optimization signal even after the output distribution has nearly converged.” The paper demonstrates that OPRD eliminates sampling variance for the selected layers it aligns, but does not evaluate effects on deeper or cross‑layer interactions; the method still requires generating teacher rollouts for each student update. One open question is whether a lightweight teacher‑free proxy could preserve the zero‑variance property at scale.

If hidden‑state alignment truly nullifies KL variance, future compression pipelines should replace vanilla OPD with OPRD as the default distillation routine. Benchmarks such as AIME and AIMO ought to be re‑run with the representation‑only loss, expecting both higher student accuracy and lower compute budgets.

References

  1. OPRD: On-Policy Representation Distillation

Top comments (0)