DEV Community

Cover image for Implementing DeepSeek-R1 GRPO in Apple MLX framework
Lewis Won
Lewis Won

Posted on • Edited on

Implementing DeepSeek-R1 GRPO in Apple MLX framework

Table of Contents

Motivation

Group Relative Policy Optimisation (GRPO) was a method developed by DeepSeek to improve "language model reasoning capabilities using pure reinforcement learning (RL)", with the specific goal to "develop reasoning capabilities without any supervised data, focusing on their self-evolution through a pure RL process" (source: DeepSeek-R1: Incentivising Reasoning Capabilitiy in LLMs via Reinforcement Learning). GRPO was the method used to train DeepSeek-R1 released in January 2025, which crashed tech stocks such as Nvidia, and served as the basis of subsequent reasoning models such as the Mistrel Magistral (source: Magistral).

This article seeks to explain the math of GRPO, and how to implement GRPO from scratch to train LLMs using Apple MLX framework. Hence, if you are an Apple silicon user, you are in luck, you can run the Jupyter notebook right on your laptop.

This article is created with the assistance of Google Gemini 2.5 Pro.

Show me the code: Jupyter notebook

For those who are keen to dive right into running the code, you may access it here. If you discover any mistakes or have any improvements to suggest, please feel free to make a pull request! I will look into all requests as soon as I can.

(Note: Please note that this code currently only has a dummy reward function. I will add a reward function in the notebook later on and inform again when done, but please feel free to file a pull request if you would like to contribute in any way.)

Peering into the GRPO equation

We will dissect is this scary-looking (at least to me!) equation:

JGRPO(θ)=E[qP(Q),oii=1Gπθold(Oq)]J_{\text{GRPO}}(\theta) = \mathbb{E}[q \sim P(Q), \langle o_i\rangle_{i=1}^{G} \sim \pi_{\theta_{\text{old}}}(O|q)]

This is the theoretical expectation, which theoretically considers every possible prompt, generates a group of responses for each, and then average the improvement we get from the objective function (see below) over all these possibilities.

1Gi=1G(min(πθ(oiq)πθold(oiq)Ai,clip(πθ(oiq)πθold(oiq),1ϵ,1+ϵ)Ai)βDKL(πθπref))\frac{1}{G} \sum_{i=1}^{G} \left( \min\left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}}) \right)

This is the sample-based objective function (the "estimator"). Since we cannot possibly compute the true expectation over all prompts and outputs, we approximate it during training where we take one batch ( a prompt q and its G generated outputs {oi}\{o_i\} and we compute the numerical estimate of our objective using the objective function. We then use this numerical estimate to calculate a gradient and update our model's weights, i.e. ( θold\theta_{\text{old}} becomes θ\theta ).

I will break down the equations into the following parts:

  • Part 1: E[qP(Q),oii=1Gπθold(Oq)]\mathbb{E}[q \sim P(Q), \langle o_i\rangle_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)]

  • Part 2: πθ(oiq)πθold(oiq)\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}

  • Part 3: clip(πθ(oiq)πθold(oiq),1ϵ,1+ϵ)\text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right)

  • Part 4: min(πθ(oiq)πθold(oiq)Ai,clip(πθ(oiq)πθold(oiq),1ϵ,1+ϵ)Ai)\min\left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right) A_i \right)

  • Part 5: DKL(πθπref))D_{KL}\left(\pi_\theta || \pi_{\text{ref}}) \right)

  • Part 6: AiA_i

Part 1

The equation that will be dissected in this part is:

E[qP(Q),oii=1Gπθold(Oq)][function(Q,O)] \mathbb{E}[q \sim P(Q), \langle o_i\rangle_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)][\text{function}(Q, O)]
  • The tilde (~): It means "distributed according to" or "is sampled from". So:

    • qP(Q)q \sim P(Q) This represents questions (q) sampled from the overall distribution of questions (P(Q)). This is a standard practice where models learn to respond to various prompts during training.
    • oii=1Gπθold(Oq)\langle o_i\rangle_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q) : This means the group of G outputs oii=1G\langle o_i\rangle_{i=1}^G is sampled from the policy πθold\pi_{\theta_{\text{old}}} given question q.
  • The expectation function i.e.

    E[,]\mathbb{E}[ \bullet , \bullet ]

    signifies a joint expectation over multiple random variables. One commonly seen example is

    E[XP(X),YP(YX)][function(X,Y)]\mathbb{E}[X \sim P(X), Y \sim P(Y|X)][\text{function}(X, Y)]

    which means taking the expectation over the combined process of first sampling X from P(X), and then sampling Y from P(Y|X). The expectation then applies to the function of both X and Y.

  • In probability theory,

    EX,Y[f(X,Y)]\mathbb{E}_{X,Y}[f(X,Y)]

    can be written as

    XP(X)YP(YX)f(X,Y)\sum_X P(X) \sum_Y P(Y|X) f(X,Y)

    for discrete variables, or

    XYP(X,Y)f(X,Y)dXdY\int_X \int_Y P(X,Y) f(X,Y) dX dY

    for continuous variables. The comma notation is a shorthand for this sequential or joint sampling process.

  • Applying to our equation:

    • q ~ P(Q): First, a question q is randomly chosen from the pool of all possible questions.
    • oii=1Gπθold(Oq)\langle o_i\rangle_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O|q) : Given that specific question q, a group of G outputs oii=1G\langle o_i\rangle_{i=1}^G is then generated by the old policy πθold\pi_{\theta_{\text{old}}} .
    • The expression [function(Q,O)][\text{function}(Q, O)] that follows (the GRPO objective function) then depends on both q and the generated output oii=1G\langle o_i\rangle_{i=1}^G .
    • So, the expectation is taken over the entire data collection process: first randomly pick a question from the bank of available questions, and then generating multiple responses for that question using the πθold\pi_{\theta_{\text{old}}} policy.

In summary:

  • The tilde (~) tells you how the data is being generated (which distribution).
  • The comma (,) separates independent (or conditionally independent) sampling steps that define the full set of random variables over which the expectation is taken. It implies a joint probability distribution, often constructed sequentially.

Why E[qP(Q),oii=1Gπθold(Oq)][function(Q,O)]\mathbb{E}[q \sim P(Q), \langle o_i\rangle_{i=1}^{G} \sim \pi_{\theta_{\text{old}}}(O|q)][\text{function}(Q, O)] instead of simply E[P(Q),πθold(Oq)][function(Q,O)]\mathbb{E}[P(Q), \pi_{\theta_{\text{old}}}(O|q)][\text{function}(Q, O)] ?

  • The notation qP(Q)q \sim P(Q) explicitly states that q is the random variable being sampled, and P(Q) is its probability distribution.
  • Similarly, oii=1Gπθold(Oq)\langle o_i\rangle_{i=1}^{G} \sim \pi_{\theta_{\text{old}}}(O|q) states that oii=1G\langle o_i\rangle_{i=1}^{G} are the random variables (the sampled outputs), and πθold(Oq)][function(Q,O)]\pi_{\theta_{\text{old}}}(O|q)][\text{function}(Q, O)] is the conditional probability distribution from which they are drawn (conditioned on q).

Without q ~ and oii=1G\langle o_i\rangle_{i=1}^{G} , the expression E[P(Q),πθold(Oq)][function(Q,O)]\mathbb{E}[P(Q), \pi_{\theta_{\text{old}}}(O|q)][\text{function}(Q, O)] is ambiguous as it is not clear which variables are being sampled or how are they related to the function [function(Q,O)][\text{function}(Q, O)] . The P(Q) and πθold(Oq)\pi_{\theta_{\text{old}}}(O|q) are distributions, not actual values that vary and contribute to the average.

Show me the code

To implement the Expectation operator E[...]. we can use a loop where each iteration processes a new, randomly sampled mini-batch of prompts, calculate the loss for that batch, and performs an update. Over many iterations, this process approximates the expected value of the objective over the entire data distribution.

Do not fret that the code is long. I will break it down and explain each piece accordingly.

# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
def grpo_train_loop(
    model, model_old, model_ref, tokenizer, optimizer, train_set,
    iters=200, group_size=4, batch_size=2, epsilon=0.2, beta=0.01,
    update_every=10, max_ans_len=4
):
    # Create a grad function for the trainable model
    loss_and_grad_fn = nn.value_and_grad(model, grpo_loss_fn)

    losses = []
    all_rewards = []

    # Start training
    pbar = tqdm.tqdm(range(iters))
    for it in pbar:
        batch_prompts = []
        batch_answers = []

        # 1. Sample a batch of prompts
        indices = np.random.randint(0, len(train_set), batch_size)
        for i in indices:
            # The last word of the output is the ground truth answer (e.g., "ending4")
            prompt_text, answer_text = train_set[i]["output"].rsplit(" ", maxsplit=1)
            full_prompt = [
                {"role": "user", "content": train_set[i]["instruction"]},
                {"role": "assistant", "content": prompt_text}
            ]
            batch_prompts.append(full_prompt)
            batch_answers.append(answer_text)

        # 2. Rollout: Generate G responses for each prompt using the old model
        rollout_sequences = []
        rollout_rewards = []
        rollout_log_probs = []
        rollout_a_toks = []

        for i in range(batch_size):
            prompt_tokens = tokenizer.apply_chat_template(batch_prompts[i], continue_final_message=True)
            group_rewards = []

            for _ in range(group_size):
                # Generate a response
                response = generate(model_old, tokenizer, prompt_tokens, max_tokens=max_ans_len)
                answer_tokens = tokenizer.encode(response, add_special_tokens=False)

                # 3. Get Reward
                reward = 1.0 if batch_answers[i] in response else 0.0
                group_rewards.append(reward)

                # Store data for the optimization step
                full_sequence = mx.array(prompt_tokens + answer_tokens)
                rollout_sequences.append(full_sequence)
                rollout_a_toks.append(mx.array(answer_tokens))

            all_rewards.extend(group_rewards)
            rollout_rewards.append(mx.array(group_rewards))

        # 4. Compute Advantages
        advantages = []
        for rewards in rollout_rewards:
            mean_reward = mx.mean(rewards)
            std_reward = mx.sqrt(mx.var(rewards)) + 1e-8 # Add epsilon for stability
            adv = (rewards - mean_reward) / std_reward
            advantages.append(adv)

        advantages = mx.concatenate(advantages)
        sequences = pad_sequences(rollout_sequences, tokenizer.pad_token_id)
        a_toks = pad_sequences(rollout_a_toks, tokenizer.pad_token_id)

        # Calculate log_probs with the old model for the ratio calculation
        old_log_probs = calculate_log_probs(model_old, sequences, a_toks)

        # 5. Optimization Step
        (loss, policy_reward, kl_div), grads = loss_and_grad_fn(
            model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon
        )

        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

        losses.append(loss.item())
        pbar.set_description(f"Loss: {np.mean(losses[-10:]):.3f}, Mean Reward: {np.mean(all_rewards[-20:]):.3f}")

        # Sync old model weights
        if (it + 1) % update_every == 0:
            model_old.update(model.parameters())
            print(f"\nIter {it+1}: Synced old model weights.")

    # Final save of adapter weights
    model.save_weights(str(adapter_path / "adapters.safetensors"))
    print("Saved final weights to adapters/adapters.safetensors.")
    return losses, all_rewards
Enter fullscreen mode Exit fullscreen mode

The rollout phase: oii=1G\langle o_i\rangle_{i=1}^{G}

For each prompt q in our batch, we need to generate a group G possible outputs ( oii=1G\langle o_i\rangle_{i=1}^{G} ) using the fixed, old policy πθold\pi_{\theta_{\text{old}}} . This is the data collection or "rollout" phase.

The code is implemented with a nested for loop within grpo_train_loop that calls generate. In the code, model_old is πθold\pi_{\theta_{\text{old}}} . The outer loop iterates through prompts in the batch, and the inner loop runs group_size (G) times to generate each output oio_i for that prompt.

# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Rollout: Generate G responses for each prompt using the old model
...
for i in range(batch_size):
            prompt_tokens = tokenizer.apply_chat_template(batch_prompts[i], continue_final_message=True)
            group_rewards = []

            for _ in range(group_size):
                # Generate a response
                response = generate(model_old, tokenizer, prompt_tokens, max_tokens=max_ans_len)
                answer_tokens = tokenizer.encode(response, add_special_tokens=False)

                # 3. Get Reward
                reward = 1.0 if batch_answers[i] in response else 0.0
                group_rewards.append(reward)

                # Store data for the optimization step
                full_sequence = mx.array(prompt_tokens + answer_tokens)
                rollout_sequences.append(full_sequence)
                rollout_a_toks.append(mx.array(answer_tokens))

            all_rewards.extend(group_rewards)
            rollout_rewards.append(mx.array(group_rewards))
Enter fullscreen mode Exit fullscreen mode

Part 2: πθ(oiq)πθold(oiq)\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}

The concept behind this equation is importance sampling:

  • πθold(oiq)\pi_{\theta_{\text{old}}}(o_i|q) (The Denominator): This represents the probability of generating a specific output oio_i (a complete reasoning trace and answer) given the input question qq , using the old policy. The "old" policy is the version of the model that was used to generate the batch of data for the current training step. It is frozen during this step.

  • πθ(oiq)\pi_\theta(o_i|q) (The Numerator): This represents the probability of generating that exact same output oio_i given the question qq , but using the current policy πθ\pi_\theta . This is the policy we are actively training and updating in this step.

  • The Ratio: The ratio πθ(oiq)πθold(oiq)\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} is the likelihood ratio or importance weight.

    • If r(θ)>1r(\theta) > 1 , the new policy πθ\pi_\theta is more likely to generate output oio_i than the old one was.
    • If r(θ)<1r(\theta) < 1 , the new policy πθ\pi_\theta is less likely to generate output oio_i .
    • If r(θ)=1r(\theta) = 1 , the policy has not changed with respect to this specific output.

What is its purpose? (Off-Policy Learning)

The primary purpose of this ratio is to enable off-policy learning.

In reinforcement learning, the ideal way to evaluate a policy is to use it to generate actions and see what rewards you get. However, generating new outputs ( oio_i ) from the model for every single gradient update is computationally very expensive.

Off-policy methods solve this. We can:

  1. Generate a large batch of experiences (the outputs oio_i ) using the πθold\pi_{\theta_{\text{old}}} policy.
  2. Then, perform several steps of optimization on our πθ\pi_\theta policy using that same batch of data.

The importance sampling ratio is the mathematical "correction factor" that allows us to estimate how good an action is under the new policy ( πθ\pi_\theta ) using data that was collected by the old policy ( πθold\pi_{\theta_{\text{old}}} ).

For the ease of discussion, I will equate:

r(θ)=πθ(oiq)πθold(oiq)r(\theta) = \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}

The objective function multiplies this ratio by the advantage AiA_i (how good the output oio_i was). So, the update logic is:

  • If oio_i was a good output ( Ai>0A_i > 0 ), we want to increase its probability. Maximising r(θ)×Air(\theta) \times A_i will push r(θ)r(\theta) to be greater than 1, which in turn pushes πθ(oiq)\pi_\theta(o_i|q) to increase.
  • If oio_i was a bad output ( Ai<0A_i < 0 ), we want to decrease its probability. Maximising r(θ)×Air(\theta) \times A_i (a negative number) will push r(θ)r(\theta) to be less than 1, making the overall term less negative and thus decreasing πθ(oiq)\pi_\theta(o_i|q) .

Why is it designed that way? (Stability and PPO)

While the ratio allows for efficient learning, it is also a source of instability. If the new policy πθ\pi_\theta becomes very different from the old one πθold\pi_{\theta_{\text{old}}} , the ratio r(θ)r(\theta) can become extremely large or close to zero. A very large ratio would lead to a massive, noisy gradient update, potentially destroying all the learning the model has already done.

This is the problem that Proximal Policy Optimization (PPO), from which GRPO's objective is derived, was designed to solve. The design in Equation 1 is a direct implementation of the PPO-Clip objective (source: Proximal Policy Optimization).

The goal is to keep the new policy "proximal" (i.e., close) to the old policy. This creates a "trust region" where we can be confident the update is beneficial.

Part 3: clip(πθ(oiq)πθold(oiq),1ϵ,1+ϵ)\text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right)

This is the core of the PPO algorithm, which encourages making the new policy πθ\pi_\theta more likely to produce high-advantage outputs, but "clips" the update to prevent it from changing to drastically and destabilsing training.

Example of clipping

Take for instance we set ϵ\epsilon to be 0.2. We then get the following clipping equation:

clip(πθ(oiq)πθold(oiq),0.8,1.2)\text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 0.8, 1.2\right)

Examples of how the clipping equation works is below:

  • clip(1.15,0.8,1.2)=1.15\text{clip}(1.15, 0.8, 1.2) = 1.15 because the value is within the range.
  • clip(1.5,0.8,1.2)=1.2\text{clip}(1.5, 0.8, 1.2) = 1.2 because the value is beyond the range and is clipped down to the maximum value of 1.2.

We can see that with clipping, when the optimiser gets "greedy" and suggests a huge change, the model is still encouraged to make the output more likely, but is prevented from making a dangerously large jump in the policy.

Code Implementation

This logic is encapsulated within the grpo_loss_fn.

The probability ratio ri(θ)r_i(\theta) is calculated in log-space for
numerical stability: ri(θ)=exp(log probθlog probθoldr_i(\theta) = exp(\text{log prob}\theta - \text{log prob}{\theta_{\text{old}}} ).

# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon):
    """The GRPO loss function."""
    # Get log probs from the trainable model (π_θ)
    log_probs = calculate_log_probs(model, sequences, a_toks)

    # Get log probs from the reference model (π_ref) for KL penalty
    log_probs_ref = calculate_log_probs(model_ref, sequences, a_toks)

    # PPO-clip objective
    ratio = mx.exp(log_probs - old_log_probs)
    clipped_ratio = mx.clip(ratio, 1.0 - epsilon, 1.0 + epsilon)
    policy_reward = mx.minimum(ratio * advantages, clipped_ratio * advantages)

    # KL penalty
    # Step 1: Calculate log(r) where r = π_ref / π_θ
    # log(r) = log(π_ref) - log(π_θ)
    log_ratio_for_kl = log_probs_ref - log_probs

    # Step 2: Calculate r itself by exponentiating log(r)
    # r = exp(log(r))
    ratio_for_kl = mx.exp(log_ratio_for_kl)

    # Step 3: Apply the paper's full formula: r - log(r) - 1
    kl_div = ratio_for_kl - log_ratio_for_kl - 1

    # The objective is to maximize this, so we return the negative for minimization
    loss = -mx.mean(policy_reward - beta * kl_div)
    return loss, mx.mean(policy_reward), mx.mean(kl_div)
Enter fullscreen mode Exit fullscreen mode

The helper function calculate_log_probs is responsible for computing log P(o_i | q) for a given policy.

def calculate_log_probs(model, sequences, a_toks):
    """Calculates the log probabilities of the generated answer tokens."""
    # Pass the full sequence (prompt + answer) to the model
    logits = model(sequences)

    # Convert to log probabilities
    log_probs_full = nn.log_softmax(logits, axis=-1)

    ## Find the actual positions where answer tokens should be extracted
    # This assumes a_toks contains the actual token IDs that were generated
    batch_size, seq_len = sequences.shape
    _, ans_len = a_toks.shape

    # Calculate the starting position for answer tokens (assuming they're at the end)
    start_pos = seq_len - ans_len

    # Extract log probabilities for the answer portion of the sequence
    answer_log_probs = log_probs_full[:, start_pos:start_pos+ans_len, :]

    # Create indices for gathering - ensure proper shape alignment
    indices = a_toks[:, :, None]

    # Extract log probabilities for the actual answer tokens
    selected_log_probs = mx.take_along_axis(answer_log_probs, indices, axis=-1).squeeze(-1)

    # Sum log probabilities across the answer sequence
    return mx.sum(selected_log_probs, axis=-1)
Enter fullscreen mode Exit fullscreen mode

Part 4: min(πθ(oiq)πθold(oiq)Ai,clip(πθ(oiq)πθold(oiq),1ϵ,1+ϵ)Ai)\min\left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1-\epsilon, 1+\epsilon\right) A_i \right)

The purpose of this min function is to act as a floor, i.e.:

  • When an output is good (positive AiA_i ), it prevents the update from becoming too rewarding.
  • When an output is bad (negative AiA_i ), it acts as a floor on the penalty.

In both cases, the min function prevents the model from making a large policy change too quickly.

Example

Using back the same clipping function in Part 3, where we set ϵ=0.2\epsilon = 0.2 , r(θ)=1.5r(\theta) = 1.5 , and we now assume we have an advantage value A1=10A_1 = 10 . The equation we thus get is:

min(πθ(oiq)πθold(oiq)×10,clip(πθ(oiq)πθold(oiq),0.8,1.2)×10)=min(1.5×10,1.2×10)=12\min\left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} \times 10, \text{clip}\left(\frac{\pi_\theta(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 0.8, 1.2\right) \times 10 \right) = \min\left( 1.5 \times 10, 1.2 \times 10 \right) = 12

Part 5: DKL(πθπref))D_{KL}\left(\pi_\theta || \pi_{\text{ref}}) \right)

This function acts as a regulariser, penalising the policy πθ\pi_{\theta} for deviating from a reference policy πref\pi_{\text{ref}} . (Note: When we say policy, we are actually referring to the LLM as weights, so πθ\pi{\theta} refers to the LLM with updated weights, while πref\pi_{\text{ref}} usually refers to the original stock LLM.)_

The paper defines the term as:

DKL(πθπref)=πref(oiq)πθ(oiq)log(πref(oiq)πθ(oiq))1D_{KL}(\pi_\theta || \pi_{\text{ref}}) = \frac{\pi_{\text{ref}}(o_i|q)}{\pi_\theta(o_i|q)} - \log\left(\frac{\pi_{\text{ref}}(o_i|q)}{\pi_\theta(o_i|q)}\right) - 1

What is key is to recognise that this equation is not the standard Kulback-Leibler (KL) divergence, but a more specific, per-sample approximation chosen for being computationally cheaper. This compares to the standard KL divergence which can be expressed as:

DKL(πθπref)=Eoπθ[log(πθ(oq)πref(oq))]D_{KL}(\pi_\theta || \pi_{\text{ref}}) = \mathbb{E}{o \sim \pi\theta}\left[\log\left(\frac{\pi_\theta(o|q)}{\pi_{\text{ref}}(o|q)}\right)\right]

Compared to the standard definition, there are two main differences:

  1. The standard definition is an expectation over the entire distribution πθ\pi_{\theta} , whereas the paper's variant is an expression for a single sample oio_i .
  2. The functional form of the paper's variant is different from the term inside the standard expectation.

Deconstruction and Analysis of the variant of KL divergence

For simplicity, let the probability ratio be r:

r=πref(oiq)πθ(oiq)r = \frac{\pi_{\text{ref}}(o_i|q)}{\pi_\theta(o_i|q)}

Then Equation (2) defines a function f(r)f(r) :

f(r)=rlog(r)1f(r) = r - \log(r) - 1

This function is evaluated for a single sample oio_i , which itself was drawn from the old policy, πθold\pi_{\theta_{old}} .

In order for f(r)f(r) to be a valid divergence measure, it must satisfy two properties:

  1. Non-negativity: D(pq)>=0D(p || q) >= 0 .
  2. Identity of Indiscernibles: D(pq)=0D(p || q) = 0 if and only if p=qp = q .

Proof that f(r)f(r) satisfies the two properties is in Appendix A.

Show me the code

This term is a regularizer. It penalizes the objective if the trainable policy π_θ strays too far from the original, trusted reference policy π_ref, helping to prevent catastrophic forgetting.

Code Implementation: Also within grpo_loss_fn.

# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(...):
    ...
    # Get log probs from the reference model (π_ref)
    log_probs_ref = calculate_log_probs(model_ref, ...)

    # KL penalty
    # Step 1: Calculate log(r) where r = π_ref / π_θ
    # log(r) = log(π_ref) - log(π_θ)
    log_ratio_for_kl = log_probs_ref - log_probs

    # Step 2: Calculate r itself by exponentiating log(r)
    # r = exp(log(r))
    ratio_for_kl = mx.exp(log_ratio_for_kl)

    # Step 3: Apply the paper's full formula: r - log(r) - 1
    kl_div = ratio_for_kl - log_ratio_for_kl - 1
    ...
Enter fullscreen mode Exit fullscreen mode

Part 6: AiA_i

The Advantage Function AiA_i is a central component in modern policy gradient methods. Intuitively, the advantage tells us not just if an action was "good" (positive reward), but if it was "better than average". It is designed to reduce sensitivity to reward scaling, and stabilises training by preventing outlier rewards. A more technical discussion about the advantage function is available in Appendix B.

Given that the advantage AiA_i tells us how much better or worse a specific output oio_i was compared. to the average of its group, this requires two steps:

  1. Calculate the raw reward rir_i ; and
  2. Normalising the output

In short, the mathematical equation is:

Ai=rimean(rgroup)std(rgroup)A_i = \frac{r_i - \text{mean}(r_{\text{group}})}{\text{std}(r_{\text{group}})}

We can implement the code as such:

  1. Reward calculation rir_i as a simple rule-based reward.
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Get Reward (r_i)
reward = 1.0 if batch_answers[i] in response else 0.0
group_rewards.append(reward)
Enter fullscreen mode Exit fullscreen mode
  1. Normalisating to get Advantage AiA_i
# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# Compute Advantages
advantages = []
for rewards in rollout_rewards:
    mean_reward = mx.mean(rewards)
    std_reward = mx.sqrt(mx.var(rewards)) + 1e-8 # Epsilon for stability
    # This line directly implements the advantage formula
    adv = (rewards - mean_reward) / std_reward 
    advantages.append(adv)

advantages = mx.concatenate(advantages)
Enter fullscreen mode Exit fullscreen mode

Bringing it all together

Finally, we combine all the pieces, average over the batch, and negate the result to create a loss that can be minimised by the optimiser.

Mathematical Component:

loss=1Ni=1N(LiCLIP(θ)βDKL(...))loss = - \frac{1}{N} \sum_{i=1}^N ( L_i^{\text{CLIP}}(\theta) - \beta * D_{\text{KL}}(...) )

where N is the total batch size (batch_size * group_size).

Code Implementation: The final lines of grpo_loss_fn and the optimizer.update call in the training loop. The code was displayed in Part 3 above, with the relevant abridged segments reproduced below for ease of reference.

# --- File: MLX LM GRPO.ipynb, Cell: grpo-helpers ---
def grpo_loss_fn(...):
    ...
    # The objective is to maximize this, so we return the negative for minimization
    # mx.mean() handles the averaging over the batch
    loss = -mx.mean(policy_reward - beta * kl_div)
    return loss, ...

# --- File: MLX LM GRPO.ipynb, Cell: grpo-loop ---
# 5. Optimization Step
(loss, ...), grads = loss_and_grad_fn(...)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
Enter fullscreen mode Exit fullscreen mode

Concrete example of how to train using GRPO

  • Prompt (q): "What is the capital of Malaysia?"
  • Policy ( πθ\pi_\theta ): The existing LLM we are trying to improve
  • Old policy ( πθold\pi_{\theta_{\text{old}}} ) A frozen copy of the LLM before this training step.
  • Reference policy ( πref\pi_{\text{ref}} ): The original, pre-trained base LLM (e.g. DeepSeek-V3-Base).
  • Hyperparameters:
    • Group size (G): 4
    • Clipping epsilon ( ϵ\epsilon ): 0.2
    • KL penality beta ( β\beta ): 0.05

Step 1: sample generation and reward calculation

We use the old policy πθold\pi_{\theta_{\text{old}}} to generate G = 4 different outputs for prompt q. Then ,we use a rule-based reward model to score them.

i Output (o_i) Reward (r_i) Notes on Reward
1 "The capital of Malaysia is Kuala Lumpur." 1.0 Correct answer.
2 "The capital of Malaysia is Johor." 0.0 Incorrect answer.
3 "Malaysia's capital city is Kuala Lumpur." 1.0 Correct answer, different wording.
4 "The capital of Malaysia is Selangor**" 0.9 Incorrect but Kuala Lumpur is surrounded by Selangor.

Step 2: calculate normalised advantage

First, calculate the mean and standard deviation of the rewards:

mean(ri)=(1.0+0.0+1.0+0.9)/4=2.9/4=0.725\text{mean}(r_i) = (1.0 + 0.0 + 1.0 + 0.9)/4 = 2.9 / 4 = 0.725
std(ri)=(1.00.725)2+(0.00.725)2+(1.00.725)2+(0.90.725)241=0.0756+0.5256+0.0756+0.03063=0.7074/30.485 \begin{align*} \text{std}(r_i) &= \sqrt{\frac{(1.0 - 0.725)^2 + (0.0 - 0.725)^2 + (1.0 - 0.725)^2 + (0.9 - 0.725)^2}{4-1}} \newline &= \sqrt{\frac{0.0756 + 0.5256 + 0.0756 + 0.0306}{3}} \newline &= \sqrt{0.7074/3} \newline &≈ 0.485 \end{align*}

Now, we compute the advantage AiA_i for each sample:

  • A1=(1.00.725)/0.485=0.275/0.4850.567A_1 = (1.0 - 0.725) / 0.485 = 0.275 / 0.485 ≈ 0.567 (Positive Advantage: this output was better than average)
  • A2=(0.00.725)/0.485=0.725/0.4851.495A_2 = (0.0 - 0.725) / 0.485 = -0.725 / 0.485 ≈ -1.495 (Negative Advantage: this output was worse than average)
  • A3=(1.00.725)/0.485=0.275/0.4850.567A_3 = (1.0 - 0.725) / 0.485 = 0.275 / 0.485 ≈ 0.567 (Positive Advantage)
  • A4=(0.90.725)/0.485=0.175/0.4850.361A_4 = (0.9 - 0.725) / 0.485 = 0.175 / 0.485 ≈ 0.361 (Positive Advantage, but smaller)

Step 3: Calculate Policy Probabilities and Ratios

For each of our 4 samples, we need to compute its probability under the old policy ( πθold\pi_{\theta{\text{old}}} ), the current policy ( πθ\pi_\theta ), the reference policy ( πref\pi_{\text{ref}} ), and the ratio ri(θ)=πθ/πθoldr_i(\theta) = \pi_\theta / \pi_{\theta_{\text{old}}} . These are hypothetical values for our example. It is important to note these are not probabilities that sum to 1, as they only represent 4 outputs out of a near-infinite number of possibilities.

i old policy current policy reference policy ratio of current policy to old policy
1 0.25 0.30 0.28 1.20
2 0.30 0.21 0.15 0.70
3 0.20 0.28 0.26 1.40
4 0.25 0.21 0.22 0.84

Interpretation: For sample 1, our new policy π_θ is more confident (0.30) than the old one (0.25), so the ratio is > 1. For sample 2, the new policy is less confident, so the ratio is < 1.

Step 4: Calculate the Clipped Surrogate Objective for Each Sample

Now we apply the min(..., clip(...)) formula for each sample. The clip range is [1 - ε, 1 + ε] = [0.8, 1.2].

Sample 1 (A₁ ≈ 0.567 > 0):

  • Unclipped term: r₁(θ) * A₁ = 1.20 * 0.567 ≈ 0.680
  • Clipped ratio: clip(1.20, 0.8, 1.2) = 1.2
  • Clipped term: 1.2 * 0.567 ≈ 0.680
  • min(0.680, 0.680) = 0.680. The ratio was within the clip bounds.

Sample 2 (A₂ ≈ -1.495 < 0):

  • Unclipped term: r₂(θ) * A₂ = 0.70 * -1.495 ≈ -1.047
  • Clipped ratio: clip(0.70, 0.8, 1.2) = 0.8
  • Clipped term: 0.8 * -1.495 ≈ -1.196
  • min(-1.047, -1.196) = -1.196. The value is clipped. This is a subtle but crucial point. The optimizer's goal is to maximize the objective. An objective of -1.047 is better than -1.196. By forcing the optimizer to take the min, we are selecting the worse of the two possible objectives. This limits the size of the policy update, preventing the model from making a large, potentially unstable change even when correcting a mistake.

Sample 3 (A₃ ≈ 0.567 > 0):

  • Unclipped term: r₃(θ) * A₃ = 1.40 * 0.567 ≈ 0.794
  • Clipped ratio: clip(1.40, 0.8, 1.2) = 1.2
  • Clipped term: 1.2 * 0.567 ≈ 0.680
  • min(0.794, 0.680) = 0.680. The policy update is clipped to prevent it from getting too greedy on this good sample.

Sample 4 (A₄ ≈ 0.361 > 0):

  • Unclipped term: r₄(θ) * A₄ = 0.84 * 0.361 ≈ 0.303
  • Clipped ratio: clip(0.84, 0.8, 1.2) = 0.84
  • Clipped term: 0.84 * 0.361 ≈ 0.303
  • min(0.303, 0.303) = 0.303. The ratio was within the clip bounds.

Step 5: Calculate the KL Penalty for Each Sample (Using Equation 2)

Now we calculate the penalty term D_{KL}(\pi_\theta || \pi_{\text{ref}}) for each sample. Let r_{ref} = π_{ref} / π_θ.
The formula is r_{ref} - log(r_{ref}) - 1.

  • Sample 1: r_{ref} = 0.28 / 0.30 ≈ 0.933. Penalty = 0.933 - log(0.933) - 1 ≈ 0.933 - (-0.069) - 1 = 0.002
  • Sample 2: r_{ref} = 0.15 / 0.21 ≈ 0.714. Penalty = 0.714 - log(0.714) - 1 ≈ 0.714 - (-0.337) - 1 = 0.051
  • Sample 3: r_{ref} = 0.26 / 0.28 ≈ 0.929. Penalty = 0.929 - log(0.929) - 1 ≈ 0.929 - (-0.074) - 1 = 0.003
  • Sample 4: r_{ref} = 0.22 / 0.21 ≈ 1.048. Penalty = 1.048 - log(1.048) - 1 ≈ 1.048 - (0.047) - 1 = 0.001

Step 6: Combine Everything to Get the Final Value

The final loss for our batch is the average over the 4 samples. For each sample i, the value is (Clipped_Objective_i - β * KL_Penalty_i).

  • Sample 1 Value: 0.680 - (0.05 * 0.002) = 0.680 - 0.0001 = 0.6799
  • Sample 2 Value: -1.196 - (0.05 * 0.051) = -1.196 - 0.00255 = -1.19855
  • Sample 3 Value: 0.680 - (0.05 * 0.003) = 0.680 - 0.00015 = 0.67985
  • Sample 4 Value: 0.303 - (0.05 * 0.001) = 0.303 - 0.00005 = 0.30295

Total Objective J (for this one prompt):
J_GRPO = (1/4) * (0.6799 - 1.19855 + 0.67985 + 0.30295) = (1/4) * 0.46415 ≈ 0.116

Final Action

The value J_GRPO ≈ 0.116 is the number we want to maximize. The optimizer (like Adam) will compute the gradient of this objective with respect to the LLM's parameters (∇_θ J_GRPO) and take a small step in that gradient's direction. This single step will slightly adjust the millions of weights in π_θ to:

  1. Increase the probability of outputs like o₁ and o₃ (the good ones).
  2. Decrease the probability of the bad output o₂.
  3. Do this while being constrained by the clipping mechanism and pulled slightly back towards the original π_{ref} model to avoid forgetting how to form coherent sentences.

Conclusion

Congratulations on making this far. The full Jupyter notebook to train your LLM on your Apple silicon computer is accessible here. If you discover any mistakes or have any improvements to suggest, please feel free to make a pull request! I will look into all requests as soon as I can.

This being my third article, I have covered:

  1. Building softmax self-attention from scratch
  2. The math behind linearised self-attention
  3. fine tuning with GRPO

My future articles will continue to revolve around these topics:

  1. Building LLM from scratch (because why not?)
  2. Fine tuning
  3. LLM evaluations

If you have any interesting topics related to LLMs or machine learning in general that you are interested for me to explore, please let me know. I am open to ideas.

Appendix A: Proof that f(r)f(r) is a valid divergence measurement

The two properties to satisfy are:

  1. Non-negativity: D(pq)>=0D(p || q) >= 0 .
  2. Identity of Indiscernibles: D(pq)=0D(p || q) = 0 if and only if p=qp = q .

1. Proof of Identity of Indiscernibles

We must show that f(r)=0f(r) = 0 if and only if r=1r = 1 . The condition r=1r = 1 implies πref(oiq)=πθ(oiq)\pi_{ref}(o_i|q) = \pi_{\theta}(o_i|q) , which means the policies are identical for this specific output.

  • If r=1r = 1 : f(1)=1log(1)1=101=0f(1) = 1 - log(1) - 1 = 1 - 0 - 1 = 0 . This part of the proof is trivial.
  • If f(r)=0f(r) = 0 : rlog(r)1=0r - log(r) - 1 = 0 implies r1=log(r)r - 1 = log(r) . Consider the graphs of y=x1y = x - 1 (a straight line) and y=log(x)y = log(x) . They are tangent at the point x=1x = 1 . To prove this formally, let g(r)=r1log(r)g(r) = r - 1 - log(r) . We want to find the roots of g(r)=0g(r) = 0 . The derivative is g(r)=11/rg^\prime(r) = 1 - 1/r . Setting g(r)=0g^\prime(r) = 0 gives r=1r = 1 . This is the only extremum. Since g(1)=11log(1)=0g(1) = 1 - 1 - log(1) = 0 , the function g(r)g(r) has a minimum value of 0 at r=1r=1 . Therefore, the only real solution to g(r)=0g(r) = 0 is r=1r = 1 . This completes the proof that f(r)=0    r=1f(r) = 0 \iff r = 1 .

2. Proof of Non-Negativity

We must show that f(r)>=0f(r) >= 0 for all
r>0r > 0 (since
rr is a ratio of probabilities, it must be positive).

  • Let's use calculus again on f(r)=rlog(r)1f(r) = r - log(r) - 1 .
  • First Derivative: f(r)=11/rf'(r) = 1 - 1/r .
  • Critical Point: f(r)=0    11/r=0    r=1f^\prime(r) = 0 \implies 1 - 1/r = 0 \implies r = 1 .
  • Second Derivative: f(r)=1/r2f^{\prime\prime}(r) = 1/r^2 .
  • Since r>0r > 0 , we have f(r)>0f^{\prime\prime}(r) > 0 for all rr in the domain. This proves that f(r)f(r) is a strictly convex function.
  • A strictly convex function has a unique global minimum at its critical point. We found this critical point to be r=1r=1 .
  • The value of the function at this global minimum is f(1)=1log(1)1=0f(1) = 1 - log(1) - 1 = 0 .
  • Since the function's global minimum value is 0, it must be that f(r)0f(r) ≥ 0 for all r>0r > 0 .

This completes the proof of non-negativity.

Appendix B - A more technical discussion on the advantage function

The Advantage Function, AiA_i , is a central component in modern policy gradient methods. In reinforcement learning, the simplest policy gradient update rule uses the total reward RR to scale the gradient θlogπθ(as)\nabla_{\theta} log \pi_{\theta}(a|s) . However, this approach suffers from high variance, meaning the gradient estimates can fluctuate wildly from one batch of samples to another, leading to unstable training.

The core idea to reduce this variance is to subtract a baseline b(s)b(s) from the reward. The baseline should ideally be an estimate of the average reward from state ss . This leads to the Advantage Function:

A(s,a)=R(s,a)b(s) A(s, a) = R(s, a) - b(s)

Intuitively, the advantage tells us not just if an action was "good" (positive reward), but if it was "better than average". If A(s,a)>0A(s, a) > 0 , the action aa was better than expected, and its probability should be increased. If A(s,a)<0A(s, a) < 0 , the action was worse than expected, and its probability should be decreased.

Key Theorem (Baseline Invariance of Policy Gradient):
The introduction of a baseline b(s)b(s) that depends only on the state s (or in our case, the prompt qq ) does not introduce bias into the gradient estimate.

Proof:
We need to show that Eaπ[θlogπθ(as)b(s)]=0E_{a∼π}[∇θ log πθ(a|s) * b(s)] = 0 .

Eaπ[θlogπθ(as)b(s)]=πθ(as)(θπθ(as)πθ(as))b(s)da=θπθ(as)b(s)da=b(s)θπθ(as)da(since b(s) does not depend on a)=b(s)θπθ(as)da(swapping integral and gradient)=b(s)θ(1)(since πθ is a probability distribution, it integrates to 1)=b(s)0=0\begin{align*} \mathbb{E}{a \sim \pi}[ \nabla\theta \log \pi_\theta(a|s) \cdot b(s) ] &= \int \pi_\theta(a|s) \left( \frac{\nabla_\theta \pi_\theta(a|s)}{\pi_\theta(a|s)} \right) b(s) da \newline &= \int \nabla_\theta \pi_\theta(a|s) \cdot b(s) da \newline &= b(s) \int \nabla_\theta \pi_\theta(a|s) da \quad (\text{since } b(s) \text{ does not depend on } a) \newline &= b(s) \nabla_\theta \int \pi_\theta(a|s) da \quad (\text{swapping integral and gradient}) \newline &= b(s) \nabla_\theta (1) \quad (\text{since } \pi_\theta \text{ is a probability distribution, it integrates to 1}) \newline &= b(s) \cdot 0 \newline &= 0 \end{align*}

This proves that subtracting a baseline does not change the expected gradient, E[J]=JE[\nabla J] = \nabla J . While the expectation is the same, the variance of the gradient estimator Var[J]Var[\nabla J] is significantly reduced.

The Equation

The paper's specific implementation of the advantage function for a group of GG outputs o1,...,oG{o_1, ..., o_G} is:

Ai=rimean({r1,r2,,rG})std({r1,r2,,rG})A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \ldots, r_G\})}{\text{std}(\{r_1, r_2, \ldots, r_G\})}

where {r1,...,rG}\{r_1, ..., r_G\} are the rewards for the corresponding outputs.

Step-by-Step Component Breakdown

Let's deconstruct the formula for the advantage of the i-th sample, AiA_i .

1. The Rewards: {r1,r2,...,rG}\{r_1, r_2, ..., r_G\}

  • rjr_j : This is the numerical reward assigned to the j-th output ojo_j , which was generated for a given prompt qq .
  • Source of Rewards: The paper specifies (in Section 2.2.2) that these are rule-based rewards.
    • Accuracy Rewards: A binary or continuous score evaluating if the final answer in ojo_j is correct. For a math problem, this could be checking if the result matches the known solution. For a coding problem, it could be the percentage of test cases passed.
    • Format Rewards: A score evaluating if the output ojo_j adheres to a desired format (e.g., using <think> and </think> tags).
  • Axiom: The existence of a reward function R:(Q,O)RR: (Q, O) \rightarrow R is a fundamental axiom of the reinforcement learning framework. rj=R(q,oj)r_j = R(q, o_j) .

2. The Baseline: mean({r1,...,rG})mean(\{r_1, ..., r_G\})

  • Definition: This is the empirical mean (or sample average) of the rewards obtained from the G outputs generated for the same prompt q.
mean(rjj=1G)=rˉ=1Gj=1Grj\text{mean}(\langle r_j\rangle_{j=1}^G) = \bar{r} = \frac{1}{G} \sum_{j=1}^{G} r_j
  • Role as a Baseline: This term serves as the baseline b(q). It's an estimate of the expected reward for the given prompt q under the current (old) policy πθold\pi_{\theta_{\text{old}}} . Instead of using a separate, learned "critic" network to predict the expected reward, the GRPO algorithm uses this simple and efficient empirical estimate from the group of samples.
  • The Numerator: The numerator rimean(rjj=1G)r_i - mean(\langle r_j\rangle_{j=1}^G) is the raw, unnormalized advantage. It measures whether the i-th output was better or worse than the average performance within its group.

3. The Normalization Factor: std({r1,...,rG})std(\{r_1, ..., r_G\})

  • Definition: This is the empirical standard deviation of the rewards from the group.
std(rjj=1G)=1G1j=1G(rjrˉ)2\text{std}(\langle r_j\rangle_{j=1}^G) = \sqrt{\frac{1}{G-1} \sum_{j=1}^{G} (r_j - \bar{r})^2}

(Note: Sometimes the denominator is G for the biased estimator, but G-1 for the unbiased estimator. In practice, for large G, the difference is negligible. We will assume the standard definition.)

  • Purpose of Normalization: Dividing the raw advantage by the standard deviation is a form of data standardization. It rescales the advantages for a given prompt so that their distribution has a standard deviation of 1.
  • Mathematical Justification: Let the set of raw advantages for a group be S={r1rˉ,r2rˉ,...,rGrˉ}S = \{r_1-\bar{r}, r_2-\bar{r}, ..., r_G-\bar{r}\} .
    • The mean of this set is E[S]=E[rjrˉ]=E[rj]E[rˉ]=rˉrˉ=0E[S] = E[r_j - \bar{r}] = E[r_j] - E[\bar{r}] = \bar{r} - \bar{r} = 0 . The advantages are centered at zero.
    • The standard deviation of this set is Std[S]=Std[rjrˉ]=Std[rj]Std[S] = Std[r_j - \bar{r}] = Std[r_j] . By dividing each element of SS by Std[rj]Std[r_j] , the resulting set of normalized advantages {A1,...,AG}\{A_1, ..., A_G\} will have a mean of 0 and a standard deviation of 1.

Why is this Normalization Important?

  1. Reduces Sensitivity to Reward Scaling: Imagine two different tasks. In Task 1, rewards are either 0 or 1. The advantages will be small fractions. In Task 2, rewards are 0 or 1000. The advantages will be large numbers. Without normalization, the gradient updates for Task 2 would be 1000 times larger than for Task 1, potentially destabilizing learning when training on a mix of tasks. Normalization ensures that the scale of the advantage signal is consistent across different prompts and reward schemes.

  2. Stabilizes Training: It prevents outlier rewards (a single very high or very low reward in a group) from generating excessively large gradients that could harm the policy. By scaling everything relative to the variation within the group, the updates become more measured and stable.

Synthesis and Conclusion

Equation (3) defines a specific form of the advantage function, known as Advantage Normalization in its simplest form, with an additional normalization step.

  1. Computes Raw Advantage: It first calculates a raw advantage for each sample oio_i by subtracting a baseline from its reward rir_i .
  2. Uses an Empirical Baseline: The baseline is not a learned value but is efficiently estimated as the mean reward of all samples o1,...,oGo_1, ..., o_G generated for the same prompt qq . This conforms to the requirement that the baseline depends only on the prompt qq (and the policy that generated the samples), thus not introducing bias into the policy gradient.
  3. Normalizes the Advantage: The raw advantage is then divided by the standard deviation of the rewards within the group. This standardizes the advantages, making them have a mean of 0 and a standard deviation of 1 for that group.
  4. Overall Effect: This process results in a well-behaved, normalized advantage signal AiA_i that robustly indicates whether an output was better or worse than average, independent of the absolute scale of the rewards for that particular task. This standardized signal is then used in Equation (1) to provide stable and effective gradient updates for the policy πθ\pi_{\theta} .

Top comments (0)