Topics

Sparse attention methods modify the standard self-attention mechanism to improve computational efficiency. Standard self-attention has a quadratic complexity with respect to sequence length . Sparse attention aims to reduce this cost.

The core idea is to restrict each query token to attend to only a subset of key tokens, rather than all keys. If each query attends to a fixed small number of keys, the complexity can be reduced to . This is much more efficient when . The motivation often comes from observing that in vanilla attention, many query-key interactions have very small weights and contribute little to the final output; these can potentially be ignored. There are various techniques for achieving sparsity.

Fixed Patterns

These methods use predefined, static patterns to determine which keys each query attends to.

Content-Based / Learnable Patterns:

Sparsity patterns are determined dynamically based on the input sequence content.

  • locality sensitive hashing: Uses hashing to group similar queries and keys into buckets, restricting attention computation within buckets. Employed by reformer
  • Clustering: Models like routing transformer uses online k-means clustering of token representations: queries attend only to keys in the same cluster. Another strat is to group queries and compute attention only with cluster centroids (e.g., clustered attention)
  • Learnable Positions: The model learns which positions are most relevant to attend to (e.g., A2-Former, Sparsefinder)
  • Top-k Selection: Selects a constant number of the highest-scoring keys for each query (e.g., SparseK Attention). Requires differentiable top-k operators for training
  • Adaptive Patterns: Dynamically adjust the sparsity pattern and/or computational budget based on the input characteristics or attention head needs (e.g., FlexPrefill, Mixture of Attention (MoA))

Fixed patterns are simpler but potentially suboptimal. Learnable or adaptive patterns offer more flexibility but increase model complexity and may require sophisticated training procedures

Random Patterns

Each query attends to a randomly selected set of keys. This helps ensure some global context while maintaining sparsity. Used as a component in bigbird to facilitate information flow.

Hierarchical Methods

Tree based view: Applied on the input tokens and restrict attention based on parent-child-sibling relationships (e.g., Sparse Transformer, Sequoia). Natively Trainable Sparse Attention (NSA) uses hierarchical token modeling with compressed coarse-grained tokens, fine-grained tokens, and sliding windows.

Graph-Based View: Perceives tokens as nodes and the attention mask as edges, applying graph algorithms for efficient computation.

These sparsity techinques discussed can also be mixed together: e.g:

  • BigBird’s block + random + global
  • strided + block patterns: full attention within block and between blocks attend only to every -th block or token.