Topics

Router z-loss, introduced in switch-transformer MoE, significantly improves training stability without quality degradation by penalizing large router logits i.e., the raw output scores of the gated network aka router (before any softmax etc)

Since this loss encourages absolute magnitude of values to be smaller, roundoff errors are reduced, which can be quite impactful for exponential functions such as the softmax, used in routing.

router_z_loss = torch.logsumexp(gate_logits, dim = -1)
router_z_loss = torch.square(router_z_loss)            
router_z_loss = router_z_loss.mean()

Here gate_logits refer to logits that will be entering the gates.

In the implementation, we can see the use of log-sum-exp trick.