🧵

GPU と FlashAttension をちゃんと理解したい

2024/01/23に公開

はじめに

ChatGPT をはじめてとして、多くの LLM が世の中に送り出された 2023 年でした。OSSとして公開されているモデルも多く試すだけであれば非常に Colab などで試せて感動しています。
とはいえ、やはり一度は LLM を自分で学習させてみたい、ただ効率的な学習をさせないとお金が溶けるだけ...。そんな中見つけた記事がこちらです。

さまざまな tips が載っています。
https://huggingface.co/docs/transformers/v4.35.2/en/perf_train_gpu_one
npaka san がこちらを日本語でまとめて下さっています。
https://note.com/npaka/n/n04c493394e07

この記事では、上に挙げられている技術の1つである FlashAttension についてみていきます。特に、どのような改善が行われているのかを追います。(結果的にどれくらい高速になるかは詳しく述べないため他の記事を参照してください)

実は FlashAttension(2022) のさらなる改善として FlashAttenstion2(2023) も既に出ているのですが、FlashAttension2 に関しては別の記事で解説できればと思います。

FlashAttension とは

FlashAttension とは、Attension を高速化するための手法の一つです。多くの既存研究が計算量を小さくするため近似手法を用いて高速化や扱えるトークン数を長くするアプローチをとる中、FlashAttension は IO に注目して高速化したというのが面白いです。

ゴールは、HBM へのアクセス(読み書き)をなるべく減らすことです。

GPU とその周辺のメンタルモデル

まずは GPU とその周辺のデバイスに関して最低限の知識が自分になかったので、調べたことをメモします。

GPU の内部

まずは GPU の内部の構造を見てみましょう。
a100

一番下が SM の内部となっています。SM とは、Streaming Multiprocessor の略で、実行ユニットです。この SM が 128 個あるのがわかります。 ただし、実際に使われるのは 108 個の SM で故障などのために冗長に作られています。

SM は L1 cache を持つのですが、それが Shared memory になっており、SM 内の全ての core から高速にアクセスできるようになっています。逆にいうと、L2 cache などは存在するものの L1 と比較するとアクセス速度などはかなり落ちるようです。

SIMT (Single Instruction Multiple Treads) が流行りであり、複数(32 とか)の Threads をまとめたものを(Nividia の GPU では)Warp といい、これが SM 上で実行されます。

GPU の構成

Nvidia の A100 GPU を参考に以下の図は書いています。

pc_single_gpu

CPU から GPU に転送するのが律速になりそうなことがわかります。

以下のツイートの話も直感的には理解できるでしょう。
https://x.com/dc1394/status/1746021281097765083?s=20

1つの node に複数 GPU が存在する場合は、以下のようになります。
pc_multi_gpu

Nvidia 独自の NVLink という技術を用いて GPU 間は非常に高速な伝送速度が実現されていることがわかります。

GPU の速度

  • Compute
    SM 単位で見ると、FP32 64 CUDA core があり、クロック周波数が 1410 MHz ほどのようなので、64 * 1410 = 90 GFlops ほどです。
    GPU 全体で見ると、SM が 108 個あるので、90 GFlops * 108 = 9.7 TFlops です。

a100_perf

あれ、spec を見ると、19.5 TFLOPS と書いてありますね。倍のスピードが出るようです。あ、ピーク性能では、A * B + C が1クロックで行われるので、2 演算 / クロックからか。

  • Memory
    gpu_memory_hierarchy
    GPU Memory Hierarchy[1]

GPU の現状

最近 GPU の性能は計算速度の方がメモリ速度よりも圧倒的に高速であり、ボトルネックとなるのはメモリアクセスです。
特に、画像処理、element-wise (activation, dropout) や reduction (sum, softmax, batchnorm, layernorm) 行列演算などで特にこの傾向があるとのこと。

近年のチップは計算能力の向上に対し、メモリ帯域向上/レイテンシ短縮が追いつかない“Compute Gap”が問題となっている。計算能力の向上が著しいMN-CoreではCompute Gap問題が顕著となっている。

https://xtech.nikkei.com/atcl/nxt/mag/rob/18/00007/00029/

fusion can also reduce GPU device memory reads/writes (by composing pointwise kernels) and help improve hardware utilization.

https://pytorch.org/blog/training-production-ai-models/?utm_content=275807566&utm_medium=social&utm_source=linkedin&hss_channel=lcp-78618366

逆に、計算速度がボトルネックとなる処理は何があるのでしょう。
行列の掛け算(特に inner dimention が大きい場合)や convolution(特に channel が大きい場合)のようです。

Attention とは

N を系列長とし、\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d} を入力とするときに、以下のようにして \mathbf{O} を計算するモジュールが attention です。多くの場合、N >> d で、N は 1k ~ 100k、d は 10 ~ 1000 といったところでしょうか。

\mathbf{S}=\mathbf{Q} \mathbf{K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d},

どこが メモリ bound なのでしょう。計算処理を有効利用できる行列積が2回もありますが。一つ目は、softmax です。softmax は element-wise に近い処理なので IO bound になりやすいです。加えて、演算の前後で計算結果をメモリに書き出し再度読み出されるということが起こることが多いことも原因の一つです。また、masking のような処理(element-wise)があればこちらもボトルネックになります。

alg0

FlashAttention の貢献

  • 部分的に attention を計算する(tiling とも呼ぶ)ことで、attention の softmax operation の際に行列全体にアクセスする必要を無くし、メモリ(HBM)にアクセスする回数を削減した。
  • gradient checkpointing を行った。

FlashAttention の説明

端的に説明すると、以下のように \mathbf{Q}, \mathbf{K}, \mathbf{V} を tiling して個別に計算してあげれば良いだけです。

flash_attention

しかし、row-wise の softmax の計算が少し厄介です。が、平均や分散をオンラインで計算したい場合などと同様に softmax についてもオンラインで計算できます。詳細を見ていきましょう。

簡単のため、まずは \mathbf{Q} が 1 row block、\mathbf{K} が 2 column block しかない場合を考えます。
\mathbf{S} = \left[\begin{array}{ll}\mathbf{S}^{(1)} & \mathbf{S}^{(2)}\end{array}\right], \mathbf{V} = \left[\begin{array}{l} \mathbf{V}^{(1)} \\ \mathbf{V}^{(2)} \end{array}\right] と表せます。
この表記を用いて、通常の attention

\mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}

は、

\begin{aligned} m & =\max \left(\operatorname{rowmax}\left(\mathbf{S}^{(1)}\right), \operatorname{rowmax}\left(\mathbf{S}^{(2)}\right)\right) \in \mathbb{R}^{B_r} \\ \ell & =\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m}\right)+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m}\right) \in \mathbb{R}^{B_r} \\ \mathbf{P} & =\left[\begin{array}{ll} \mathbf{P}^{(1)} & \mathbf{P}^{(2)} \end{array}\right]=\operatorname{diag}(\ell)^{-1}\left[\begin{array}{ll} e^{\mathbf{S}^{(1)}-m} & e^{\mathbf{S}^{(2)}-m} \end{array}\right] \in \mathbb{R}^{B_r \times 2 \boldsymbol{B}_c} \\ \mathbf{O} & =\left[\begin{array}{ll} \mathbf{P}^{(1)} & \mathbf{P}^{(2)} \end{array}\right]\left[\begin{array}{l} \mathbf{V}^{(1)} \\ \mathbf{V}^{(2)} \end{array}\right]=\operatorname{diag}(\ell)^{-1} e^{\mathbf{S}^{(1)}-m} \mathbf{V}^{(1)}+e^{\mathbf{S}^{(2)}-m} \mathbf{V}^{(2)} \in \mathbb{R}^{B_r \times d} \end{aligned}

e^{\mathbf{S}^{(1)}-m} などの -mは、overflow, underflow を防ぐためです。当たり前ですが、まだ、これは tile ごとに(オンライン)計算できていません。オンライン計算では、まず index-1 に対応する計算をします。

\begin{aligned} m^{(1)} & =\operatorname{rowmax}\left(\mathbf{S}^{(1)}\right) \in \mathbb{R}^{B_r} \\ \ell^{(1)} & =\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m^{(1)}}\right) \in \mathbb{R}^{B_r} \\ \tilde{\mathbf{P}}^{(1)} & =\operatorname{diag}\left(\ell^{(1)}\right)^{-1} e^{\mathbf{S}^{(1)}-m^{(1)}} \in \mathbb{R}^{B_r \times B_c} \\ \mathbf{O}^{(1)} & =\tilde{\mathbf{P}}^{(1)} \mathbf{V}^{(1)}=\operatorname{diag}\left(\ell^{(1)}\right)^{-1} e^{\mathbf{S}^{(1)}-m^{(1)}} \mathbf{V}^{(1)} \in \mathbb{R}^{B_r \times d} \\ \end{aligned}

ですね。次に、index-2 に対応するデータが得られたとして、どう最終的に \mathbf{O} を計算するかです。

\begin{aligned} m^{(2)} & =\max \left(m^{(1)}, \operatorname{rowmax}\left(\mathbf{S}^{(2)}\right)\right)=m \\ \ell^{(2)} & =e^{m^{(1)}-m^{(2)}} \ell^{(1)}+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m^{(2)}}\right)=\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m}\right)+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m}\right)=\ell \\ \tilde{\mathbf{P}}^{(2)} & =\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{\mathbf{S}^{(2)}-m^{(2)}} \\ \mathbf{O}^{(2)} & =\operatorname{diag}\left(\ell^{(1)} / \ell^{(2)}\right)^{-1} \mathbf{O}^{(1)}+\tilde{\mathbf{P}}^{(2)} \mathbf{V}^{(2)}=\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{\boldsymbol{s}^{(1)}-m} \mathbf{V}^{(1)}+\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{s^{(2)}-m} \mathbf{V}^{(2)}=\mathbf{O} . \end{aligned}

最後の式で、ちゃんと \mathbf{O}^{(2)} = \mathbf{O} となっており、逐次的に計算できていることがわかります。つまり、m^{(n-1)}, l^{(n-1)}, \mathbf{O}^{(n-1)} を持っておけば、n 番目のデータを考慮した新しい attention が逐次的に計算できるということです。

この時のメモリの使い方も見ておきます。

flash_attention_math
FlashAttention diagram[2]

まず、\mathbf{Q} を HBM から SRAM にロードします。(block 数が多い場合は現在計算したい block のみロードします。)次に、\mathbf{S}^{(1)} を SRAM にロードし、index-1 に関する計算を行います。ここで、\mathbf{O}^{(1)} を HBM に書き戻さずに、次の index-2 の計算に移ります。そのようにすることで、HBM への値の書き戻しが一度もないため、メモリを可能な限り避けることができています。

定性的にみたので、定量的にもみておきましょう。

thorem_2

通常の attention が Nd + N^2 なのに対して、FlashAttention が N^2 d^2 / M のようです。一見、HBM のアクセス回数が増えているように見えますが、多くの場合、d は 64 ~ 128 であり、M は 100 KB ほどなので、 d^2 / M は 1 より小さくなります。

実際、以下の表のように、HBM のアクセスの削減がランタイムの削減に大きく寄与していることが実験により確認されています。また block size をギリギリまで大きくすることで、HBM アクセスが減りランタイムが減ることも確認できています。ただし、512 のところでは、メモリアクセスではなく別の部分で律速となっている可能性が高いです。
result1

実装

FlashAttention を実装するにはどうするのがいいのでしょう。

CUDA で書く

僕自身全く CUDA を直接触ったことがないのですが、きっとかなり面倒(難易度が高い)と思います。CUDA を書ける猛者になりたい。
もちろん、論文の著者は CUDA で実装してます。
https://github.com/Dao-AILab/flash-attention/tree/main

もっと簡単に

FlashAttention は PyTorch 2.0 の Transformer API を用いて

model = torch.compile(model)

とすることで使うことができるようです。ただ、複数の手法がサポートされておりそこからデフォルトでは自動的に最も最良そうなものが選ばれるため、必ずしも FlashAttention が使われるわけではないことに注意です。強制的に FlashAttention 使うようにすることもできそうです。

In addition to automatic kernel selection, a context manager enables developers to override the kernel selection algorithm

https://pytorch.org/blog/accelerated-pytorch-2/

新しい記事に詳しく使い方書いてあったので、こちらもぜひ参考にしてみてください。
https://zenn.dev/kaeru39/articles/1ea73bfa40c7df

torch.compile を用いた性能評価の記事

最近の動向

PyTorch 2.0 で JIT compile が一番の目玉だったこともありますが、深層学習のコンパイラがアツいと最近感じてます。(いや、もう乗り遅れている)

深層学習のコンパイラに関しては、以下の記事が非常によくまとまっていておすすめです。

https://zenn.dev/acd1034/articles/230325-dl-compiler-overview

まとめ

  • hardware で殴れる世界線じゃなくなっている(あと10年ぐらいしたらまた逆転する可能性も全然ある)
  • 現状、GPU のボトルネックは多くの場合 IO にある。
  • PyTorch 2.0 では簡単に FlashAttention を使える。

最後に

もっと ChatGPT-4.0 の回答速度が上がると嬉しいなぁ。
この記事を書くために色々調べていて気づいたのですが、世はまさに大 compiler 時代なんですね。完全に出遅れました。

参考文献

https://amzn.asia/d/jjSP2us
https://zenn.dev/selllous/articles/transformers_pretrain_to_ft
https://speakerdeck.com/hpprc/zi-yuan-tositejian-rushi-yan-puroguramu?slide=31
https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21730-inside-the-nvidia-ampere-architecture.pdf
https://tech.preferred.jp/ja/blog/mncore-compiler-1/
https://tech.preferred.jp/ja/blog/mn-core-tensor-layout/

脚注
  1. https://arxiv.org/pdf/2205.14135.pdf より ↩︎

  2. https://arxiv.org/pdf/2307.08691.pdf より ↩︎

GitHubで編集を提案

Discussion