Grouped Query Attention

Balancing KV-cache memory efficiency with model quality

Query (Q)
Key (K)
Value (V)
MHA
Multi-Head Attention
Q
0
1
2
3
K
0
1
2
3
V
0
1
2
3
KV Cache: 8 heads
GQA
Grouped Query Attention
Q
0
1
2
3
K
0-1
2-3
V
0-1
2-3
KV Cache: 4 heads
MQA
Multi-Query Attention
Q
0
1
2
3
K
shared
V
shared
KV Cache: 2 heads
How GQA Works
Input Hidden States
↓
WQ
WK
WV
↓
Q₀ Q₁ Q₂ Q₃
n_heads = 4
K₀ K₁
n_kv = 2
V₀ V₁
n_kv = 2
↓
Group 0: Q₀,Q₁ → K₀,V₀
Group 1: Q₂,Q₃ → K₁,V₁
↓
Attention Output (concat all heads)
Why GQA Matters
💾
Smaller KV Cache
Reduces memory footprint proportional to the grouping ratio. Critical for long-context inference.
âš¡
Faster Decoding
Less KV data to load from HBM per token. Directly improves memory-bound decode throughput.
🎯
Quality Preserved
Outperforms MQA significantly. Llama 2 70B uses GQA with 8 KV heads for 64 query heads.
KV Cache Size = 2 × n_layers × n_kv_heads × seq_len × head_dim × dtype_size