Topics

In BitNet b1.58, we quantize activation functions to reduce memory comsumption and improve speed. A concise implementation in PyTorch will look something like:

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1,
            keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y

Not that due to the precense of the round(), this becomes non-differentiable, so we use the detach trick for straight through estimators (STE)

STE: In the backward pass, basically pass the gradient as is through this layer without any calculation, pretending this function/layer to be identity

class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight
        x_quant = x + (activation_quant(x) - x).detach()
        
        ...