Topics
Masked attention is a modification of the standard attention mechanism used primarily in sequence generation tasks. It ensures that when computing the output for a position in a sequence, the model can only attend to previous positions and the current position, and not future ones.
This is crucial for autoregressive models where each output token depends only on the preceding tokens. Without masking, the model could “cheat” by looking at the target output sequence it is trying to predict.
Implementation involves adding a mask matrix to the raw attention scores (often computed as the dot product of Query and Key matrices, ) before applying the softmax function.
- The mask matrix typically has values of zero in positions corresponding to allowed connections (past and current positions)
- It has very large negative values (like ) in positions corresponding to disallowed connections (future positions)
When the softmax function is applied after adding the mask, the large negative values result in attention weights close to zero for the masked positions. This effectively prevents the model from attending to those future elements.
where the Mask has values for future positions. For sequence length 3, mask:
This mechanism is a core component in the decoder part of the transformer architecture, enabling it to generate sequences one element at a time based on previously generated output.
Note
Masked attention is often referred to as “causal attention” because it enforces causality, meaning the output at a time step depends only on inputs from time steps .