💨
[Attention] Grouped-Query Attention
Key Contributions
-
Key,value can be share for each head
-> reduce projection weights
-> reduce key,value storage
-
computation amount doesn't change within attention. just loading shared key,value
Multi Head-Attention
Key,Value has different per head
Q_proj[N,dh*8] = [Q_h1,Q_h2,Q_h3,Q_h4,Q_h5,Q_h6,Q_h7,Q_h8]
K_proj[N,dh*8] = [K_h1,K_h2,K_h3,K_h4,K_h5,K_h6,K_h7,K_h8]
V_proj[N,dh*8] = [V_h1,V_h2,V_h3,V_h4,V_h5,V_h6,V_h7,V_h8]
= V[N, d_h * h] @ W_v[d_h * h,d_h * h]
Multi Query-Attention
Key,Value are shared for all head
Q_proj[N,dh*8] = [Q_h1,Q_h2,Q_h3,Q_h4,Q_h5,Q_h6,Q_h7,Q_h8]
K_proj[N,dh*1] = [K_h1]
V_proj[N,dh*1] = [V_h1]
= V[N, d_h * h] @ W_v[d_h * h,d_h*1]
Grouped Query-Attention
Key,Value are shared per head group
Q_proj[N,dh*8] = [Qh1,Qh2,Qh3,Qh4,Qh5,Qh6,Qh7,Qh8]
K_proj[N,dh*4] = [K_g1,K_g2,K_g3,K_g4]
V_proj[N,dh*4] = [V_g1,V_g2,V_g3,V_g4]
= V[N, d_h * h] @ W_v[d_h * h , d_h*4]
Reference
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Discussion