Table of Contents
- Motivation
- Show me the code: Jupyter notebook
- Peering into the GRPO equation
- Part 1:
- The rollout phase:
- Part 2:
- Part 3:
- Part 4:
- Part 5:
- Part 6:
- Bringing it all together
- Concrete example of how to train using GRPO
- Conclusion
Motivation
Group Relative Policy Optimisation (GRPO) was a method developed by DeepSeek to improve "language model reasoning capabilities using pure reinforcement learning (RL)", with the specific goal to "develop reasoning capabilities without any supervised data, focusing on their self-evolution through a pure RL process" (source: DeepSeek-R1: Incentivising Reasoning Capabilitiy in LLMs via Reinforcement Learning). GRPO was the method used to train DeepSeek-R1 released in January 2025, which crashed tech stocks such as Nvidia, and served as the basis of subsequent reasoning models such as the Mistrel Magistral (source: Magistral).
This article seeks to explain the math of GRPO, and how to implement GRPO from scratch to train LLMs using Apple MLX framework. Hence, if you are an Apple silicon user, you are in luck, you can run the Jupyter notebook right on your laptop.
This article is created with the assistance of Google Gemini 2.5 Pro.
Show me the code: Jupyter notebook
For those who are keen to dive right into running the code, you may access it here. If you discover any mistakes or have any improvements to suggest, please feel free to make a pull request! I will look into all requests as soon as I can.
(Note: Please note that this code currently only has a dummy reward function. I will add a reward function in the notebook later on and inform again when done, but please feel free to file a pull request if you would like to contribute in any way.)
Peering into the GRPO equation
We will dissect is this scary-looking (at least to me!) equation:
This is the theoretical expectation, which theoretically considers every possible prompt, generates a group of responses for each, and then average the improvement we get from the objective function (see below) over all these possibilities.
This is the sample-based objective function (the "estimator"). Since we cannot possibly compute the true expectation over all prompts and outputs, we approximate it during training where we take one batch ( a prompt q
and its G
generated outputs
and we compute the numerical estimate of our objective using the objective function. We then use this numerical estimate to calculate a gradient and update our model's weights, i.e. (
becomes
).
I will break down the equations into the following parts:
Part 1:
Part 2:
Part 3:
Part 4:
Part 5:
Part 6:
Part 1
The equation that will be dissected in this part is:
-
The tilde (~): It means "distributed according to" or "is sampled from". So:
- This represents questions (q) sampled from the overall distribution of questions (P(Q)). This is a standard practice where models learn to respond to various prompts during training.
- : This means the group of G outputs is sampled from the policy given question q.
The expectation function i.e.
signifies a joint expectation over multiple random variables. One commonly seen example is
which means taking the expectation over the combined process of first sampling X from P(X), and then sampling Y from P(Y|X). The expectation then applies to the function of both X and Y.In probability theory,
can be written as
for discrete variables, or
for continuous variables. The comma notation is a shorthand for this sequential or joint sampling process.-
Applying to our equation:
- q ~ P(Q): First, a question
q
is randomly chosen from the pool of all possible questions. -
: Given that specific question
q
, a group ofG
outputs is then generated by the old policy . - The expression
that follows (the GRPO objective function) then depends on both
q
and the generated output . - So, the expectation is taken over the entire data collection process: first randomly pick a question from the bank of available questions, and then generating multiple responses for that question using the policy.
- q ~ P(Q): First, a question
In summary:
- The tilde (~) tells you how the data is being generated (which distribution).
- The comma (,) separates independent (or conditionally independent) sampling steps that define the full set of random variables over which the expectation is taken. It implies a joint probability distribution, often constructed sequentially.
Why instead of simply ?
- The notation explicitly states that q is the random variable being sampled, and P(Q) is its probability distribution.
- Similarly, states that are the random variables (the sampled outputs), and is the conditional probability distribution from which they are drawn (conditioned on q).
Without q ~ and , the expression is ambiguous as it is not clear which variables are being sampled or how are they related to the function . The P(Q) and are distributions, not actual values that vary and contribute to the average.
Show me the code
To implement the Expectation operator E[...]. we can use a loop where each iteration processes a new, randomly sampled mini-batch of prompts, calculate the loss for that batch, and performs an update. Over many iterations, this process approximates the expected value of the objective over the entire data distribution.
Do not fret that the code is long. I will break it down and explain each piece accordingly.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
def grpo_train_loop(
model, model_old, model_ref, tokenizer, optimizer, train_set,
iters=200, group_size=4, batch_size=2, epsilon=0.2, beta=0.01,
update_every=10, max_ans_len=4
):
# Create a grad function for the trainable model
loss_and_grad_fn = nn.value_and_grad(model, grpo_loss_fn)
losses = []
all_rewards = []
# Start training
pbar = tqdm.tqdm(range(iters))
for it in pbar:
batch_prompts = []
batch_answers = []
# 1. Sample a batch of prompts
indices = np.random.randint(0, len(train_set), batch_size)
for i in indices:
# The last word of the output is the ground truth answer (e.g., "ending4")
prompt_text, answer_text = train_set[i]["output"].rsplit(" ", maxsplit=1)
full_prompt = [
{"role": "user", "content": train_set[i]["instruction"]},
{"role": "assistant", "content": prompt_text}
]
batch_prompts.append(full_prompt)
batch_answers.append(answer_text)
# 2. Rollout: Generate G responses for each prompt using the old model
rollout_sequences = []
rollout_rewards = []
rollout_log_probs = []
rollout_a_toks = []
for i in range(batch_size):
prompt_tokens = tokenizer.apply_chat_template(batch_prompts[i], continue_final_message=True)
group_rewards = []
for _ in range(group_size):
# Generate a response
response = generate(model_old, tokenizer, prompt_tokens, max_tokens=max_ans_len)
answer_tokens = tokenizer.encode(response, add_special_tokens=False)
# 3. Get Reward
reward = 1.0 if batch_answers[i] in response else 0.0
group_rewards.append(reward)
# Store data for the optimization step
full_sequence = mx.array(prompt_tokens + answer_tokens)
rollout_sequences.append(full_sequence)
rollout_a_toks.append(mx.array(answer_tokens))
all_rewards.extend(group_rewards)
rollout_rewards.append(mx.array(group_rewards))
# 4. Compute Advantages
advantages = []
for rewards in rollout_rewards:
mean_reward = mx.mean(rewards)
std_reward = mx.sqrt(mx.var(rewards)) + 1e-8 # Add epsilon for stability
adv = (rewards - mean_reward) / std_reward
advantages.append(adv)
advantages = mx.concatenate(advantages)
sequences = pad_sequences(rollout_sequences, tokenizer.pad_token_id)
a_toks = pad_sequences(rollout_a_toks, tokenizer.pad_token_id)
# Calculate log_probs with the old model for the ratio calculation
old_log_probs = calculate_log_probs(model_old, sequences, a_toks)
# 5. Optimization Step
(loss, policy_reward, kl_div), grads = loss_and_grad_fn(
model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon
)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
losses.append(loss.item())
pbar.set_description(f"Loss: {np.mean(losses[-10:]):.3f}, Mean Reward: {np.mean(all_rewards[-20:]):.3f}")
# Sync old model weights
if (it + 1) % update_every == 0:
model_old.update(model.parameters())
print(f"\nIter {it+1}: Synced old model weights.")
# Final save of adapter weights
model.save_weights(str(adapter_path / "adapters.safetensors"))
print("Saved final weights to adapters/adapters.safetensors.")
return losses, all_rewards
The rollout phase:
For each prompt q in our batch, we need to generate a group G possible outputs ( ) using the fixed, old policy . This is the data collection or "rollout" phase.
The code is implemented with a nested for loop within grpo_train_loop
that calls generate
. In the code, model_old
is
. The outer loop iterates through prompts in the batch, and the inner loop runs group_size
(G) times to generate each output
for that prompt.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Rollout: Generate G responses for each prompt using the old model
...
for i in range(batch_size):
prompt_tokens = tokenizer.apply_chat_template(batch_prompts[i], continue_final_message=True)
group_rewards = []
for _ in range(group_size):
# Generate a response
response = generate(model_old, tokenizer, prompt_tokens, max_tokens=max_ans_len)
answer_tokens = tokenizer.encode(response, add_special_tokens=False)
# 3. Get Reward
reward = 1.0 if batch_answers[i] in response else 0.0
group_rewards.append(reward)
# Store data for the optimization step
full_sequence = mx.array(prompt_tokens + answer_tokens)
rollout_sequences.append(full_sequence)
rollout_a_toks.append(mx.array(answer_tokens))
all_rewards.extend(group_rewards)
rollout_rewards.append(mx.array(group_rewards))
Part 2:
The concept behind this equation is importance sampling:
(The Denominator): This represents the probability of generating a specific output (a complete reasoning trace and answer) given the input question , using the old policy. The "old" policy is the version of the model that was used to generate the batch of data for the current training step. It is frozen during this step.
(The Numerator): This represents the probability of generating that exact same output given the question , but using the current policy . This is the policy we are actively training and updating in this step.
-
The Ratio: The ratio is the likelihood ratio or importance weight.
- If , the new policy is more likely to generate output than the old one was.
- If , the new policy is less likely to generate output .
- If , the policy has not changed with respect to this specific output.
What is its purpose? (Off-Policy Learning)
The primary purpose of this ratio is to enable off-policy learning.
In reinforcement learning, the ideal way to evaluate a policy is to use it to generate actions and see what rewards you get. However, generating new outputs ( ) from the model for every single gradient update is computationally very expensive.
Off-policy methods solve this. We can:
- Generate a large batch of experiences (the outputs ) using the policy.
- Then, perform several steps of optimization on our policy using that same batch of data.
The importance sampling ratio is the mathematical "correction factor" that allows us to estimate how good an action is under the new policy ( ) using data that was collected by the old policy ( ).
For the ease of discussion, I will equate:
The objective function multiplies this ratio by the advantage (how good the output was). So, the update logic is:
- If was a good output ( ), we want to increase its probability. Maximising will push to be greater than 1, which in turn pushes to increase.
- If was a bad output ( ), we want to decrease its probability. Maximising (a negative number) will push to be less than 1, making the overall term less negative and thus decreasing .
Why is it designed that way? (Stability and PPO)
While the ratio allows for efficient learning, it is also a source of instability. If the new policy becomes very different from the old one , the ratio can become extremely large or close to zero. A very large ratio would lead to a massive, noisy gradient update, potentially destroying all the learning the model has already done.
This is the problem that Proximal Policy Optimization (PPO), from which GRPO's objective is derived, was designed to solve. The design in Equation 1 is a direct implementation of the PPO-Clip objective (source: Proximal Policy Optimization).
The goal is to keep the new policy "proximal" (i.e., close) to the old policy. This creates a "trust region" where we can be confident the update is beneficial.
Part 3:
This is the core of the PPO algorithm, which encourages making the new policy more likely to produce high-advantage outputs, but "clips" the update to prevent it from changing to drastically and destabilsing training.
Example of clipping
Take for instance we set to be 0.2. We then get the following clipping equation:
Examples of how the clipping equation works is below:
- because the value is within the range.
- because the value is beyond the range and is clipped down to the maximum value of 1.2.
We can see that with clipping, when the optimiser gets "greedy" and suggests a huge change, the model is still encouraged to make the output more likely, but is prevented from making a dangerously large jump in the policy.
Code Implementation
This logic is encapsulated within the grpo_loss_fn
.
The probability ratio
is calculated in log-space for
numerical stability:
).
# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon):
"""The GRPO loss function."""
# Get log probs from the trainable model (π_θ)
log_probs = calculate_log_probs(model, sequences, a_toks)
# Get log probs from the reference model (π_ref) for KL penalty
log_probs_ref = calculate_log_probs(model_ref, sequences, a_toks)
# PPO-clip objective
ratio = mx.exp(log_probs - old_log_probs)
clipped_ratio = mx.clip(ratio, 1.0 - epsilon, 1.0 + epsilon)
policy_reward = mx.minimum(ratio * advantages, clipped_ratio * advantages)
# KL penalty
# Step 1: Calculate log(r) where r = π_ref / π_θ
# log(r) = log(π_ref) - log(π_θ)
log_ratio_for_kl = log_probs_ref - log_probs
# Step 2: Calculate r itself by exponentiating log(r)
# r = exp(log(r))
ratio_for_kl = mx.exp(log_ratio_for_kl)
# Step 3: Apply the paper's full formula: r - log(r) - 1
kl_div = ratio_for_kl - log_ratio_for_kl - 1
# The objective is to maximize this, so we return the negative for minimization
loss = -mx.mean(policy_reward - beta * kl_div)
return loss, mx.mean(policy_reward), mx.mean(kl_div)
The helper function calculate_log_probs
is responsible for computing log P(o_i | q) for a given policy.
def calculate_log_probs(model, sequences, a_toks):
"""Calculates the log probabilities of the generated answer tokens."""
# Pass the full sequence (prompt + answer) to the model
logits = model(sequences)
# Convert to log probabilities
log_probs_full = nn.log_softmax(logits, axis=-1)
## Find the actual positions where answer tokens should be extracted
# This assumes a_toks contains the actual token IDs that were generated
batch_size, seq_len = sequences.shape
_, ans_len = a_toks.shape
# Calculate the starting position for answer tokens (assuming they're at the end)
start_pos = seq_len - ans_len
# Extract log probabilities for the answer portion of the sequence
answer_log_probs = log_probs_full[:, start_pos:start_pos+ans_len, :]
# Create indices for gathering - ensure proper shape alignment
indices = a_toks[:, :, None]
# Extract log probabilities for the actual answer tokens
selected_log_probs = mx.take_along_axis(answer_log_probs, indices, axis=-1).squeeze(-1)
# Sum log probabilities across the answer sequence
return mx.sum(selected_log_probs, axis=-1)
Part 4:
The purpose of this min
function is to act as a floor, i.e.:
- When an output is good (positive ), it prevents the update from becoming too rewarding.
- When an output is bad (negative ), it acts as a floor on the penalty.
In both cases, the min
function prevents the model from making a large policy change too quickly.
Example
Using back the same clipping function in Part 3, where we set , , and we now assume we have an advantage value . The equation we thus get is:
Part 5:
This function acts as a regulariser, penalising the policy for deviating from a reference policy . (Note: When we say policy, we are actually referring to the LLM as weights, so refers to the LLM with updated weights, while usually refers to the original stock LLM.)_
The paper defines the term as:
What is key is to recognise that this equation is not the standard Kulback-Leibler (KL) divergence, but a more specific, per-sample approximation chosen for being computationally cheaper. This compares to the standard KL divergence which can be expressed as:
Compared to the standard definition, there are two main differences:
- The standard definition is an expectation over the entire distribution , whereas the paper's variant is an expression for a single sample .
- The functional form of the paper's variant is different from the term inside the standard expectation.
Deconstruction and Analysis of the variant of KL divergence
For simplicity, let the probability ratio be r
:
Then Equation (2) defines a function :
This function is evaluated for a single sample , which itself was drawn from the old policy, .
In order for to be a valid divergence measure, it must satisfy two properties:
- Non-negativity: .
- Identity of Indiscernibles: if and only if .
Proof that satisfies the two properties is in Appendix A.
Show me the code
This term is a regularizer. It penalizes the objective if the trainable policy π_θ strays too far from the original, trusted reference policy π_ref, helping to prevent catastrophic forgetting.
Code Implementation: Also within grpo_loss_fn
.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(...):
...
# Get log probs from the reference model (π_ref)
log_probs_ref = calculate_log_probs(model_ref, ...)
# KL penalty
# Step 1: Calculate log(r) where r = π_ref / π_θ
# log(r) = log(π_ref) - log(π_θ)
log_ratio_for_kl = log_probs_ref - log_probs
# Step 2: Calculate r itself by exponentiating log(r)
# r = exp(log(r))
ratio_for_kl = mx.exp(log_ratio_for_kl)
# Step 3: Apply the paper's full formula: r - log(r) - 1
kl_div = ratio_for_kl - log_ratio_for_kl - 1
...
Part 6:
The Advantage Function is a central component in modern policy gradient methods. Intuitively, the advantage tells us not just if an action was "good" (positive reward), but if it was "better than average". It is designed to reduce sensitivity to reward scaling, and stabilises training by preventing outlier rewards. A more technical discussion about the advantage function is available in Appendix B.
Given that the advantage tells us how much better or worse a specific output was compared. to the average of its group, this requires two steps:
- Calculate the raw reward ; and
- Normalising the output
In short, the mathematical equation is:
We can implement the code as such:
- Reward calculation as a simple rule-based reward.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Get Reward (r_i)
reward = 1.0 if batch_answers[i] in response else 0.0
group_rewards.append(reward)
- Normalisating to get Advantage
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Compute Advantages
advantages = []
for rewards in rollout_rewards:
mean_reward = mx.mean(rewards)
std_reward = mx.sqrt(mx.var(rewards)) + 1e-8 # Epsilon for stability
# This line directly implements the advantage formula
adv = (rewards - mean_reward) / std_reward
advantages.append(adv)
advantages = mx.concatenate(advantages)
Bringing it all together
Finally, we combine all the pieces, average over the batch, and negate the result to create a loss that can be minimised by the optimiser.
Mathematical Component:
where N is the total batch size (batch_size * group_size).
Code Implementation: The final lines of grpo_loss_fn
and the optimizer.update call in the training loop. The code was displayed in Part 3 above, with the relevant abridged segments reproduced below for ease of reference.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(...):
...
# The objective is to maximize this, so we return the negative for minimization
# mx.mean() handles the averaging over the batch
loss = -mx.mean(policy_reward - beta * kl_div)
return loss, ...
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# 5. Optimization Step
(loss, ...), grads = loss_and_grad_fn(...)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
Concrete example of how to train using GRPO
- Prompt (q): "What is the capital of Malaysia?"
- Policy ( ): The existing LLM we are trying to improve
- Old policy ( ) A frozen copy of the LLM before this training step.
- Reference policy ( ): The original, pre-trained base LLM (e.g. DeepSeek-V3-Base).
- Hyperparameters:
- Group size (G): 4
- Clipping epsilon ( ): 0.2
- KL penality beta ( ): 0.05
Step 1: sample generation and reward calculation
We use the old policy to generate G = 4 different outputs for prompt q. Then ,we use a rule-based reward model to score them.
i | Output (o_i ) |
Reward (r_i ) |
Notes on Reward |
---|---|---|---|
1 | "The capital of Malaysia is Kuala Lumpur." | 1.0 | Correct answer. |
2 | "The capital of Malaysia is Johor." | 0.0 | Incorrect answer. |
3 | "Malaysia's capital city is Kuala Lumpur." | 1.0 | Correct answer, different wording. |
4 | "The capital of Malaysia is Selangor**" | 0.9 | Incorrect but Kuala Lumpur is surrounded by Selangor. |
Step 2: calculate normalised advantage
First, calculate the mean and standard deviation of the rewards:
Now, we compute the advantage for each sample:
- (Positive Advantage: this output was better than average)
- (Negative Advantage: this output was worse than average)
- (Positive Advantage)
- (Positive Advantage, but smaller)
Step 3: Calculate Policy Probabilities and Ratios
For each of our 4 samples, we need to compute its probability under the old policy ( ), the current policy ( ), the reference policy ( ), and the ratio . These are hypothetical values for our example. It is important to note these are not probabilities that sum to 1, as they only represent 4 outputs out of a near-infinite number of possibilities.
i | old policy | current policy | reference policy | ratio of current policy to old policy |
---|---|---|---|---|
1 | 0.25 | 0.30 | 0.28 | 1.20 |
2 | 0.30 | 0.21 | 0.15 | 0.70 |
3 | 0.20 | 0.28 | 0.26 | 1.40 |
4 | 0.25 | 0.21 | 0.22 | 0.84 |
Interpretation: For sample 1, our new policy π_θ
is more confident (0.30) than the old one (0.25), so the ratio is > 1. For sample 2, the new policy is less confident, so the ratio is < 1.
Step 4: Calculate the Clipped Surrogate Objective for Each Sample
Now we apply the min(..., clip(...))
formula for each sample. The clip range is [1 - ε, 1 + ε] = [0.8, 1.2]
.
Sample 1 (A₁ ≈ 0.567 > 0):
- Unclipped term:
r₁(θ) * A₁ = 1.20 * 0.567 ≈ 0.680
- Clipped ratio:
clip(1.20, 0.8, 1.2) = 1.2
- Clipped term:
1.2 * 0.567 ≈ 0.680
-
min(0.680, 0.680) = 0.680
. The ratio was within the clip bounds.
Sample 2 (A₂ ≈ -1.495 < 0):
- Unclipped term:
r₂(θ) * A₂ = 0.70 * -1.495 ≈ -1.047
- Clipped ratio:
clip(0.70, 0.8, 1.2) = 0.8
- Clipped term:
0.8 * -1.495 ≈ -1.196
-
min(-1.047, -1.196) = -1.196
. The value is clipped. This is a subtle but crucial point. The optimizer's goal is to maximize the objective. An objective of -1.047 is better than -1.196. By forcing the optimizer to take themin
, we are selecting the worse of the two possible objectives. This limits the size of the policy update, preventing the model from making a large, potentially unstable change even when correcting a mistake.
Sample 3 (A₃ ≈ 0.567 > 0):
- Unclipped term:
r₃(θ) * A₃ = 1.40 * 0.567 ≈ 0.794
- Clipped ratio:
clip(1.40, 0.8, 1.2) = 1.2
- Clipped term:
1.2 * 0.567 ≈ 0.680
-
min(0.794, 0.680) = 0.680
. The policy update is clipped to prevent it from getting too greedy on this good sample.
Sample 4 (A₄ ≈ 0.361 > 0):
- Unclipped term:
r₄(θ) * A₄ = 0.84 * 0.361 ≈ 0.303
- Clipped ratio:
clip(0.84, 0.8, 1.2) = 0.84
- Clipped term:
0.84 * 0.361 ≈ 0.303
-
min(0.303, 0.303) = 0.303
. The ratio was within the clip bounds.
Step 5: Calculate the KL Penalty for Each Sample (Using Equation 2)
Now we calculate the penalty term D_{KL}(\pi_\theta || \pi_{\text{ref}})
for each sample. Let r_{ref} = π_{ref} / π_θ
.
The formula is r_{ref} - log(r_{ref}) - 1
.
- Sample 1:
r_{ref} = 0.28 / 0.30 ≈ 0.933
. Penalty =0.933 - log(0.933) - 1 ≈ 0.933 - (-0.069) - 1 = 0.002
- Sample 2:
r_{ref} = 0.15 / 0.21 ≈ 0.714
. Penalty =0.714 - log(0.714) - 1 ≈ 0.714 - (-0.337) - 1 = 0.051
- Sample 3:
r_{ref} = 0.26 / 0.28 ≈ 0.929
. Penalty =0.929 - log(0.929) - 1 ≈ 0.929 - (-0.074) - 1 = 0.003
- Sample 4:
r_{ref} = 0.22 / 0.21 ≈ 1.048
. Penalty =1.048 - log(1.048) - 1 ≈ 1.048 - (0.047) - 1 = 0.001
Step 6: Combine Everything to Get the Final Value
The final loss for our batch is the average over the 4 samples. For each sample i
, the value is (Clipped_Objective_i - β * KL_Penalty_i)
.
- Sample 1 Value:
0.680 - (0.05 * 0.002) = 0.680 - 0.0001 = 0.6799
- Sample 2 Value:
-1.196 - (0.05 * 0.051) = -1.196 - 0.00255 = -1.19855
- Sample 3 Value:
0.680 - (0.05 * 0.003) = 0.680 - 0.00015 = 0.67985
- Sample 4 Value:
0.303 - (0.05 * 0.001) = 0.303 - 0.00005 = 0.30295
Total Objective J
(for this one prompt):
J_GRPO = (1/4) * (0.6799 - 1.19855 + 0.67985 + 0.30295) = (1/4) * 0.46415 ≈ 0.116
Final Action
The value J_GRPO ≈ 0.116
is the number we want to maximize. The optimizer (like Adam) will compute the gradient of this objective with respect to the LLM's parameters (∇_θ J_GRPO
) and take a small step in that gradient's direction. This single step will slightly adjust the millions of weights in π_θ
to:
- Increase the probability of outputs like
o₁
ando₃
(the good ones). - Decrease the probability of the bad output
o₂
. - Do this while being constrained by the clipping mechanism and pulled slightly back towards the original
π_{ref}
model to avoid forgetting how to form coherent sentences.
Conclusion
Congratulations on making this far. The full Jupyter notebook to train your LLM on your Apple silicon computer is accessible here. If you discover any mistakes or have any improvements to suggest, please feel free to make a pull request! I will look into all requests as soon as I can.
This being my third article, I have covered:
- Building softmax self-attention from scratch
- The math behind linearised self-attention
- fine tuning with GRPO
My future articles will continue to revolve around these topics:
- Building LLM from scratch (because why not?)
- Fine tuning
- LLM evaluations
If you have any interesting topics related to LLMs or machine learning in general that you are interested for me to explore, please let me know. I am open to ideas.
Appendix A: Proof that is a valid divergence measurement
The two properties to satisfy are:
- Non-negativity: .
- Identity of Indiscernibles: if and only if .
1. Proof of Identity of Indiscernibles
We must show that if and only if . The condition implies , which means the policies are identical for this specific output.
- If : . This part of the proof is trivial.
- If : implies . Consider the graphs of (a straight line) and . They are tangent at the point . To prove this formally, let . We want to find the roots of . The derivative is . Setting gives . This is the only extremum. Since , the function has a minimum value of 0 at . Therefore, the only real solution to is . This completes the proof that .
2. Proof of Non-Negativity
We must show that
for all
(since
is a ratio of probabilities, it must be positive).
- Let's use calculus again on .
- First Derivative: .
- Critical Point: .
- Second Derivative: .
- Since , we have for all in the domain. This proves that is a strictly convex function.
- A strictly convex function has a unique global minimum at its critical point. We found this critical point to be .
- The value of the function at this global minimum is .
- Since the function's global minimum value is 0, it must be that for all .
This completes the proof of non-negativity.
Appendix B - A more technical discussion on the advantage function
The Advantage Function, , is a central component in modern policy gradient methods. In reinforcement learning, the simplest policy gradient update rule uses the total reward to scale the gradient . However, this approach suffers from high variance, meaning the gradient estimates can fluctuate wildly from one batch of samples to another, leading to unstable training.
The core idea to reduce this variance is to subtract a baseline from the reward. The baseline should ideally be an estimate of the average reward from state . This leads to the Advantage Function:
Intuitively, the advantage tells us not just if an action was "good" (positive reward), but if it was "better than average". If , the action was better than expected, and its probability should be increased. If , the action was worse than expected, and its probability should be decreased.
Key Theorem (Baseline Invariance of Policy Gradient):
The introduction of a baseline
that depends only on the state s
(or in our case, the prompt
) does not introduce bias into the gradient estimate.
Proof:
We need to show that
.
This proves that subtracting a baseline does not change the expected gradient, . While the expectation is the same, the variance of the gradient estimator is significantly reduced.
The Equation
The paper's specific implementation of the advantage function for a group of outputs is:
where are the rewards for the corresponding outputs.
Step-by-Step Component Breakdown
Let's deconstruct the formula for the advantage of the i
-th sample,
.
1. The Rewards:
-
: This is the numerical reward assigned to the
j
-th output , which was generated for a given prompt . - Source of Rewards: The paper specifies (in Section 2.2.2) that these are rule-based rewards.
- Accuracy Rewards: A binary or continuous score evaluating if the final answer in is correct. For a math problem, this could be checking if the result matches the known solution. For a coding problem, it could be the percentage of test cases passed.
- Format Rewards: A score evaluating if the output
adheres to a desired format (e.g., using
<think>
and</think>
tags).
- Axiom: The existence of a reward function is a fundamental axiom of the reinforcement learning framework. .
2. The Baseline:
- Definition: This is the empirical mean (or sample average) of the rewards obtained from the
G
outputs generated for the same promptq
.
- Role as a Baseline: This term serves as the baseline
b(q)
. It's an estimate of the expected reward for the given promptq
under the current (old) policy . Instead of using a separate, learned "critic" network to predict the expected reward, the GRPO algorithm uses this simple and efficient empirical estimate from the group of samples. - The Numerator: The numerator
is the raw, unnormalized advantage. It measures whether the
i
-th output was better or worse than the average performance within its group.
3. The Normalization Factor:
- Definition: This is the empirical standard deviation of the rewards from the group.
(Note: Sometimes the denominator is G
for the biased estimator, but G-1
for the unbiased estimator. In practice, for large G
, the difference is negligible. We will assume the standard definition.)
- Purpose of Normalization: Dividing the raw advantage by the standard deviation is a form of data standardization. It rescales the advantages for a given prompt so that their distribution has a standard deviation of 1.
- Mathematical Justification:
Let the set of raw advantages for a group be
.
- The mean of this set is . The advantages are centered at zero.
- The standard deviation of this set is . By dividing each element of by , the resulting set of normalized advantages will have a mean of 0 and a standard deviation of 1.
Why is this Normalization Important?
Reduces Sensitivity to Reward Scaling: Imagine two different tasks. In Task 1, rewards are either 0 or 1. The advantages will be small fractions. In Task 2, rewards are 0 or 1000. The advantages will be large numbers. Without normalization, the gradient updates for Task 2 would be 1000 times larger than for Task 1, potentially destabilizing learning when training on a mix of tasks. Normalization ensures that the scale of the advantage signal is consistent across different prompts and reward schemes.
Stabilizes Training: It prevents outlier rewards (a single very high or very low reward in a group) from generating excessively large gradients that could harm the policy. By scaling everything relative to the variation within the group, the updates become more measured and stable.
Synthesis and Conclusion
Equation (3) defines a specific form of the advantage function, known as Advantage Normalization in its simplest form, with an additional normalization step.
- Computes Raw Advantage: It first calculates a raw advantage for each sample by subtracting a baseline from its reward .
- Uses an Empirical Baseline: The baseline is not a learned value but is efficiently estimated as the mean reward of all samples generated for the same prompt . This conforms to the requirement that the baseline depends only on the prompt (and the policy that generated the samples), thus not introducing bias into the policy gradient.
- Normalizes the Advantage: The raw advantage is then divided by the standard deviation of the rewards within the group. This standardizes the advantages, making them have a mean of 0 and a standard deviation of 1 for that group.
- Overall Effect: This process results in a well-behaved, normalized advantage signal that robustly indicates whether an output was better or worse than average, independent of the absolute scale of the rewards for that particular task. This standardized signal is then used in Equation (1) to provide stable and effective gradient updates for the policy .
Top comments (0)