Recap. Part 1 framed the problem (trajectory reward is too coarse for multi-step agents) and SDAR's fix (a privileged teacher gives dense token-level guidance, filtered through a gate). Part 2 put the four-model system on AWS and counted the GPU cost. This part is the payoff: the actual gate, in PyTorch.
Honest label up front: what follows is a reference implementation - faithful to the paper's mechanism, written to be read and reasoned about. It is not a benchmarked run. I have no convergence curves to sell you; I have the machinery. (Part 4 designs the verification that money would buy.)
The claim: it's basically one sigmoid
Strip away the framework scaffolding and SDAR's entire contribution is a weighting coefficient on a distillation loss. Three moves:
- Measure how much the teacher likes a token versus the student - the gap.
- Squash that gap through a sigmoid to get a per-token weight in
[0, 1]. - Use that weight to scale a token-level distillation loss, then add it to the ordinary GRPO loss.
Positive gap (teacher more confident than student → genuine endorsement) → weight near 1 → distill hard. Negative gap (teacher less confident → a rejection that might just be noise) → weight near 0 → soften, don't obey. That asymmetry is the whole idea.
The math, slowly
For each generated token t, the teacher (with privileged context) and the student each assign a probability to the token that was actually produced. Define the gap as the difference in their log-probabilities:
gap_t = log p_teacher(token_t | privileged_context) − log p_student(token_t)
Pass it through a sigmoid to get the gate:
gate_t = σ(gap_t / τ)
τ is a temperature that stops the sigmoid from snapping to a hard 0/1 (more on that in the traps). The distillation signal itself is a per-token KL that pulls the student toward the teacher's full distribution:
KL_t = Σ_v p_teacher(v) · ( log p_teacher(v) − log p_student(v) )
And the combined objective:
L_total = L_GRPO + λ · mean_t( gate_t · KL_t )
RL stays primary. The gated distillation is an auxiliary nudge whose strength per token is set by how much we trust the teacher on that token.
The code
Framework-agnostic PyTorch, written to drop into the actor-update step of a verl-agent/OpenRLHF loss function. student_logits come from the policy being trained; teacher_logits come from a frozen, privileged-context forward pass done under no_grad.
import torch
import torch.nn.functional as F
def gated_distillation_loss(
student_logits, # [B, T, V] - requires grad (the policy)
teacher_logits, # [B, T, V] - from a no_grad privileged forward pass
actions, # [B, T] - the token ids actually generated
response_mask, # [B, T] - 1 on generated tokens, 0 elsewhere
tau: float = 1.0, # gate temperature
):
student_logp = F.log_softmax(student_logits, dim=-1) # [B, T, V]
teacher_logp = F.log_softmax(teacher_logits, dim=-1) # [B, T, V]
# --- 1. token-level gap on the realized action ---
s_tok = student_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # [B, T]
t_tok = teacher_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # [B, T]
gap = (t_tok - s_tok).detach() # DETACH: this is a weight, not a loss
# --- 2. the gate: positive gap -> ~1 (distill), negative -> ~0 (soften) ---
gate = torch.sigmoid(gap / tau) # [B, T], in (0, 1), already detached
# --- 3. per-token forward KL (teacher || student), pulls student toward teacher ---
teacher_p = teacher_logp.exp()
kl_per_tok = (teacher_p * (teacher_logp - student_logp)).sum(-1) # [B, T]
# --- 4. gate-weighted, masked mean over generated tokens only ---
weighted = gate * kl_per_tok * response_mask
return weighted.sum() / response_mask.sum().clamp(min=1.0)
And where it joins the primary objective:
rl_loss = grpo_policy_loss(...) # your existing GRPO term + KL-to-reference
distill = gated_distillation_loss(student_logits, teacher_logits,
actions, response_mask, tau=tau)
total_loss = rl_loss + lam * distill # lam scheduled, see below
total_loss.backward()
That's it. The teacher forward pass and the gather are the only real additions to a working GRPO step.
The four traps that turn this into a NaN
The code above is short. Getting it right is where the time goes.
1. Detach the gap, and run the teacher under no_grad.
The gate is a weight, not part of the loss surface. If gap keeps its graph, gradients flow into the teacher branch (which should never update) and into the weighting itself, producing bizarre second-order behaviour. gap.detach() plus a with torch.no_grad(): around the teacher forward pass. Forget either and you'll spend an evening confused.
2. Mind the KL direction.
Forward KL KL(teacher‖student) is mode-covering - the student tries to put mass everywhere the teacher does. Reverse KL KL(student‖teacher) is mode-seeking - the student collapses onto the teacher's peak. Distillation usually wants forward KL (the code above). Swapping them silently changes what your agent learns; it won't crash, it'll just quietly behave differently.
3. Watch the gate saturate.
If gaps are large in magnitude, σ pins to 0 or 1 and your "soft" gate becomes a hard binary mask - you've thrown away the nuance that justified the sigmoid. The temperature τ is the fix: raise it to keep the gate responsive. Log the gate's distribution during training; if it's bimodal at the extremes, τ is too low.
4. Soften negatives - don't zero them.
The reason this is σ(gap) and not relu(gap) or a hard threshold: a teacher rejection might come from bad skill retrieval, not a bad token (this was the whole motivation in Part 1). A sigmoid leaves a small non-zero weight on rejected tokens, so a noisy "no" can't fully erase a token that was actually fine. Zeroing them throws that hedge away.
One more, not a NaN but a stability killer: schedule λ. Start it low (or at zero) and warm it up. Let GRPO establish a competent policy first; ramp the distillation in afterward. Cranking λ from step zero hands control to the teacher's noisiest early signals - which is exactly the naive-GRPO+OPSD instability SDAR exists to avoid.
Where it lives on AWS
Mapping back to Part 2's system:
- The function above sits inside the actor update on your single GPU node (
p4d/p5). - The teacher forward pass assembles privileged context from the DynamoDB skill store, then runs a frozen model - ideally on spot, since it never trains and is fully restartable.
-
Checkpoint aggressively to S3: actor weights, optimizer state, and the
λschedule position. Spot gives a ~2-minute reclaim notice; you want a job that resumes mid-schedule, not one that restarts fromλ=0and wastes the warm-up you already paid for.
Optional: the near-free "it runs" experiment
If you want a single screenshot of the gate behaving - not convergence, just proof the plumbing is sound - you can do it on free compute:
- Colab/Kaggle free tier (one T4, 16 GB), Qwen2.5-0.5B + LoRA, a handful of ALFWorld episodes.
- Goal: the loss doesn't NaN, and the gate-value histogram shifts as the student learns. That's it.
- It will almost certainly not learn the task at 0.5B on a toy slice - and that's fine. You're validating the mechanism, not the result. Label any plot from this as toy-scale, or a commenter rightly will.
What's left
We have the mechanism. What we don't have - by design and by budget - is proof it beats the baselines and proof it's more stable than naive GRPO+OPSD, which is SDAR's real selling point. Part 4 designs exactly that: the three-way comparison, the stability instrumentation most reproductions skip, and the FinOps reality of running it for real.
Next: "Evaluation, Stability & FinOps" - how you'd prove the gate earns its keep, and what proving it costs.
If you've implemented gated or weighted distillation losses, I'd genuinely like to know how you handled the detach boundary and the KL direction - comments are open.
Top comments (1)
Some comments may only be visible to logged-in visitors. Sign in to view all comments.