DEV Community

Cover image for Routing and balancing losses with Mixture of Experts
Lewis Won
Lewis Won

Posted on

Routing and balancing losses with Mixture of Experts

Table of Contents

Motivation

This article explains by hand how to (i) route tokens and (ii balance losses in routing when training Mixture of Experts (MoEs) large language models.

For those following the Stanford CS 336 class, this could be used as a study guide that complements Lecture 4 on Mixture of Experts.

This article does not seek to carry out a comprehensive review of the entire MoEs literature. Hence, I will only focus on the most popular variant of routing, i.e. (i) tokens choose experts where k = 2; and (ii) heuristic balancing losses.

This article was written with the assistance of Google Gemini 2.5.


What are mixture of experts

MoEs is a machine learning technique that divides a complex problem among multiple specialized models, known as "experts". A "gating network" or "router" then selects the most suitable expert or combination of experts to handle a given input. (Fun fact: sparse expert models are a thirty-year old concept.)

The core idea behind MoEs is to replace dense feed-forward network (FFN) layers with sparse MoEs layers. In a traditional dense model, the entire network is activated to process any input. In contrast, MoE models use conditional computation, meaning only a select few "expert" sub-networks are activated for each input token.

In the context of large language models, the more popular implementation is to replace multi-layer perceptron (MLP) with mixture of expert layer. The diagram below compares a dense model with a mixture of expert model. The blue FFN layers are the MLP.

mixture of experts structure

Source: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.

Note that there are also implementations of LLMs that split the attention heads in the self-attention mechanism (the red layer) as experts. However these implementations are not common because they are difficult to train consistently.


Why study mixture of experts

MoEs are getting popular because:

  • Selective activation, or sparsity, allow models to have a significantly larger number of parameters without a proportional increase in computational cost during pre-training and inference.
  • Fedus et. al. showed that models with more experts have lower test loss and get better perplexity faster.

moe test loss

Source: Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity by Fedus et. al., 2022

  • MoEs also provide a natural way to parallelise training and inference at the experts level, i.e. by putting each FFN on a different device.

Routing taxonomy

Broadly, there are four major variants of routing tokens to experts:

  • Top-k: pick the top-k most highly activated experts to process the token. This strategy is used in models from Qwen, DeepSeek, Mixtral, Grok etc.

top-k

  • Hashing: Route token based on its hash value. It has low compute cost and is relatively easier to implement as compared to the other routing strategies.

hashing

  • Reinforcement learning to learn routes. Uncommon nowadays due to significantly higher compute costs, and contributes to instability of training.

reinforcement learning

  • Solve a matching problem, such as linear assignment problems.

solve matching problem

Source of diagrams: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.

I will focus only on the routing strategy top-k as it is the most popular amongst modern LLMs.


Top-k routing strategy

The "choose top-k" strategy has three main variants:

  • 1. Token chooses expert
  • 2. Expert chooses token
  • 3. Global routing via optimization

Routing taxonomy

Source: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.

Illustration of routing by hand

The core of the routing process is a matrix of weights, often generated by a gating network, which represents the affinity between each token and each expert. Let us consider a scenario with 4 tokens (T1, T2, T3, T4) and 4 experts (E1, E2, E3, E4), and the routing mechanism SS contains the scores for each token-expert pair:

S=(T1T2T3T4E13.11.20.51.8E20.82.51.91.1E32.20.73.50.4E41.52.81.33.2)S = \begin{pmatrix} & T1 & T2 & T3 & T4 \newline E1 & 3.1 & 1.2 & 0.5 & 1.8 \newline E2 & 0.8 & 2.5 & 1.9 & 1.1 \newline E3 & 2.2 & 0.7 & 3.5 & 0.4 \newline E4 & 1.5 & 2.8 & 1.3 & 3.2 \end{pmatrix}

The value SijS_{ij} represents the score for sending token jj to expert ii . A higher score indicates a better fit.

We now go through each of the 3 routing strategies, assuming top-k = 2 as it is also the more popular strategy for MoE LLMs.

1. Token Chooses Expert

In this method, each token independently selects the two experts with the highest affinity scores. This approach allows each token to be processed by multiple specialized experts.

Calculation:

For each token column, we identify the two experts (rows) with the highest scores.

  • T1: The top two scores are 3.1 (E1) and 2.2 (E3).
  • T2: The top two scores are 2.8 (E4) and 2.5 (E2).
  • T3: The top two scores are 3.5 (E3) and 1.9 (E2).
  • T4: The top two scores are 3.2 (E4) and 1.8 (E1).

Resulting Expert Load:

We tally the tokens assigned to each expert:

  • E1 processes: T1, T4
  • E2 processes: T2, T3
  • E3 processes: T1, T3
  • E4 processes: T2, T4

With k=2, each expert is assigned exactly two tokens, resulting in a balanced load for this specific example. However, token-choice routing can often lead to load imbalance where some experts receive a disproportionate number of tokens.

2. Expert Chooses Token

With expert-choice routing, the decision-making is inverted: each expert selects the top k tokens it is best suited to process. This method is designed to enforce a balanced load, as each expert is assigned a fixed number of tokens.

Calculation:

For each expert row, we identify the two tokens (columns) with the highest scores.

  • E1: The top two scores are 3.1 (T1) and 1.8 (T4).
  • E2: The top two scores are 2.5 (T2) and 1.9 (T3).
  • E3: The top two scores are 3.5 (T3) and 2.2 (T1).
  • E4: The top two scores are 3.2 (T4) and 2.8 (T2).

Resulting Token Assignment:

  • E1 chooses: T1, T4
  • E2 chooses: T2, T3
  • E3 chooses: T1, T3
  • E4 chooses: T2, T4

In this particular case, the outcome is identical to the "Token Chooses Expert" method. Each expert processes two tokens, and each token is processed by two experts. This symmetry is a feature of this specific matrix and not a general rule. In many real-world scenarios, the two methods would produce different assignments.

3. Global Routing via Optimization

Global routing seeks the most optimal assignment across the entire matrix to maximize the total affinity score. The constraints for this problem with k=2 are:

  1. Each token must be assigned to exactly two experts.
  2. Each expert must process exactly two tokens.

This is a balanced assignment problem that can be solved with algorithms like the Hungarian algorithm or by formulating it as a maximum weight bipartite matching problem.

Calculation:

The goal is to select eight scores from the matrix S such that every row and every column is chosen exactly twice, and the sum of these eight scores is maximized.

We can find this optimal assignment by selecting the eight highest scores in the matrix, as long as they satisfy the constraints. The eight highest scores are: 3.5, 3.2, 3.1, 2.8, 2.5, 2.2, 1.9, and 1.8.

Let's check if this selection meets the constraints:

  • Selected Pairs: (E3, T3), (E4, T4), (E1, T1), (E4, T2), (E2, T2), (E3, T1), (E2, T3), (E1, T4).

  • Token Assignments (2 per token):

    • T1: assigned to E1, E3
    • T2: assigned to E4, E2
    • T3: assigned to E3, E2
    • T4: assigned to E4, E1
  • Expert Assignments (2 per expert):

    • E1: assigned T1, T4
    • E2: assigned T2, T3
    • E3: assigned T3, T1
    • E4: assigned T4, T2

The constraints are perfectly met. The resulting assignment is identical to the previous two methods for this specific example. The total maximized score is the sum of these values:
3.5+3.2+3.1+2.8+2.5+2.2+1.9+1.8=21.03.5 + 3.2 + 3.1 + 2.8 + 2.5 + 2.2 + 1.9 + 1.8 = 21.0

While all three methods produced the same outcome here, this highlights a scenario of perfect alignment. In practice, especially with larger and more varied score matrices, these methods would yield different assignments, each with its own trade-offs between computational cost, load balancing, and model performance.

A note on the complexity of global routing

Global routing frames the assignment as an optimization problem: find the assignment of tokens to experts that maximizes the total affinity score, subject to a set of global constraints. For our example with k = 2, the constraints are:

  • 1. Each token must be assigned to exactly two experts.
  • 2. Each expert must be assigned exactly two tokens.

This is a balanced assignment problem, which is computationally more complex than the local token or expert choice methods. While algorithms like the Hungarian algorithm or other linear programming solvers can find the optimal solution, they are often too slow to be practical for routing in LLMs.

In our specific example matrix, it happens that simply selecting the eight highest affinity scores—3.5, 3.2, 3.1, 2.8, 2.5, 2.2, 1.9, and 1.8—coincidentally satisfies these constraints. However, this is not a general strategy and would fail on most other matrices.


Deeper dive into the math of "token chooses expert" strategy

A natural question is how are the numbers in routing by hand example derived? As the strategy of "token chooses expert" is the most common strategy in modern MoE LLMs, let us work out this strategy by hand with an example, assuming k = 2.

We will be working through the following equations specifically:

top-k routing in detail

Source: Stanford CS336 Language Modeling from Scratch, Lecture 4 on Mixture of Experts, 24 min 52 sec

Scenario Setup

Let's assume we are at a specific layer l in the model. We will focus on routing a single token, t.

  • Number of Experts (N): Let's say we have N=4 experts.
  • Top-K: We want to route our token to the top K=2 experts.
  • Token Representation ( utlu_t^l ): This is the input vector for our token. For simplicity, let's assume it's a 3-dimensional vector.
utl=(0.51.00.2) u_t^l = \begin{pmatrix} 0.5 \newline 1.0 \newline -0.2 \end{pmatrix}
  • Expert Embeddings ( eile_i^l ): Each expert has a learnable weight vector (embedding) of the same dimension, which is used by the router to calculate affinity.

e1l=(2.00.80.5) e_1^l = \begin{pmatrix} 2.0 \newline 0.8 \newline -0.5 \end{pmatrix} , e2l=(1.00.10.9) e_2^l = \begin{pmatrix} -1.0 \newline 0.1 \newline 0.9 \end{pmatrix} , e3l=(1.51.20.3) e_3^l = \begin{pmatrix} 1.5 \newline 1.2 \newline 0.3 \end{pmatrix} , e4l=(0.41.11.4) e_4^l = \begin{pmatrix} -0.4 \newline -1.1 \newline 1.4 \end{pmatrix}

Step 1: Calculate Raw Affinity Scores (Logits)

The first step is to determine how well our token utu_t aligns with each expert eie_i . This is done by calculating the dot product between the token's vector and each expert's embedding vector. This corresponds to the inner part of the third equation, (utlTeil)(u_t^{l^T} e_i^l) .

  • Logit for Expert 1: (utl)Te1l=(0.5×2.0)+(1.0×0.8)+(0.2×0.5)=1.0+0.8+0.1=1.9(u_t^l)^T e_1^l = (0.5 \times 2.0) + (1.0 \times 0.8) + (-0.2 \times -0.5) = 1.0 + 0.8 + 0.1 = 1.9
  • Logit for Expert 2: (utl)Te2l=(0.5×1.0)+(1.0×0.1)+(0.2×0.9)=0.5+0.10.18=0.58(u_t^l)^T e_2^l = (0.5 \times -1.0) + (1.0 \times 0.1) + (-0.2 \times 0.9) = -0.5 + 0.1 - 0.18 = -0.58
  • Logit for Expert 3: (utl)Te3l=(0.5×1.5)+(1.0×1.2)+(0.2×0.3)=0.75+1.20.06=1.89(u_t^l)^T e_3^l = (0.5 \times 1.5) + (1.0 \times 1.2) + (-0.2 \times 0.3) = 0.75 + 1.2 - 0.06 = 1.89
  • Logit for Expert 4: (utl)Te4l=(0.5×0.4)+(1.0×1.1)+(0.2×1.4)=0.21.10.28=1.58(u_t^l)^T e_4^l = (0.5 \times -0.4) + (1.0 \times -1.1) + (-0.2 \times 1.4) = -0.2 - 1.1 - 0.28 = -1.58

Our vector of raw logits is: [1.9,0.58,1.89,1.58][1.9, -0.58, 1.89, -1.58]

Step 2: Normalize Scores with Softmax

The next step is to normalize these logits into a probability distribution using the Softmax function. This gives us the scores si,ts_{i,t} , which represent the router's confidence for each expert.

Equation: si,t=Softmaxi(utlTeil)s_{i,t} = \text{Softmax}_i(u_t^{l^T} e_i^l)

The Softmax formula is: si,t=exp(logiti)j=1Nexp(logitj)s_{i,t} = \frac{\exp(\text{logit}i)}{\sum{j=1}^{N} \exp(\text{logit}_j)}

  • Exponentials:

    • exp(1.9)6.686\exp(1.9) \approx 6.686
    • exp(0.58)0.560\exp(-0.58) \approx 0.560
    • exp(1.89)6.619\exp(1.89) \approx 6.619
    • exp(1.58)0.206\exp(-1.58) \approx 0.206
  • Sum of exponentials:

    6.686+0.560+6.619+0.206=14.0716.686 + 0.560 + 6.619 + 0.206 = 14.071

  • Normalized scores si,ts_{i,t} :

    • s1,t=6.686/14.0710.475s_{1,t} = 6.686 / 14.071 \approx 0.475
    • s2,t=0.560/14.0710.040s_{2,t} = 0.560 / 14.071 \approx 0.040
    • s3,t=6.619/14.0710.470s_{3,t} = 6.619 / 14.071 \approx 0.470
    • s4,t=0.206/14.0710.015s_{4,t} = 0.206 / 14.071 \approx 0.015

Our vector of normalized scores is: st=[0.475,0.040,0.470,0.015]s_t = [0.475, 0.040, 0.470, 0.015] . Note that these values sum to 1.

Step 3: Select Top-K and Determine Gating Weights

Now we apply the Top-K logic to determine the final gating weights, gi,tg_{i,t} . The gate value is equal to the score si,ts_{i,t} if it's in the Top-K, otherwise it's zero.

Equation: gi,t={si,t,if si,tTopK(sj,t,K=2)0,otherwiseg_{i,t} = \begin{cases} s_{i,t}, & \text{if } s_{i,t} \in \text{TopK}({s_{j,t}}, K=2) \newline 0, & \text{otherwise} \end{cases}

  • Our scores are [0.475, 0.040, 0.470, 0.015].
  • The two highest scores (K=2) are 0.475 (for Expert 1) and 0.470 (for Expert 3).
  • The final gating weights are:
    • g1,t=s1,t=0.475g_{1,t} = s_{1,t} = 0.475
    • g2,t=0g_{2,t} = 0
    • g3,t=s3,t=0.470g_{3,t} = s_{3,t} = 0.470
    • g4,t=0g_{4,t} = 0

This means only Expert 1 and Expert 3 will be activated for this token.

Step 4: Calculate Final Output

Finally, the output of the MoE layer ( htlh_t^l ) is a weighted sum of the outputs from the selected experts, added to the original token representation (a residual connection).

Equation: htl=i=1N(gi,tFFNi(utl))+utlh_t^l = \sum_{i=1}^{N} (g_{i,t} \text{FFN}_i(u_t^l)) + u_t^l

Plugging in our gating values:

htl=(0.475FFN1(utl))+(0FFN2(utl))+(0.470FFN3(utl))+(0FFN4(utl))+utlh_t^l = (0.475 \cdot \text{FFN}_1(u_t^l)) + (0 \cdot \text{FFN}_2(u_t^l)) + (0.470 \cdot \text{FFN}_3(u_t^l)) + (0 \cdot \text{FFN}_4(u_t^l)) + u_t^l

This simplifies to:

htl=0.475FFN1(utl)+0.470FFN3(utl)+utlh_t^l = 0.475 \cdot \text{FFN}_1(u_t^l) + 0.470 \cdot \text{FFN}_3(u_t^l) + u_t^l

The final output is the combination of the outputs from the two chosen experts, weighted by their routing scores, plus the original input token vector.

A Note on Variants (Mixtral, DBRX)

Models like Mixtral and DBRX use a slightly different approach where they apply Softmax after selecting the Top-K.

Using our example, that would mean:

  1. Select Top-K Logits: Take the raw logits [1.9, -0.58, 1.89, -1.58]. The Top-2 are 1.9 and 1.89.
  2. Apply Softmax to only the Top-K:
    • g1,t=exp(1.9)exp(1.9)+exp(1.89)=6.6866.686+6.6190.503g_{1,t} = \frac{\exp(1.9)}{\exp(1.9) + \exp(1.89)} = \frac{6.686}{6.686 + 6.619} \approx 0.503
    • g3,t=exp(1.89)exp(1.9)+exp(1.89)=6.6196.686+6.6190.497g_{3,t} = \frac{\exp(1.89)}{\exp(1.9) + \exp(1.89)} = \frac{6.619}{6.686 + 6.619} \approx 0.497
    • g2,t=0g_{2,t} = 0 , g4,t=0g_{4,t} = 0

In this variant, the final gating weights for the chosen experts are re-normalized to sum to 1.


Heuristic balancing losses

System efficiency requires that we use experts evenly, so that computation is evenly distributed across all experts. In the worst case where all tokens are routed to a single expert, it is as bad as having a single dense model while having significantly higher memory overhead from the unused experts.

To encourage a balanced load, an auxiliary loss is often added to the total model loss during training. The following section illustrates the load balancing loss introduced in the Switch Transformers paper. It is important to note that this specific formulation was designed for a top-1 routing scenario (k = 1), where each token is dispatched to the single expert with the highest score (argmax). We use a top-1 example here to faithfully explain the paper's original equations. See Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.

balancing loss

1. Equation Explanations

Here is a breakdown of each equation:

Equation (4): The Auxiliary Loss

This equation defines the overall auxiliary loss that is added to the model's main training loss (e.g., cross-entropy).

loss=αNi=1NfiPi(4) \text{loss} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i \quad (4)
  • loss: The final calculated auxiliary loss value.
  • α (alpha): A small, constant hyperparameter (the text suggests 10^-2) used to scale this auxiliary loss. It ensures that load balancing doesn't overwhelm the primary goal of the main loss function.
  • N: The total number of experts in the MoE layer.
  • f_i: The fraction of tokens in the batch that are actually dispatched to expert i.
  • P_i: The average fraction of the router's probability (or "confidence") allocated to expert i across all tokens in the batch.

Meaning: This loss is essentially a scaled dot product between the vector of token fractions (f) and the vector of average router probabilities (P). This loss is minimized when the tokens are distributed uniformly, meaning each expert i receives 1/N of the tokens (f_i = 1/N) and the router assigns an average probability of 1/N to each expert (P_i = 1/N).

Equation (5): Fraction of Tokens Dispatched

This equation calculates f_i, the actual proportion of tokens assigned to each expert.

fi=1TxB1{argmax p(x)=i}(5) f_i = \frac{1}{T} \sum_{x \in B} 1\{\text{argmax } p(x) = i\} \quad (5)
  • T: The total number of tokens in the batch B.
  • p(x): The output of the router for a single token x. It's a probability vector of size N, where each element is the probability of sending the token to the corresponding expert.
  • argmax p(x): This finds the index of the expert that received the highest probability for token x.
  • 1{...}: This is an indicator function. It returns 1 if the condition inside is true (i.e., if expert i was chosen for token x) and 0 otherwise.

Meaning: This formula counts how many tokens in the batch are assigned to expert i (based on the highest router probability) and then divides by the total number of tokens to get the fraction. As noted in the text, the use of argmax makes this function non-differentiable, which has implications for the backward pass.

Equation (6): Fraction of Router Probability

This equation calculates P_i, the average router probability assigned to each expert.

Pi=1TxBpi(x)(6) P_i = \frac{1}{T} \sum_{x \in B} p_i(x) \quad (6)
  • p_i(x): The specific probability assigned to expert i for token x. This is the i-th element of the vector p(x).

Meaning: This formula calculates the average "vote" or probability mass the router gives to expert i across all tokens in the batch. Unlike f_i, this calculation is based on the soft probabilities from the router and is fully differentiable.

2. Illustrative Example

Let's use a simple example to make these equations concrete.

  • Number of experts N: 3
  • Number of tokens in batch T: 4
  • Hyperparameter α: 0.01

Assume our router network processes the 4 tokens and outputs the following probability distributions (p(x)). Each row corresponds to a token, and each column corresponds to an expert.

Expert 1 Expert 2 Expert 3
Token 1 0.7 0.2 0.1
Token 2 0.1 0.8 0.1
Token 3 0.6 0.3 0.1
Token 4 0.2 0.2 0.6

This matrix represents the p_i(x) values.

3. Forward Pass Demonstration

Using our example matrix, we will now calculate the auxiliary loss.

Step 1: Calculate f (Fraction of Tokens)

We apply Equation (5) by finding the argmax for each token (row):

  • Token 1: argmax([0.7, 0.2, 0.1]) = Expert 1
  • Token 2: argmax([0.1, 0.8, 0.1]) = Expert 2
  • Token 3: argmax([0.6, 0.3, 0.1]) = Expert 1
  • Token 4: argmax([0.2, 0.2, 0.6]) = Expert 3

Now, we count the assignments for each expert:

  • Expert 1 was chosen 2 times.
  • Expert 2 was chosen 1 time.
  • Expert 3 was chosen 1 time.

Finally, we calculate the fractions (f_i = count / T):

  • f1=2/4=0.5f_1 = 2 / 4 = 0.5
  • f2=1/4=0.25f_2 = 1 / 4 = 0.25
  • f3=1/4=0.25f_3 = 1 / 4 = 0.25

So, our f vector is [0.5, 0.25, 0.25].

Step 2: Calculate P (Fraction of Probability)

We apply Equation (6) by averaging the probabilities in each column of our matrix:

  • P1=(0.7+0.1+0.6+0.2)/4=1.6/4=0.4P_1 = (0.7 + 0.1 + 0.6 + 0.2) / 4 = 1.6 / 4 = 0.4
  • P2=(0.2+0.8+0.3+0.2)/4=1.5/4=0.375P_2 = (0.2 + 0.8 + 0.3 + 0.2) / 4 = 1.5 / 4 = 0.375
  • P3=(0.1+0.1+0.1+0.6)/4=0.9/4=0.225P_3 = (0.1 + 0.1 + 0.1 + 0.6) / 4 = 0.9 / 4 = 0.225

So, our P vector is [0.4, 0.375, 0.225].

Step 3: Calculate the Final Auxiliary Loss

We use Equation (4) with the f and P vectors we just calculated:

loss=αNi=1NfiPi \text{loss} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i
=0.013((0.50.4)+(0.250.375)+(0.250.225)) = 0.01 \cdot 3 \cdot ( (0.5 \cdot 0.4) + (0.25 \cdot 0.375) + (0.25 \cdot 0.225) )
=0.03(0.2+0.09375+0.05625) = 0.03 \cdot ( 0.2 + 0.09375 + 0.05625 )
=0.03(0.35)=0.0105 = 0.03 \cdot ( 0.35 ) = 0.0105

This value, 0.0105, is the auxiliary loss for this batch, which will be added to the main model loss before the backward pass.

4. Backward Pass Demonstration

During backpropagation, we need to calculate the gradient of the loss with respect to the model's parameters, specifically the weights of the router network. The gradient flows backward from the loss through the calculations we performed.

A key detail from the text is that the f vector is not differentiable due to the argmax function. In practice, this is handled by treating f as a constant during the backward pass. This is an application of a concept similar to a straight-through estimator (STE), where the gradient for the non-differentiable argmax operation is effectively passed through as if it were an identity function for the forward pass values. As a result, the gradient signal only flows back through the differentiable part of the loss equation, which is the P_i term (the average router probability).

Step 1: Gradient of the Loss w.r.t. P_i

We start by finding the derivative of the loss function with respect to each element of the P vector.

lossPi=Pi(αNj=1NfjPj) \frac{\partial \text{loss}}{\partial P_i} = \frac{\partial}{\partial P_i} \left( \alpha \cdot N \cdot \sum_{j=1}^{N} f_j \cdot P_j \right)

Since f is treated as a constant, the derivative is:

lossPi=αNfi \frac{\partial \text{loss}}{\partial P_i} = \alpha \cdot N \cdot f_i

Let's calculate this for our example:

  • lossP1=0.013f1=0.030.5=0.015\frac{\partial \text{loss}}{\partial P_1} = 0.01 \cdot 3 \cdot f_1 = 0.03 \cdot 0.5 = 0.015
  • lossP2=0.013f2=0.030.25=0.0075\frac{\partial \text{loss}}{\partial P_2} = 0.01 \cdot 3 \cdot f_2 = 0.03 \cdot 0.25 = 0.0075
  • lossP3=0.013f3=0.030.25=0.0075\frac{\partial \text{loss}}{\partial P_3} = 0.01 \cdot 3 \cdot f_3 = 0.03 \cdot 0.25 = 0.0075

Step 2: Gradient of P_i w.r.t. p_i(x)

Next, we need the gradient of P_i (the average probability) with respect to the individual router output p_i(x) for a specific token x.

Pi=1TxBpi(x)    Pipi(x)=1T P_i = \frac{1}{T} \sum_{x \in B} p_i(x) \implies \frac{\partial P_i}{\partial p_i(x)} = \frac{1}{T}

For our example, this is 1 / 4 = 0.25.

Step 3: Combine Gradients (Chain Rule)

Using the chain rule, we find the gradient of the loss with respect to each individual router probability p_i(x).

losspi(x)=lossPiPipi(x)=(αNfi)1T \frac{\partial \text{loss}}{\partial p_i(x)} = \frac{\partial \text{loss}}{\partial P_i} \cdot \frac{\partial P_i}{\partial p_i(x)} = (\alpha \cdot N \cdot f_i) \cdot \frac{1}{T}

This gradient value is the same for all tokens x in the batch. Let's calculate it:

  • lossp1(x)=0.01514=0.00375\frac{\partial \text{loss}}{\partial p_1(x)} = 0.015 \cdot \frac{1}{4} = 0.00375
  • lossp2(x)=0.007514=0.001875\frac{\partial \text{loss}}{\partial p_2(x)} = 0.0075 \cdot \frac{1}{4} = 0.001875
  • lossp3(x)=0.007514=0.001875\frac{\partial \text{loss}}{\partial p_3(x)} = 0.0075 \cdot \frac{1}{4} = 0.001875

This gives us a gradient vector [0.00375, 0.001875, 0.001875]. This gradient is then backpropagated to the router's softmax layer for every token. The optimizer will use this gradient to update the router's weights, encouraging it to produce probabilities that lead to a more balanced load in the next iteration. In this case, since expert 1 was over-utilized (f_1 was high), it receives the largest gradient signal, which will push the router to assign it lower probabilities in the future.

Top comments (0)