Topics

Causal linear attention can be expressed as an RNN. This applies when a token at position only attends to tokens at positions .

This formulation, shown by Katharopoulos et al. (2020), defines state at time using accumulated key-value (context) summary and accumulated key summary :

For more details on notation, refer linear attention math formulation. Cool observation is that these states update recurrently:

Output at time computed using current query and updated state :

Standard causal attention with kv-caching (storing all previous keys/values) needs memory. Causal linear attention with RNN formulation only needs memory (specifically per step during inference) for fixed-size states , making it efficient for long sequences.

Pros:

  • time complexity and space complexity per step during inference
  • Good for long sequences

Cons:

  • Sequential nature of the RNN formulation hinders parallelization during training
  • The kernel function is still an approximation, so performance < standard attn

Trends:

  • To mitigate training slowdown, use chunkwise processing: input sequence is divided into non-overlapping chunks. Within each chunk, computations can be parallelized. The recurrent state update is performed sequentially between chunks. Nice balance between parallelization and sequential nature required for causal modeling across chunks
  • Hardware-aware implementations of linear attention using tools like Triton
  • Gaining traction through libraries like flash-linear-attention