Open6

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

pofipofi

HBMからのAttention matrixの読み込みと書き込みを防止することが目的

SRAM>HBM>DRAMの順に読み込みや書き込みの速度が速い、ただし、サイズも小さい。
HBMやDRAMにアクセスを少なくして計算することが学習や推論速度向上につながる。

pofipofi

目的の達成には、2つが必要

  1. HBMにアクセスするためには入力全体にアクセスせずにAttention部分のsoftmaxを計算する
  2. backward pasの計算のためにattention matrixの巨大な中間出力を保存しないようにする

これらを解決するために、

  1. 入力をブロックに分割し、そのブロックに対して数回のパスを作ることでAttentionのsoftmaxの計算を段階的に行う。
  2. backwardでのAttentionを迅速に計算するために、forwardでのsoftmax normalization factorを保存する。
pofipofi

FLASHATTENTIONはHBMへのO\left(N^2 d^2 M^{-1}\right)のアクセスが必要
dはhead dimension, MはSRAMのサイズ
一方、standardのAttentionはHBMへ\Omega\left(N d+N^2\right)アクセスしている。
最大9倍HBMへのアクセスが減る

pofipofi

入力は、\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d},
出力は、\mathbf{O} \in \mathbb{R}^{N \times d}
Attentionの計算は、
\mathbf{S}=\mathbf{Q K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}
となる。
一般的なAttentionの実装では、\mathbf{S}, \mathbf{P}をHBMに保存し、O\left(N^2\right)のメモリ容量がかかる。

pofipofi

FlashAttentionでは、
\mathbf{Q}, \mathbf{K}, \mathbf{V}をブロックに分割し、HBMからSRAMへ読み込み、ブロックごとにAttentionを計算する。
計算したAttentionの出力はブロックごとにNormalizationしてから足し合わせる。