Topics
A vanilla approach to penalize imbalanced routing in MoEs.
This loss ensures that all experts receive a roughly equal number of training examples.
def compute_aux_loss(router_probs, num_experts):
# Calculate fraction of tokens going to each expert
expert_usage = router_probs.mean(dim=0) # [num_experts]
# Compute variance from ideal uniform distribution
uniform_prob = 1.0 / num_experts
balance_loss = torch.sum((expert_usage - uniform_prob) ** 2)
return balance_loss
In sample code above, router_probs
are weights generated during token routing.
Note
cross-entropy loss forces router to send tokens to better and accurate experts, but the auxiliary loss tries to even out the token routing across experts. We can control their influences by some weighing factors.