📑

[Attention] Flash Attention

に公開

Key Contributions

Standard Attention Problem

  • if N(input sequence) is large compared to d(channels), S[N,N] and P[N,N] are very large
  • Large Read/Write Cost to VRAM(High Bandwidth Memory)

FlashAttention

  • split Q@K^T[N,N] into submatrix Q@K^T[Tc,Tr]. (Tc,Tr size is configurable)
  • decompose softmax

Inner Loop (iterating through Query and Output blocks):
for i in 1 to Tr do

P_{ij}^{\text{curr}} = \exp(Q_i @ K_j^T)

l_i^{\text{new}} = l_i^{\text{prev}} + \text{rowsum}(P_{ij}^{\text{curr}})

O_i^{\text{new}} = \frac{l_i^{\text{prev}} \times O_i^{\text{prev}} + P_{ij}^{\text{curr}} @ V_j}{l_i^{\text{new}}}

l_i^{\text{prev}} \leftarrow l_i^{\text{new}}
O_i^{\text{prev}} \leftarrow O_i^{\text{new}}

Reference

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Discussion