iTranslated by AI
Understanding GPU and FlashAttention
Introduction
2023 was the year when many LLMs, starting with ChatGPT, were released to the world. I am truly impressed that I can try many models released as OSS just by using Colab.
However, I still want to try training an LLM myself at least once—but if I don't train it efficiently, it will just burn money... That's when I found this article.
It contains various tips. npaka-san has summarized this in Japanese here.
In this article, we will look at FlashAttention, which is one of the techniques listed above. Specifically, we will follow what improvements have been made. (I won't describe in detail how much faster it actually becomes, so please refer to other articles for that.)
Actually, FlashAttention-2 (2023) has already been released as a further improvement of FlashAttention (2022), but I hope to explain FlashAttention-2 in a separate article.
What is FlashAttention?
FlashAttention is a method for accelerating Attention. While many existing studies take an approach of using approximation methods to reduce computational complexity for speed-up or to handle longer token counts, FlashAttention is interesting in that it achieves speed-up by focusing on I/O.
The goal is to minimize access (reads/writes) to HBM as much as possible.
Mental Model of GPU and its Surroundings
First, since I didn't have the minimum knowledge regarding GPUs and their surrounding devices, I'll note down what I've researched.
Inside a GPU
First, let's look at the internal structure of a GPU.

The bottom part shows the inside of an SM. SM stands for Streaming Multiprocessor and is an execution unit. You can see that there are 128 of these SMs. However, only 108 SMs are actually used, as they are built redundantly for cases like failures.
An SM has an L1 cache, which serves as shared memory, allowing high-speed access from all cores within the SM. Conversely, while an L2 cache exists, its access speed is significantly slower compared to L1.
SIMT (Single Instruction Multiple Threads) is currently popular. A group of multiple threads (e.g., 32) is called a Warp (in Nvidia GPUs), which is executed on an SM.
GPU Configuration
The following diagram is drawn with reference to the Nvidia A100 GPU.

It is clear that transferring data from the CPU to the GPU is likely to be the bottleneck.
The point made in the following tweet is also intuitively understandable.
When there are multiple GPUs in a single node, it looks like this:

You can see that very high transmission speeds are achieved between GPUs using Nvidia's proprietary NVLink technology.
GPU Speed
- Compute
Looking per SM, there are 64 FP32 CUDA cores and the clock frequency is around 1410 MHz, which results in approximately 64 * 1410 = 90 GFlops.
Looking at the entire GPU, since there are 108 SMs, it's 90 GFlops * 108 = 9.7 TFlops.

Wait, looking at the spec, it says 19.5 TFLOPS. It seems to be twice as fast. Ah, is it because at peak performance, A * B + C is performed in a single clock, meaning 2 operations per clock?
- Memory

GPU Memory Hierarchy[1]
Current Status of GPUs
Recently, GPU computing speed has become significantly faster than memory speed, making memory access the bottleneck.
This trend is particularly evident in image processing, element-wise operations (activation, dropout), and reduction operations (sum, softmax, batchnorm, layernorm) in matrix calculations.
Recent chips face a "Compute Gap" problem, where improvements in memory bandwidth and latency cannot keep up with the increase in computing power. This Compute Gap problem is particularly prominent in MN-Core, where computing power has improved significantly.
fusion can also reduce GPU device memory reads/writes (by composing pointwise kernels) and help improve hardware utilization.
Conversely, what kind of processes have computational speed as their bottleneck?
It seems to be matrix multiplication (especially when the inner dimension is large) or convolutions (especially when the channel size is large).
What is Attention?
When
Where is the memory-bound part? There are two matrix multiplications that can effectively utilize computational processing. The first is softmax. Since softmax is a process close to being element-wise, it tends to be IO-bound. Additionally, another cause is that calculation results are often written to memory before and after the operation and then read again. Furthermore, processes like masking (element-wise) can also become bottlenecks.

FlashAttention's Contributions
- By calculating attention partially (also called tiling), the need to access the entire matrix during the attention softmax operation was eliminated, reducing the number of memory (HBM) accesses.
- Implemented gradient checkpointing.
Explanation of FlashAttention
To put it simply, all you need to do is tile

However, calculating the row-wise softmax is a bit tricky. But just like when calculating mean or variance online, softmax can also be calculated online. Let's look at the details.
For simplicity, let's first consider the case where
It can be expressed as
Using this notation, standard attention
\mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}
is,
The
Next, assuming data for index-2 is obtained, here is how we finally calculate
In the final equation,
Let's also look at how memory is used at this time.

FlashAttention diagram[2]
First,
Now that we've looked at it qualitatively, let's also look at it quantitatively.

While standard attention is
In fact, as shown in the table below, experiments have confirmed that reducing HBM access contributes significantly to reducing runtime. It has also been confirmed that by increasing the block size as much as possible, HBM access decreases and runtime decreases. However, at the 512 mark, it's highly likely that something other than memory access is the bottleneck.

Implementation
What is the best way to implement FlashAttention?
Writing in CUDA
I have never personally touched CUDA directly, but I imagine it must be quite tedious (high difficulty). I want to become an expert who can write CUDA.
Of course, the authors of the paper have implemented it in CUDA.
Even More Simply
It seems that FlashAttention can be used by using the PyTorch 2.0 Transformer API as follows:
model = torch.compile(model)
However, please note that multiple methods are supported, and since the most optimal one is automatically selected by default, FlashAttention is not necessarily used. It seems possible to force the use of FlashAttention.
In addition to automatic kernel selection, a context manager enables developers to override the kernel selection algorithm
Detailed usage is written in a newer article, so please refer to that as well.
Performance evaluation articles using torch.compile
Recent Trends
PyTorch 2.0 highlighted JIT compilation as its main feature, but I feel that deep learning compilers have become a hot topic lately (though I might already be late to the party).
Regarding deep learning compilers, the following articles are very well-organized and highly recommended:
Summary
- We are no longer in a world where you can simply brute-force everything with hardware (though this could completely reverse again in about 10 years).
- Currently, GPU bottlenecks are often found in I/O.
- FlashAttention can be easily used in PyTorch 2.0.
Final Thoughts
I hope the response speed of ChatGPT-4.0 improves even more.
While researching for this article, I realized that we are truly in the "Great Compiler Era." I have completely fallen behind.
References
Discussion