iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article
🧵

Understanding GPU and FlashAttention

に公開

Introduction

2023 was the year when many LLMs, starting with ChatGPT, were released to the world. I am truly impressed that I can try many models released as OSS just by using Colab.
However, I still want to try training an LLM myself at least once—but if I don't train it efficiently, it will just burn money... That's when I found this article.

It contains various tips.
https://huggingface.co/docs/transformers/v4.35.2/en/perf_train_gpu_one
npaka-san has summarized this in Japanese here.
https://note.com/npaka/n/n04c493394e07

In this article, we will look at FlashAttention, which is one of the techniques listed above. Specifically, we will follow what improvements have been made. (I won't describe in detail how much faster it actually becomes, so please refer to other articles for that.)

Actually, FlashAttention-2 (2023) has already been released as a further improvement of FlashAttention (2022), but I hope to explain FlashAttention-2 in a separate article.

What is FlashAttention?

FlashAttention is a method for accelerating Attention. While many existing studies take an approach of using approximation methods to reduce computational complexity for speed-up or to handle longer token counts, FlashAttention is interesting in that it achieves speed-up by focusing on I/O.

The goal is to minimize access (reads/writes) to HBM as much as possible.

Mental Model of GPU and its Surroundings

First, since I didn't have the minimum knowledge regarding GPUs and their surrounding devices, I'll note down what I've researched.

Inside a GPU

First, let's look at the internal structure of a GPU.
a100

The bottom part shows the inside of an SM. SM stands for Streaming Multiprocessor and is an execution unit. You can see that there are 128 of these SMs. However, only 108 SMs are actually used, as they are built redundantly for cases like failures.

An SM has an L1 cache, which serves as shared memory, allowing high-speed access from all cores within the SM. Conversely, while an L2 cache exists, its access speed is significantly slower compared to L1.

SIMT (Single Instruction Multiple Threads) is currently popular. A group of multiple threads (e.g., 32) is called a Warp (in Nvidia GPUs), which is executed on an SM.

GPU Configuration

The following diagram is drawn with reference to the Nvidia A100 GPU.

pc_single_gpu

It is clear that transferring data from the CPU to the GPU is likely to be the bottleneck.

The point made in the following tweet is also intuitively understandable.
https://x.com/dc1394/status/1746021281097765083?s=20

When there are multiple GPUs in a single node, it looks like this:
pc_multi_gpu

You can see that very high transmission speeds are achieved between GPUs using Nvidia's proprietary NVLink technology.

GPU Speed

  • Compute
    Looking per SM, there are 64 FP32 CUDA cores and the clock frequency is around 1410 MHz, which results in approximately 64 * 1410 = 90 GFlops.
    Looking at the entire GPU, since there are 108 SMs, it's 90 GFlops * 108 = 9.7 TFlops.

a100_perf

Wait, looking at the spec, it says 19.5 TFLOPS. It seems to be twice as fast. Ah, is it because at peak performance, A * B + C is performed in a single clock, meaning 2 operations per clock?

  • Memory
    gpu_memory_hierarchy
    GPU Memory Hierarchy[1]

Current Status of GPUs

Recently, GPU computing speed has become significantly faster than memory speed, making memory access the bottleneck.
This trend is particularly evident in image processing, element-wise operations (activation, dropout), and reduction operations (sum, softmax, batchnorm, layernorm) in matrix calculations.

Recent chips face a "Compute Gap" problem, where improvements in memory bandwidth and latency cannot keep up with the increase in computing power. This Compute Gap problem is particularly prominent in MN-Core, where computing power has improved significantly.

https://xtech.nikkei.com/atcl/nxt/mag/rob/18/00007/00029/

fusion can also reduce GPU device memory reads/writes (by composing pointwise kernels) and help improve hardware utilization.

https://pytorch.org/blog/training-production-ai-models/?utm_content=275807566&utm_medium=social&utm_source=linkedin&hss_channel=lcp-78618366

Conversely, what kind of processes have computational speed as their bottleneck?
It seems to be matrix multiplication (especially when the inner dimension is large) or convolutions (especially when the channel size is large).

What is Attention?

When N is the sequence length and \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d} are inputs, the module that calculates \mathbf{O} as follows is called attention. In many cases, N >> d, with N ranging from 1k to 100k and d from 10 to 1000.

\mathbf{S}=\mathbf{Q} \mathbf{K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d},

Where is the memory-bound part? There are two matrix multiplications that can effectively utilize computational processing. The first is softmax. Since softmax is a process close to being element-wise, it tends to be IO-bound. Additionally, another cause is that calculation results are often written to memory before and after the operation and then read again. Furthermore, processes like masking (element-wise) can also become bottlenecks.

alg0

FlashAttention's Contributions

  • By calculating attention partially (also called tiling), the need to access the entire matrix during the attention softmax operation was eliminated, reducing the number of memory (HBM) accesses.
  • Implemented gradient checkpointing.

Explanation of FlashAttention

To put it simply, all you need to do is tile \mathbf{Q}, \mathbf{K}, \text{and } \mathbf{V} and calculate them individually as shown below.

flash_attention

However, calculating the row-wise softmax is a bit tricky. But just like when calculating mean or variance online, softmax can also be calculated online. Let's look at the details.

For simplicity, let's first consider the case where \mathbf{Q} has only 1 row block and \mathbf{K} has only 2 column blocks.
It can be expressed as \mathbf{S} = \left[\begin{array}{ll}\mathbf{S}^{(1)} & \mathbf{S}^{(2)}\end{array}\right] and \mathbf{V} = \left[\begin{array}{l} \mathbf{V}^{(1)} \\ \mathbf{V}^{(2)} \end{array}\right].
Using this notation, standard attention

\mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}

is,

\begin{aligned} m & =\max \left(\operatorname{rowmax}\left(\mathbf{S}^{(1)}\right), \operatorname{rowmax}\left(\mathbf{S}^{(2)}\right)\right) \in \mathbb{R}^{B_r} \\ \ell & =\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m}\right)+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m}\right) \in \mathbb{R}^{B_r} \\ \mathbf{P} & =\left[\begin{array}{ll} \mathbf{P}^{(1)} & \mathbf{P}^{(2)} \end{array}\right]=\operatorname{diag}(\ell)^{-1}\left[\begin{array}{ll} e^{\mathbf{S}^{(1)}-m} & e^{\mathbf{S}^{(2)}-m} \end{array}\right] \in \mathbb{R}^{B_r \times 2 \boldsymbol{B}_c} \\ \mathbf{O} & =\left[\begin{array}{ll} \mathbf{P}^{(1)} & \mathbf{P}^{(2)} \end{array}\right]\left[\begin{array}{l} \mathbf{V}^{(1)} \\ \mathbf{V}^{(2)} \end{array}\right]=\operatorname{diag}(\ell)^{-1} e^{\mathbf{S}^{(1)}-m} \mathbf{V}^{(1)}+e^{\mathbf{S}^{(2)}-m} \mathbf{V}^{(2)} \in \mathbb{R}^{B_r \times d} \end{aligned}

The -m in e^{\mathbf{S}^{(1)}-m}, etc., is to prevent overflow and underflow. Obviously, this has not yet been calculated per tile (online). In online calculation, we first perform the calculation corresponding to index-1.

\begin{aligned} m^{(1)} & =\operatorname{rowmax}\left(\mathbf{S}^{(1)}\right) \in \mathbb{R}^{B_r} \\ \ell^{(1)} & =\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m^{(1)}}\right) \in \mathbb{R}^{B_r} \\ \tilde{\mathbf{P}}^{(1)} & =\operatorname{diag}\left(\ell^{(1)}\right)^{-1} e^{\mathbf{S}^{(1)}-m^{(1)}} \in \mathbb{R}^{B_r \times B_c} \\ \mathbf{O}^{(1)} & =\tilde{\mathbf{P}}^{(1)} \mathbf{V}^{(1)}=\operatorname{diag}\left(\ell^{(1)}\right)^{-1} e^{\mathbf{S}^{(1)}-m^{(1)}} \mathbf{V}^{(1)} \in \mathbb{R}^{B_r \times d} \\ \end{aligned}

Next, assuming data for index-2 is obtained, here is how we finally calculate \mathbf{O}.

\begin{aligned} m^{(2)} & =\max \left(m^{(1)}, \operatorname{rowmax}\left(\mathbf{S}^{(2)}\right)\right)=m \\ \ell^{(2)} & =e^{m^{(1)}-m^{(2)}} \ell^{(1)}+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m^{(2)}}\right)=\operatorname{rowsum}\left(e^{\mathbf{S}^{(1)}-m}\right)+\operatorname{rowsum}\left(e^{\mathbf{S}^{(2)}-m}\right)=\ell \\ \tilde{\mathbf{P}}^{(2)} & =\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{\mathbf{S}^{(2)}-m^{(2)}} \\ \mathbf{O}^{(2)} & =\operatorname{diag}\left(\ell^{(1)} / \ell^{(2)}\right)^{-1} \mathbf{O}^{(1)}+\tilde{\mathbf{P}}^{(2)} \mathbf{V}^{(2)}=\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{\boldsymbol{s}^{(1)}-m} \mathbf{V}^{(1)}+\operatorname{diag}\left(\ell^{(2)}\right)^{-1} e^{s^{(2)}-m} \mathbf{V}^{(2)}=\mathbf{O} . \end{aligned}

In the final equation, \mathbf{O}^{(2)} = \mathbf{O} holds true, showing that it can be calculated sequentially. This means that if you keep m^{(n-1)}, \ell^{(n-1)}, \text{and } \mathbf{O}^{(n-1)}, you can sequentially calculate a new attention that takes the n-th data into account.

Let's also look at how memory is used at this time.

flash_attention_math
FlashAttention diagram[2]

First, \mathbf{Q} is loaded from HBM into SRAM. (If there are many blocks, only the block currently being calculated is loaded.) Next, \mathbf{S}^{(1)} is loaded into SRAM, and calculations for index-1 are performed. Here, instead of writing \mathbf{O}^{(1)} back to HBM, we move on to the calculation for index-2. By doing so, there is not a single write-back of values to HBM, allowing us to avoid HBM access as much as possible.

Now that we've looked at it qualitatively, let's also look at it quantitatively.

thorem_2

While standard attention is Nd + N^2, FlashAttention appears to be N^2 d^2 / M. At first glance, it might look like the number of HBM accesses has increased, but in many cases, d is 64 to 128 and M is about 100 KB, so d^2 / M becomes smaller than 1.

In fact, as shown in the table below, experiments have confirmed that reducing HBM access contributes significantly to reducing runtime. It has also been confirmed that by increasing the block size as much as possible, HBM access decreases and runtime decreases. However, at the 512 mark, it's highly likely that something other than memory access is the bottleneck.
result1

Implementation

What is the best way to implement FlashAttention?

Writing in CUDA

I have never personally touched CUDA directly, but I imagine it must be quite tedious (high difficulty). I want to become an expert who can write CUDA.
Of course, the authors of the paper have implemented it in CUDA.
https://github.com/Dao-AILab/flash-attention/tree/main

Even More Simply

It seems that FlashAttention can be used by using the PyTorch 2.0 Transformer API as follows:

model = torch.compile(model)

However, please note that multiple methods are supported, and since the most optimal one is automatically selected by default, FlashAttention is not necessarily used. It seems possible to force the use of FlashAttention.

In addition to automatic kernel selection, a context manager enables developers to override the kernel selection algorithm

https://pytorch.org/blog/accelerated-pytorch-2/

Detailed usage is written in a newer article, so please refer to that as well.
https://zenn.dev/kaeru39/articles/1ea73bfa40c7df

Performance evaluation articles using torch.compile

PyTorch 2.0 highlighted JIT compilation as its main feature, but I feel that deep learning compilers have become a hot topic lately (though I might already be late to the party).

Regarding deep learning compilers, the following articles are very well-organized and highly recommended:

https://zenn.dev/acd1034/articles/230325-dl-compiler-overview
https://huyenchip.com/2021/09/07/a-friendly-introduction-to-machine-learning-compilers-and-optimizers.html

Summary

  • We are no longer in a world where you can simply brute-force everything with hardware (though this could completely reverse again in about 10 years).
  • Currently, GPU bottlenecks are often found in I/O.
  • FlashAttention can be easily used in PyTorch 2.0.

Final Thoughts

I hope the response speed of ChatGPT-4.0 improves even more.
While researching for this article, I realized that we are truly in the "Great Compiler Era." I have completely fallen behind.

References

https://amzn.asia/d/jjSP2us
https://zenn.dev/selllous/articles/transformers_pretrain_to_ft
https://speakerdeck.com/hpprc/zi-yuan-tositejian-rushi-yan-puroguramu?slide=31
https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21730-inside-the-nvidia-ampere-architecture.pdf
https://tech.preferred.jp/ja/blog/mncore-compiler-1/
https://tech.preferred.jp/ja/blog/mn-core-tensor-layout/

脚注
  1. From https://arxiv.org/pdf/2205.14135.pdf ↩︎

  2. From https://arxiv.org/pdf/2307.08691.pdf ↩︎

GitHubで編集を提案

Discussion