Attention variants — MQA, GQA, FlashAttention, linear & sparse

Standard scaled dot-product attention is the workhorse of every transformer, but it has two ugly cost terms: O(n^2 · d) compute and an O(n · h · d) KV cache at inference. A whole zoo of variants — MQA, GQA, FlashAttention, linear attention, sliding-window and sparse attention — exists to attack one or both of those terms without breaking the inductive bias too badly.

TL;DR. Vanilla attention costs O(n^2 · d) compute and materialises an n x n score matrix. FlashAttention keeps the exact same math but tiles the computation so the score matrix never hits HBM, making long-context training 2-4x faster with no quality loss. MQA / GQA share K and V across heads to shrink the KV cache during autoregressive decoding — Llama-2 70B uses GQA-8. Linear attention rewrites softmax(QK^T)V as φ(Q)(φ(K)^T V) for O(n · d^2) cost; sliding-window and sparse attention drop most of the score matrix entries. In IOAI/USAAIO tasks you meet these whenever a long document, a long audio sequence, or a memory-tight inference budget shows up.

1. The intuition

Standard self-attention computes a similarity score between every pair of tokens. For a sequence of length n that's n^2 scores, each costing d multiplications — so compute is O(n^2 · d) and memory is O(n^2). At n = 8k the score matrix alone is 64M entries per head per layer; at n = 32k it stops fitting in GPU memory at all.

Variants attack the cost along two roughly orthogonal axes. The first is exact-but-cheaper: FlashAttention reorders the computation so the n x n score block is never written to high-bandwidth memory (HBM); the answer is bit-identical to vanilla attention, just faster and lower-memory. The second is approximate-or-restricted: linear attention swaps softmax for a kernel feature map so you can reassociate the matmul; sliding-window attention only attends to a local neighbourhood; sparse attention picks a structured subset of token pairs. These change the function the model can express.

A third concern lives entirely at inference time: the KV cache. When a transformer decodes one token at a time, every previous K and V tensor must be kept around. For an L-layer model with h heads and head dim d_h the cache is 2 · L · n · h · d_h floats — for Llama-2 70B at 4k context that's tens of GB. MQA and GQA cut this directly by sharing K and V across heads.

2. The math

Scaled dot-product attention

Let Q, K, V in R^(n x d_k) (we'll fold the batch axis in later). The core operation is:

Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V

The 1/sqrt(d_k) factor keeps the pre-softmax logits at unit variance when Q and K have unit-variance entries; without it softmax saturates and gradients vanish. Compute cost: the matmul Q K^T is n · n · d_k, and S V is another n · n · d_v, so total O(n^2 · d). Memory: the score matrix S alone is n^2 entries.

Multi-head attention

Split the model dim d_model into h heads of size d_h = d_model / h. Each head has its own (W_Q^i, W_K^i, W_V^i) and runs attention in parallel; outputs concatenate and project through W_O:

head_i = Attention(X W_Q^i, X W_K^i, X W_V^i) MHA(X) = concat(head_1, ..., head_h) W_O

Total compute is unchanged — splitting into heads is just a tensor reshape. Storage of K and V for inference, however, is 2 · n · h · d_h per layer.

Multi-Query Attention (MQA)

Single K, V shared across all query heads (Shazeer, 2019). Q still has h heads, but W_K, W_V project to a single d_h-dim space:

Q in R^(n x h x d_h), K, V in R^(n x 1 x d_h) head_i = softmax(Q_i K^T / sqrt(d_h)) V

KV cache shrinks from 2 · n · h · d_h to 2 · n · d_h — a factor of h smaller (typically 16-32x). Decoding throughput goes up proportionally because each generated token rereads the cache. Quality drops slightly vs full MHA because heads can no longer specialise their key/value projections.

Grouped-Query Attention (GQA)

A middle ground (Ainslie et al., 2023). Partition h query heads into g groups; each group shares one K, V head:

Q in R^(n x h x d_h), K, V in R^(n x g x d_h), 1 <= g <= h

g = h recovers vanilla MHA; g = 1 recovers MQA. Llama-2 70B uses h = 64, g = 8 — KV cache 8x smaller than MHA, quality nearly indistinguishable.

FlashAttention — IO-aware exact attention

Dao et al. (2022). Same math as vanilla attention, but the n x n score matrix is never materialised in HBM. Tile Q into blocks of size B_r and K, V into blocks of size B_c; for each pair of tiles compute a partial softmax in SRAM and accumulate the output using the online-softmax recurrence:

m_new = max(m_old, rowmax(S_tile)) l_new = exp(m_old - m_new) · l_old + sum(exp(S_tile - m_new)) O_new = (l_old / l_new) · exp(m_old - m_new) · O_old + (1 / l_new) · exp(S_tile - m_new) · V_tile

Where m is the running rowwise max and l the running rowwise denominator. HBM access drops from O(n^2) to O(n^2 · d / M) where M is SRAM size, which translates to 2-4x wall-clock speedup at n >= 2k. In PyTorch 2.x you just call torch.nn.functional.scaled_dot_product_attention and the FlashAttention backend is picked automatically when the shapes and dtypes are eligible.

Linear / kernelized attention

Replace softmax with a kernel feature map φ : R^d → R^m such that k(q, k) ≈ φ(q)^T φ(k). Then by associativity:

softmax(Q K^T) V ≈ φ(Q) (φ(K)^T V)

Compute φ(K)^T V first — an m x d matrix — then multiply by φ(Q). Total cost O(n · m · d), linear in n. Linear Transformer (Katharopoulos et al., 2020) uses φ(x) = elu(x) + 1; Performer (Choromanski et al., 2020) uses random Fourier features that approximate softmax in expectation. Trade-off: gradient noise is higher and the model loses some of softmax's sharp attending behaviour.

Sliding-window attention

Each query only attends to w tokens on either side. The score matrix becomes band-diagonal with bandwidth 2w + 1:

M_ij = 1 if |i - j| <= w else 0 Attention_local(Q,K,V) = softmax((Q K^T / sqrt(d)) + log M) V

Compute is O(n · w · d). After L stacked layers the effective receptive field is L · w, so deep stacks recover long-range dependencies indirectly (used in Longformer, Mistral 7B with w = 4096).

Sparse attention

Sparse Transformer (Child et al., 2019) combines a local band of width w ~ sqrt(n) with a strided pattern that hops every sqrt(n) tokens. Two layers of this composition reach any pair, and total cost is O(n · sqrt(n) · d). BigBird adds a random sparse pattern plus a few global tokens and proves the result is a universal sequence-to-sequence approximator.

Rotary positional embedding (RoPE)

Not a complexity variant — RoPE modifies Q and K before the dot product so that relative position is encoded in the angle of a 2D rotation pair. For each pair of dimensions (2i, 2i+1) rotate by angle m · θ_i at position m with θ_i = 10000^(-2i/d). The dot product q_m · k_n then depends only on m - n. RoPE composes cleanly with every variant above because it acts pointwise on Q and K.

Cost summary

VariantComputeScore-matrix memoryKV cacheExact?
Vanilla MHAO(n^2 · d)O(n^2)2 · n · h · d_hyes
FlashAttentionO(n^2 · d)O(n)2 · n · h · d_hyes
MQAO(n^2 · d)O(n^2)2 · n · d_hyes (different params)
GQA-gO(n^2 · d)O(n^2)2 · n · g · d_hyes (different params)
Sliding-window wO(n · w · d)O(n · w)2 · w · h · d_hno (local only)
Sparse (strided)O(n · sqrt(n) · d)O(n · sqrt(n))sameno
Linear / PerformerO(n · m · d)O(m · d)O(m · d) stateapprox

3. PyTorch reference implementation

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------------
# 1. Vanilla scaled dot-product attention from scratch
# ---------------------------------------------------------------------------
def vanilla_attention(Q, K, V, mask=None):
    """Q, K, V: (B, h, n, d_h).  mask: (B, 1, n, n) additive, 0 or -inf."""
    d_h = Q.size(-1)
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h)   # (B, h, n, n)
    if mask is not None:
        scores = scores + mask
    attn = scores.softmax(dim=-1)
    return attn @ V                                       # (B, h, n, d_h)


# ---------------------------------------------------------------------------
# 2. Multi-head / MQA / GQA via a single module (g controls grouping)
# ---------------------------------------------------------------------------
class GroupedQueryAttention(nn.Module):
    """g = h  -> standard MHA.   g = 1  -> MQA.   1 < g < h  -> GQA."""
    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads  % n_kv_heads == 0
        self.h, self.g, self.d_h = n_heads, n_kv_heads, d_model // n_heads
        self.W_Q = nn.Linear(d_model, n_heads    * self.d_h, bias=False)
        self.W_K = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False)
        self.W_V = nn.Linear(d_model, n_kv_heads * self.d_h, bias=False)
        self.W_O = nn.Linear(d_model, d_model,                bias=False)

    def forward(self, x, mask=None):
        B, n, _ = x.shape
        Q = self.W_Q(x).view(B, n, self.h, self.d_h).transpose(1, 2)   # (B,h,n,d_h)
        K = self.W_K(x).view(B, n, self.g, self.d_h).transpose(1, 2)   # (B,g,n,d_h)
        V = self.W_V(x).view(B, n, self.g, self.d_h).transpose(1, 2)   # (B,g,n,d_h)
        # Repeat K, V across each group's query heads
        K = K.repeat_interleave(self.h // self.g, dim=1)               # (B,h,n,d_h)
        V = V.repeat_interleave(self.h // self.g, dim=1)
        # Use the fused FlashAttention backend when available
        out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask,
                                             is_causal=(mask is None))
        out = out.transpose(1, 2).contiguous().view(B, n, self.h * self.d_h)
        return self.W_O(out)


# ---------------------------------------------------------------------------
# 3. Sliding-window mask
# ---------------------------------------------------------------------------
def sliding_window_mask(n, w, device=None):
    """Additive mask: 0 where |i - j| <= w, else -inf.  Shape (1, 1, n, n)."""
    i = torch.arange(n, device=device).unsqueeze(1)
    j = torch.arange(n, device=device).unsqueeze(0)
    keep = (i - j).abs() <= w
    return torch.where(keep, 0.0, float("-inf")).view(1, 1, n, n)


# ---------------------------------------------------------------------------
# 4. Linear attention with feature map phi(x) = elu(x) + 1
# ---------------------------------------------------------------------------
def linear_attention(Q, K, V):
    """O(n * d^2) instead of O(n^2 * d).  Q,K,V: (B, h, n, d_h)."""
    phi_Q = F.elu(Q) + 1.0
    phi_K = F.elu(K) + 1.0
    KV    = phi_K.transpose(-2, -1) @ V          # (B, h, d_h, d_h)
    Z     = phi_Q @ phi_K.sum(dim=-2, keepdim=True).transpose(-2, -1)  # normaliser
    return (phi_Q @ KV) / (Z + 1e-6)


# ---------------------------------------------------------------------------
# 5. FlashAttention — conceptual tiled outer loop.
#    Real kernels live in CUDA; PyTorch dispatches to them via SDPA.
# ---------------------------------------------------------------------------
def flash_attention_reference(Q, K, V, B_r=64, B_c=64):
    """Numerically equivalent to vanilla_attention, computed tile by tile."""
    B, h, n, d_h = Q.shape
    scale = 1.0 / math.sqrt(d_h)
    O = torch.zeros_like(Q)
    l = torch.zeros(B, h, n, 1, device=Q.device)
    m = torch.full((B, h, n, 1), float("-inf"), device=Q.device)
    for i in range(0, n, B_r):
        Qi = Q[:, :, i:i + B_r]
        Oi = torch.zeros_like(Qi)
        li = torch.zeros(B, h, Qi.size(2), 1, device=Q.device)
        mi = torch.full((B, h, Qi.size(2), 1), float("-inf"), device=Q.device)
        for j in range(0, n, B_c):
            Kj = K[:, :, j:j + B_c]
            Vj = V[:, :, j:j + B_c]
            Sij  = (Qi @ Kj.transpose(-2, -1)) * scale
            m_new = torch.maximum(mi, Sij.max(dim=-1, keepdim=True).values)
            P     = torch.exp(Sij - m_new)
            l_new = torch.exp(mi - m_new) * li + P.sum(dim=-1, keepdim=True)
            Oi    = torch.exp(mi - m_new) * Oi + P @ Vj
            mi, li = m_new, l_new
        O[:, :, i:i + B_r] = Oi / li
    return O


# Use torch.compile in real code to fuse the linear / GQA variants:
#     fast_gqa = torch.compile(GroupedQueryAttention(512, 8, 2))


if __name__ == "__main__":
    B, n, d_model, h, g = 2, 64, 256, 8, 2
    x = torch.randn(B, n, d_model)
    block = GroupedQueryAttention(d_model, h, g)
    y = block(x)
    print(y.shape)            # torch.Size([2, 64, 256])
    block.train(False)        # switch dropout/BN to inference mode

F.scaled_dot_product_attention is the right call in production — it picks the FlashAttention CUDA backend automatically on Ampere/Hopper GPUs when dtypes are fp16/bf16 and the shapes are eligible, and falls back to a memory-efficient kernel otherwise. The hand-rolled flash_attention_reference above is for teaching only; it is not faster than vanilla because Python loops can't beat a fused CUDA kernel.

4. Common USAAIO / IOAI applications

5. Drills

D1 · Prove standard attention is O(n^2 · d)

Given Q, K, V in R^(n x d), count the multiplications and additions in softmax(Q K^T / sqrt(d)) V.

Solution

Q K^T is two matrices of shape (n, d) and (d, n) — the result is n x n with each entry costing d multiplies and d - 1 adds, so O(n^2 · d). Softmax is O(n^2). Multiplying the n x n attention by V in R^(n x d) is another O(n^2 · d). Total O(n^2 · d) compute and O(n^2) peak memory for the score matrix.

D2 · MQA KV-cache savings

A model has L = 32 layers, h = 32 heads, head dim d_h = 128, sequence length n = 4096, fp16. What is the KV cache size for MHA vs MQA vs GQA-4?

Solution

Cache = 2 · L · n · h_kv · d_h · 2 bytes.

  • MHA: 2 · 32 · 4096 · 32 · 128 · 2 = 2.0 GB
  • GQA-4: 2 · 32 · 4096 · 4 · 128 · 2 = 256 MB (8x smaller)
  • MQA: 2 · 32 · 4096 · 1 · 128 · 2 = 64 MB (32x smaller)

On a 24 GB consumer GPU this is the difference between fitting one or 8+ concurrent sequences.

D3 · Sliding-window effective receptive field

A 12-layer transformer uses sliding-window attention with w = 128. What is the maximum distance between two tokens that can directly influence each other in the output?

Solution

Each layer extends reach by w tokens on each side, so after L layers reach is L · w. Here 12 · 128 = 1536 tokens. That's enough for paragraph-level dependencies but not document-level. Mistral 7B uses w = 4096 over 32 layers for reach ~131k, which is why it claims a 32k context.

D4 · Linear-attention kernel choice

You try linear attention with φ(x) = x (the identity). Training loss diverges. Why? What does φ(x) = elu(x) + 1 fix?

Solution

For the attention output to be a convex combination of values, the implicit similarity φ(q)^T φ(k) must be non-negative; otherwise the normaliser Z = φ(Q) · Σ φ(K) can be zero or negative and the output blows up. The identity map produces signed dot products. The elu(x) + 1 map is always strictly positive (range (0, ∞)) while staying smooth and gradient-friendly, which guarantees a positive denominator and stable training.

D5 · When NOT to use FlashAttention

Your model trains fine with vanilla attention. You switch to FlashAttention and accuracy drops a little. What likely happened?

Solution

FlashAttention is bit-equivalent to vanilla in exact arithmetic, but reductions happen in a different order, so fp16/bf16 rounding differs slightly. If you trained with vanilla and only switched at fine-tune time, those rounding shifts can move logits enough to matter. Either retrain end-to-end with FlashAttention, or upcast the softmax to fp32 inside the kernel (PyTorch's SDPA does this by default in newer versions). It's not the algorithm — it's the numerics.

Next step

Loop back to Transformers for the base architecture that all of these variants modify, and to Round 2 theory for short-answer drills on attention complexity, KV cache sizing, and RoPE. If you're building a long-context model end to end, the Deep Learning page covers the training-loop and mixed-precision plumbing you'll need underneath.