Topics
Multi-Head Attention (MHA) is a core component of the transformer model, applied in both self-attention and cross-attention. Instead of performing attention using full Q, K, V vectors, MHA splits the computation into multiple parallel “heads”.
Mechanim:
First, input Q, K, V matrices (dimension ) are linearly projected times using different learned weight matrices:
for each head . This results in sets of projected queries (), keys (), and values (). The dimensions and are typically set to .
Note
Dims of Key and Query should match for dot product calculation.
Next, the scaled dot product attention function is applied independently and in parallel to each projected set:
Each captures attention information from a different projected subspace. Finally, the outputs of the heads, (each dimension ), are concatenated along the feature dimension. This concatenated output (dimension ) is then passed through a final linear projection layer with a weight matrix to produce the final MHA output:
This final projection integrates the information learned across the different heads and ensures the output dimension matches the model’s expected dimension.
Pros:
- Allows model jointly attend to information from different representation subspaces/positions
- Multiple heads can specialize capturing different dependency types (local vs global, syntactic vs semantic)
- Facilitates parallel computation across heads
Cons:
- Head redundancy (“attention collapse”): Multiple heads learn similar functions/patterns
- Not all heads contribute equally, some might be prunable
- True specialization across all heads not guaranteed
Note
Research explores methods encouraging diversity among heads, e.g., regularization or dynamic head selection like Mixture-of-Head attention.