Understanding FlashAttention-3: One of the Most Important Algortihms to Make Transformers Fast
The new version takes full advatange of H100 capabilities to improve attention in transformer models.
I recently started an AI-focused educational newsletter, that already has over 170,000 subscribers. TheSequence is a no-BS (meaning no hype, no news, etc) ML-oriented newsletter that takes 5 minutes to read. The goal is to keep you up to date with machine learning projects, research papers, and concepts. Please give it a try by subscribing below:
There are few algorithms that have had as much impact on the recent generation of transformer architectures as FlashAttention. Originally developed by researchers from Princeton University, including the renowned Tri Dao, FlashAttention and its successor FlashAttention-2 were able to improve the performance of attention mechanisms in GPUs by minimizing read-writes. Almost immediately after the original publication, FlashAttention was rapidly adopted within the new generation of transformers. There were not many complaints about FlashAttention, but one of the few was that it was unable to take full advantage of new hardware architectures. For instance, FlashAttention-2 is only able to achieve 35% utilization of max FLOPs in H100 GPUs.
But now we have a new version.
Last week, a group of AI researchers from Meta, Princeton University, NVIDIA, and other AI labs published the paper and open-source code for FlashAttention-3. The new version of the method uses several techniques to speed up attention in H100 GPUs, exploiting the asynchrony of the tensor cores. The result is simple: FlashAttention-3 is blazing fast. The new model achieves 75% theoretical max FLOP utilization in H100, which results in practical 1.5–2x performance improvements. The new algorithm is also able to use lower precision numbers, which reduces the memory footprint.
Let’s dive into some of the details but, before, let’s recap some details of FlashAttention.
FlashAttention
FlashAttention is designed to optimize the computation of attention mechanisms by reordering the steps and utilizing tiling and recomputation. This approach significantly accelerates processing and reduces memory usage from quadratic to linear with respect to sequence length. The algorithm uses tiling to load blocks of inputs from GPU memory (HBM) to a faster cache (SRAM), processes the attention within that block, and updates the output back in GPU memory. By avoiding the storage of large intermediate matrices in HBM, FlashAttention reduces memory read/write operations, resulting in a 2–4x speed improvement in wallclock time.
In the FlashAttention forward pass, tiling and softmax rescaling allow the algorithm to operate by blocks. This method avoids extensive read/write operations from HBM, ensuring accurate output without approximations.
H100 GPUs and Attention
The magic of FlashAttention-3 is to take advantage of the latest H100 features to improve attention performance and address some of the limitation of its predecessor.
Although FlashAttention-2 achieves up to 70% of the theoretical maximum FLOPS on Ampere (A100) GPUs, it does not fully leverage the new capabilities of Hopper GPUs. Here are some of the key features of Hopper GPUs and their significance:
· WGMMA (Warpgroup Matrix Multiply-Accumulate): Utilizes new Tensor Cores on Hopper GPUs, offering much higher throughput compared to the older mma.sync instruction in Ampere GPUs.
· TMA (Tensor Memory Accelerator): This hardware unit speeds up data transfer between global memory and shared memory, handling index calculations and out-of-bound predictions. It frees up registers, enhancing tile size and efficiency.
· Low-precision with FP8: This feature doubles the throughput of Tensor Cores (e.g., from 989 TFLOPS with FP16 to 1978 TFLOPS with FP8) by using fewer bits to represent floating-point numbers, trading some accuracy for speed.
FlashAttention-3
FlashAttention-3 incorporates these new Hopper features using abstractions from NVIDIA’s CUTLASS library. Research like ThunderKitten 2 and cuDNN 9 has demonstrated that these hardware features can significantly accelerate attention computation. By adapting FlashAttention to utilize these features, its performance improves dramatically (e.g., from 350 TFLOPS in FlashAttention-2 FP16 forward pass to about 540–570 TFLOPS). The asynchronous instructions on Hopper (WGMMA and TMA) further provide opportunities for algorithmic optimizations.
FlashAttention-3 introduces three key techniques to enhance performance on modern GPU architectures:
1. Producer-Consumer Asynchrony: This method employs warp-specialized software pipelining, splitting data producers and consumers into separate warps. This separation exploits asynchronous execution to better hide memory and instruction issue latencies.
2. Hiding Softmax Under Asynchronous Block-wise GEMMs: By overlapping low-throughput softmax operations with asynchronous WGMMA instructions, FlashAttention-3 can circumvent sequential dependencies between softmax and GEMMs. For example, in a 2-stage version, while softmax processes one block of the scores matrix, WGMMA computes the next block.
3. Hardware-accelerated Low-precision GEMM: This adaptation targets FP8 Tensor Cores for GEMM, nearly doubling the measured TFLOPS/s. It involves managing different layout requirements for FP32 accumulators and FP8 operand matrices through block quantization and incoherent processing to mitigate accuracy loss from reduced precision.
The Results
The team behind FlashAttention-3 measured its runtime across various sequence lengths and compared it to standard PyTorch implementations, FlashAttention-2, FlashAttention-2 in Triton (which uses H100-specific instructions), and a vendor’s H100-optimized FlashAttention-2 from cuDNN. FlashAttention-3 is found to be up to 2x faster than FlashAttention-2 and 1.5x faster than FlashAttention-2 in Triton, achieving up to 740 TFLOPS/s, or 75% of the theoretical maximum on H100 GPUs.
FlashAttention-3 is an exciting development in generative AI algorithms. This method will almost certainly lead to improvements in large context windows in LLMs and better inference performance on modern GPU architectures. Impressive progress!