💨

[Attention] Grouped-Query Attention

に公開

Key Contributions

  • Key,value can be share for each head
    -> reduce projection weights
    -> reduce key,value storage

  • computation amount doesn't change within attention. just loading shared key,value

Multi Head-Attention

Key,Value has different per head

Q_proj[N,dh*8] = [Q_h1,Q_h2,Q_h3,Q_h4,Q_h5,Q_h6,Q_h7,Q_h8]
K_proj[N,dh*8] = [K_h1,K_h2,K_h3,K_h4,K_h5,K_h6,K_h7,K_h8]
V_proj[N,dh*8] = [V_h1,V_h2,V_h3,V_h4,V_h5,V_h6,V_h7,V_h8]
               = V[N, d_h * h] @ W_v[d_h * h,d_h * h]

Multi Query-Attention

Key,Value are shared for all head

Q_proj[N,dh*8] = [Q_h1,Q_h2,Q_h3,Q_h4,Q_h5,Q_h6,Q_h7,Q_h8]
K_proj[N,dh*1] = [K_h1]
V_proj[N,dh*1] = [V_h1]
               = V[N, d_h * h] @ W_v[d_h * h,d_h*1]

Grouped Query-Attention

Key,Value are shared per head group

Q_proj[N,dh*8] = [Qh1,Qh2,Qh3,Qh4,Qh5,Qh6,Qh7,Qh8]
K_proj[N,dh*4] = [K_g1,K_g2,K_g3,K_g4]
V_proj[N,dh*4] = [V_g1,V_g2,V_g3,V_g4]
               = V[N, d_h * h] @ W_v[d_h * h , d_h*4]

Reference

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Discussion