Table of Contents
- Motivation
- What are mixture of experts
- Why study mixture of experts
- Routing taxonomy
- Top-k routing strategy
- Deeper dive into the math of "token chooses expert" strategy
- Heuristic balancing losses
Motivation
This article explains by hand how to (i) route tokens and (ii balance losses in routing when training Mixture of Experts (MoEs) large language models.
For those following the Stanford CS 336 class, this could be used as a study guide that complements Lecture 4 on Mixture of Experts.
This article does not seek to carry out a comprehensive review of the entire MoEs literature. Hence, I will only focus on the most popular variant of routing, i.e. (i) tokens choose experts where k = 2
; and (ii) heuristic balancing losses.
This article was written with the assistance of Google Gemini 2.5.
What are mixture of experts
MoEs is a machine learning technique that divides a complex problem among multiple specialized models, known as "experts". A "gating network" or "router" then selects the most suitable expert or combination of experts to handle a given input. (Fun fact: sparse expert models are a thirty-year old concept.)
The core idea behind MoEs is to replace dense feed-forward network (FFN) layers with sparse MoEs layers. In a traditional dense model, the entire network is activated to process any input. In contrast, MoE models use conditional computation, meaning only a select few "expert" sub-networks are activated for each input token.
In the context of large language models, the more popular implementation is to replace multi-layer perceptron (MLP) with mixture of expert layer. The diagram below compares a dense model with a mixture of expert model. The blue FFN layers are the MLP.
Source: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.
Note that there are also implementations of LLMs that split the attention heads in the self-attention mechanism (the red layer) as experts. However these implementations are not common because they are difficult to train consistently.
Why study mixture of experts
MoEs are getting popular because:
- Selective activation, or sparsity, allow models to have a significantly larger number of parameters without a proportional increase in computational cost during pre-training and inference.
- Fedus et. al. showed that models with more experts have lower test loss and get better perplexity faster.
Source: Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity by Fedus et. al., 2022
- MoEs also provide a natural way to parallelise training and inference at the experts level, i.e. by putting each FFN on a different device.
Routing taxonomy
Broadly, there are four major variants of routing tokens to experts:
-
Top-k: pick the
top-k
most highly activated experts to process the token. This strategy is used in models from Qwen, DeepSeek, Mixtral, Grok etc.
- Hashing: Route token based on its hash value. It has low compute cost and is relatively easier to implement as compared to the other routing strategies.
- Reinforcement learning to learn routes. Uncommon nowadays due to significantly higher compute costs, and contributes to instability of training.
- Solve a matching problem, such as linear assignment problems.
Source of diagrams: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.
I will focus only on the routing strategy top-k as it is the most popular amongst modern LLMs.
Top-k routing strategy
The "choose top-k" strategy has three main variants:
- 1. Token chooses expert
- 2. Expert chooses token
- 3. Global routing via optimization
Source: A Review of Sparse Expert Models in Deep Learning by Fedus et. al., 2022.
Illustration of routing by hand
The core of the routing process is a matrix of weights, often generated by a gating network, which represents the affinity between each token and each expert. Let us consider a scenario with 4 tokens (T1, T2, T3, T4) and 4 experts (E1, E2, E3, E4), and the routing mechanism contains the scores for each token-expert pair:
The value represents the score for sending token to expert . A higher score indicates a better fit.
We now go through each of the 3 routing strategies, assuming top-k = 2
as it is also the more popular strategy for MoE LLMs.
1. Token Chooses Expert
In this method, each token independently selects the two experts with the highest affinity scores. This approach allows each token to be processed by multiple specialized experts.
Calculation:
For each token column, we identify the two experts (rows) with the highest scores.
- T1: The top two scores are 3.1 (E1) and 2.2 (E3).
- T2: The top two scores are 2.8 (E4) and 2.5 (E2).
- T3: The top two scores are 3.5 (E3) and 1.9 (E2).
- T4: The top two scores are 3.2 (E4) and 1.8 (E1).
Resulting Expert Load:
We tally the tokens assigned to each expert:
- E1 processes: T1, T4
- E2 processes: T2, T3
- E3 processes: T1, T3
- E4 processes: T2, T4
With k=2
, each expert is assigned exactly two tokens, resulting in a balanced load for this specific example. However, token-choice routing can often lead to load imbalance where some experts receive a disproportionate number of tokens.
2. Expert Chooses Token
With expert-choice routing, the decision-making is inverted: each expert selects the top k
tokens it is best suited to process. This method is designed to enforce a balanced load, as each expert is assigned a fixed number of tokens.
Calculation:
For each expert row, we identify the two tokens (columns) with the highest scores.
- E1: The top two scores are 3.1 (T1) and 1.8 (T4).
- E2: The top two scores are 2.5 (T2) and 1.9 (T3).
- E3: The top two scores are 3.5 (T3) and 2.2 (T1).
- E4: The top two scores are 3.2 (T4) and 2.8 (T2).
Resulting Token Assignment:
- E1 chooses: T1, T4
- E2 chooses: T2, T3
- E3 chooses: T1, T3
- E4 chooses: T2, T4
In this particular case, the outcome is identical to the "Token Chooses Expert" method. Each expert processes two tokens, and each token is processed by two experts. This symmetry is a feature of this specific matrix and not a general rule. In many real-world scenarios, the two methods would produce different assignments.
3. Global Routing via Optimization
Global routing seeks the most optimal assignment across the entire matrix to maximize the total affinity score. The constraints for this problem with k=2
are:
- Each token must be assigned to exactly two experts.
- Each expert must process exactly two tokens.
This is a balanced assignment problem that can be solved with algorithms like the Hungarian algorithm or by formulating it as a maximum weight bipartite matching problem.
Calculation:
The goal is to select eight scores from the matrix S
such that every row and every column is chosen exactly twice, and the sum of these eight scores is maximized.
We can find this optimal assignment by selecting the eight highest scores in the matrix, as long as they satisfy the constraints. The eight highest scores are: 3.5, 3.2, 3.1, 2.8, 2.5, 2.2, 1.9, and 1.8.
Let's check if this selection meets the constraints:
Selected Pairs: (E3, T3), (E4, T4), (E1, T1), (E4, T2), (E2, T2), (E3, T1), (E2, T3), (E1, T4).
-
Token Assignments (2 per token):
- T1: assigned to E1, E3
- T2: assigned to E4, E2
- T3: assigned to E3, E2
- T4: assigned to E4, E1
-
Expert Assignments (2 per expert):
- E1: assigned T1, T4
- E2: assigned T2, T3
- E3: assigned T3, T1
- E4: assigned T4, T2
The constraints are perfectly met. The resulting assignment is identical to the previous two methods for this specific example. The total maximized score is the sum of these values:
While all three methods produced the same outcome here, this highlights a scenario of perfect alignment. In practice, especially with larger and more varied score matrices, these methods would yield different assignments, each with its own trade-offs between computational cost, load balancing, and model performance.
A note on the complexity of global routing
Global routing frames the assignment as an optimization problem: find the assignment of tokens to experts that maximizes the total affinity score, subject to a set of global constraints. For our example with k = 2
, the constraints are:
- 1. Each token must be assigned to exactly two experts.
- 2. Each expert must be assigned exactly two tokens.
This is a balanced assignment problem, which is computationally more complex than the local token or expert choice methods. While algorithms like the Hungarian algorithm or other linear programming solvers can find the optimal solution, they are often too slow to be practical for routing in LLMs.
In our specific example matrix, it happens that simply selecting the eight highest affinity scores—3.5, 3.2, 3.1, 2.8, 2.5, 2.2, 1.9, and 1.8—coincidentally satisfies these constraints. However, this is not a general strategy and would fail on most other matrices.
Deeper dive into the math of "token chooses expert" strategy
A natural question is how are the numbers in routing by hand example derived? As the strategy of "token chooses expert" is the most common strategy in modern MoE LLMs, let us work out this strategy by hand with an example, assuming k = 2
.
We will be working through the following equations specifically:
Source: Stanford CS336 Language Modeling from Scratch, Lecture 4 on Mixture of Experts, 24 min 52 sec
Scenario Setup
Let's assume we are at a specific layer l
in the model. We will focus on routing a single token, t
.
- Number of Experts (N): Let's say we have
N=4
experts. - Top-K: We want to route our token to the top
K=2
experts. - Token Representation ( ): This is the input vector for our token. For simplicity, let's assume it's a 3-dimensional vector.
- Expert Embeddings ( ): Each expert has a learnable weight vector (embedding) of the same dimension, which is used by the router to calculate affinity.
, , ,
Step 1: Calculate Raw Affinity Scores (Logits)
The first step is to determine how well our token aligns with each expert . This is done by calculating the dot product between the token's vector and each expert's embedding vector. This corresponds to the inner part of the third equation, .
- Logit for Expert 1:
- Logit for Expert 2:
- Logit for Expert 3:
- Logit for Expert 4:
Our vector of raw logits is:
Step 2: Normalize Scores with Softmax
The next step is to normalize these logits into a probability distribution using the Softmax function. This gives us the scores , which represent the router's confidence for each expert.
Equation:
The Softmax formula is:
-
Exponentials:
Sum of exponentials:
-
Normalized scores :
Our vector of normalized scores is: . Note that these values sum to 1.
Step 3: Select Top-K and Determine Gating Weights
Now we apply the Top-K logic to determine the final gating weights, . The gate value is equal to the score if it's in the Top-K, otherwise it's zero.
Equation:
- Our scores are
[0.475, 0.040, 0.470, 0.015]
. - The two highest scores (K=2) are 0.475 (for Expert 1) and 0.470 (for Expert 3).
- The final gating weights are:
This means only Expert 1 and Expert 3 will be activated for this token.
Step 4: Calculate Final Output
Finally, the output of the MoE layer ( ) is a weighted sum of the outputs from the selected experts, added to the original token representation (a residual connection).
Equation:
Plugging in our gating values:
This simplifies to:
The final output is the combination of the outputs from the two chosen experts, weighted by their routing scores, plus the original input token vector.
A Note on Variants (Mixtral, DBRX)
Models like Mixtral and DBRX use a slightly different approach where they apply Softmax after selecting the Top-K.
Using our example, that would mean:
- Select Top-K Logits: Take the raw logits
[1.9, -0.58, 1.89, -1.58]
. The Top-2 are 1.9 and 1.89. - Apply Softmax to only the Top-K:
- ,
In this variant, the final gating weights for the chosen experts are re-normalized to sum to 1.
Heuristic balancing losses
System efficiency requires that we use experts evenly, so that computation is evenly distributed across all experts. In the worst case where all tokens are routed to a single expert, it is as bad as having a single dense model while having significantly higher memory overhead from the unused experts.
To encourage a balanced load, an auxiliary loss is often added to the total model loss during training. The following section illustrates the load balancing loss introduced in the Switch Transformers paper. It is important to note that this specific formulation was designed for a top-1 routing scenario (k = 1
), where each token is dispatched to the single expert with the highest score (argmax). We use a top-1 example here to faithfully explain the paper's original equations. See Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
1. Equation Explanations
Here is a breakdown of each equation:
Equation (4): The Auxiliary Loss
This equation defines the overall auxiliary loss that is added to the model's main training loss (e.g., cross-entropy).
-
loss
: The final calculated auxiliary loss value. -
α
(alpha): A small, constant hyperparameter (the text suggests10^-2
) used to scale this auxiliary loss. It ensures that load balancing doesn't overwhelm the primary goal of the main loss function. -
N
: The total number of experts in the MoE layer. -
f_i
: The fraction of tokens in the batch that are actually dispatched to experti
. -
P_i
: The average fraction of the router's probability (or "confidence") allocated to experti
across all tokens in the batch.
Meaning: This loss is essentially a scaled dot product between the vector of token fractions (f
) and the vector of average router probabilities (P
). This loss is minimized when the tokens are distributed uniformly, meaning each expert i
receives 1/N
of the tokens (f_i = 1/N
) and the router assigns an average probability of 1/N
to each expert (P_i = 1/N
).
Equation (5): Fraction of Tokens Dispatched
This equation calculates f_i
, the actual proportion of tokens assigned to each expert.
-
T
: The total number of tokens in the batchB
. -
p(x)
: The output of the router for a single tokenx
. It's a probability vector of sizeN
, where each element is the probability of sending the token to the corresponding expert. -
argmax p(x)
: This finds the index of the expert that received the highest probability for tokenx
. -
1{...}
: This is an indicator function. It returns1
if the condition inside is true (i.e., if experti
was chosen for tokenx
) and0
otherwise.
Meaning: This formula counts how many tokens in the batch are assigned to expert i
(based on the highest router probability) and then divides by the total number of tokens to get the fraction. As noted in the text, the use of argmax
makes this function non-differentiable, which has implications for the backward pass.
Equation (6): Fraction of Router Probability
This equation calculates P_i
, the average router probability assigned to each expert.
-
p_i(x)
: The specific probability assigned to experti
for tokenx
. This is thei
-th element of the vectorp(x)
.
Meaning: This formula calculates the average "vote" or probability mass the router gives to expert i
across all tokens in the batch. Unlike f_i
, this calculation is based on the soft probabilities from the router and is fully differentiable.
2. Illustrative Example
Let's use a simple example to make these equations concrete.
- Number of experts
N
:3
- Number of tokens in batch
T
:4
- Hyperparameter
α
:0.01
Assume our router network processes the 4 tokens and outputs the following probability distributions (p(x)
). Each row corresponds to a token, and each column corresponds to an expert.
Expert 1 | Expert 2 | Expert 3 | |
---|---|---|---|
Token 1 | 0.7 | 0.2 | 0.1 |
Token 2 | 0.1 | 0.8 | 0.1 |
Token 3 | 0.6 | 0.3 | 0.1 |
Token 4 | 0.2 | 0.2 | 0.6 |
This matrix represents the p_i(x)
values.
3. Forward Pass Demonstration
Using our example matrix, we will now calculate the auxiliary loss.
Step 1: Calculate f
(Fraction of Tokens)
We apply Equation (5) by finding the argmax
for each token (row):
- Token 1:
argmax([0.7, 0.2, 0.1])
= Expert 1 - Token 2:
argmax([0.1, 0.8, 0.1])
= Expert 2 - Token 3:
argmax([0.6, 0.3, 0.1])
= Expert 1 - Token 4:
argmax([0.2, 0.2, 0.6])
= Expert 3
Now, we count the assignments for each expert:
- Expert 1 was chosen 2 times.
- Expert 2 was chosen 1 time.
- Expert 3 was chosen 1 time.
Finally, we calculate the fractions (f_i = count / T
):
So, our f
vector is [0.5, 0.25, 0.25]
.
Step 2: Calculate P
(Fraction of Probability)
We apply Equation (6) by averaging the probabilities in each column of our matrix:
So, our P
vector is [0.4, 0.375, 0.225]
.
Step 3: Calculate the Final Auxiliary Loss
We use Equation (4) with the f
and P
vectors we just calculated:
This value, 0.0105
, is the auxiliary loss for this batch, which will be added to the main model loss before the backward pass.
4. Backward Pass Demonstration
During backpropagation, we need to calculate the gradient of the loss with respect to the model's parameters, specifically the weights of the router network. The gradient flows backward from the loss through the calculations we performed.
A key detail from the text is that the f
vector is not differentiable due to the argmax
function. In practice, this is handled by treating f
as a constant during the backward pass. This is an application of a concept similar to a straight-through estimator (STE), where the gradient for the non-differentiable argmax operation is effectively passed through as if it were an identity function for the forward pass values. As a result, the gradient signal only flows back through the differentiable part of the loss equation, which is the P_i
term (the average router probability).
Step 1: Gradient of the Loss w.r.t. P_i
We start by finding the derivative of the loss function with respect to each element of the P
vector.
Since f
is treated as a constant, the derivative is:
Let's calculate this for our example:
Step 2: Gradient of P_i
w.r.t. p_i(x)
Next, we need the gradient of P_i
(the average probability) with respect to the individual router output p_i(x)
for a specific token x
.
For our example, this is 1 / 4 = 0.25
.
Step 3: Combine Gradients (Chain Rule)
Using the chain rule, we find the gradient of the loss with respect to each individual router probability p_i(x)
.
This gradient value is the same for all tokens x
in the batch. Let's calculate it:
This gives us a gradient vector [0.00375, 0.001875, 0.001875]
. This gradient is then backpropagated to the router's softmax layer for every token. The optimizer will use this gradient to update the router's weights, encouraging it to produce probabilities that lead to a more balanced load in the next iteration. In this case, since expert 1 was over-utilized (f_1
was high), it receives the largest gradient signal, which will push the router to assign it lower probabilities in the future.
Top comments (0)