📑
[Attention] Flash Attention
Key Contributions
Standard Attention Problem
- if N(input sequence) is large compared to d(channels), S[N,N] and P[N,N] are very large
- Large Read/Write Cost to VRAM(High Bandwidth Memory)
FlashAttention
- split Q@K^T[N,N] into submatrix Q@K^T[Tc,Tr]. (Tc,Tr size is configurable)
- decompose softmax
Inner Loop (iterating through Query and Output blocks):
for i in 1 to Tr do
Reference
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Discussion