FlashAttention-2
Modern GPU architecture consists scalable array of multi-threaded streaming multiprocessor (SM), which functions as a general purpose processors independent of each other containing compute cores capable of executing several thread blocks in parallel. It consist of shared memory for communication between threads and schedulers for warps. Thread block is a group of threads up to 1024 threads in modern GPU architecture [Cuda]. Threads within a block execute concurrently on one SM. The SM creates and executes groups of 32 parallel threads called warps. Individual threads composing a warp start together at the same program address, but they have their own instruction address counter and register state and are therefore free to branch and execute independently. These threads are selected serially by the SM. A warp executes instructions serially on the SM’s scheduler. New thread blocks are not launched until sufficient registers and shared memory are available. [https://en.wikipedia.org/wiki/Thread_block_(CUDA_programming)] [https://docs.nvidia.com/cuda/cuda-c-programming-guide/] Blocks are organized into a one or multi-dimensional grid of thread blocks.
Reduction of non-matmul operation
Modern GPUs have specialized tensor cores that speeds up the matmul operation, for example offering theoretical throughput up to $312$ TFLOPs in A100 SXM4 and $1979$ TFLOPs in H100 SXM5 for FP16 Tensor core whereas only $19.5$ TFLOPs in A100 and $67$ TFLOPs in H100 for regular FP32. In other word, matmul operation is much cheaper than non matmul operation. FlashAttention-2 attempts to tradeoff this non matmul operation for matmul operation.
Reduction of $d_j$ $O_{j} = O_{j-1}d_{j-1}e^{m_{j-1}-m_{j}}+e^{S_{ij}-m_j}V_{j}$ The idea is to maintain an unscaled $O_j$, only at the end of every loop it is scaled by $d_j$ i.e., $O_{b} = O_{j-1}\frac{d_{j-1}}{d_{j}}e^{m_{j-1}-m_{j}}+\frac{e^{S_{ij}-m_j}}{d_j}V_{j}$
Reduction of $m_b$ and $d_b$ to $\text{logsumexp} L_j = m_j + \text{log}(d_j)$ for the backward pass.
Parallelization on sequence
FlashAttention parallelizes over batch size $B$ and number of heads $hn$, assigning one thread block per attention head $(B \times nh)$. Here, utilization of SM depends on the size of $B \times nh$, that means the larger the $B$ or $nh$ the more efficient is the scheduling. However, for small $B \times nh$ , such as in long sequences or distributed pipelines we see suboptimal utilization of SM. FlashAttention-2 parallelizes over $B \times nh$ as in FlashAttention but in addition with scheduling the loop over sequence length on different thread blocks that do not need any communication as rows of attention doesn't depends on each other. That means additional parallelization over sequence dimension. Furthermore, similar strategy is used in backward pass but instead columns of the attention are passed over thread block
Partitioning between warps
For each thread block FlashAttention uses 4 or 8 warps in split-K scheme where $K$ and $V$ are divided to across warps, while making $Q$ accessible for all warps. That means, each warps performs matmul and write intermediate result to shared memory, synchronize and add up, resulting memory read and write overhead on shared memory. Instead, FlashAttention2 adopts split-Q scheme where $Q$ is divided across wraps while $K$ and $V$ are shared. That means, each warp performs its $QK^\top$ and since $V$ is shared, it can directly matmul with $V$ reducing the overhead in shared memory read and writes. Similar technique is applied in backward pass with additional need of synchronization due to dependency between inputs and gradients.
Newer GPUs can support up to 1024 thread blocks, while [Dao, 2023] saw that increasing the block sizes reduces shared memory loads but increases number of registers and total amount of shared memory risking overflow/spilling and causing slowdowns. They manually tuned block sizes for each head dimension $d$.
FlashAttention-2 have $2 \times$ speedup over FlashAttention, improving the theoretical max throughput from $30 \sim 50 %$ to $73%$ in forward pass and from $25 \sim 35 %$ to $63%$ in backward pass when trained on A100 GPU [Dao, 2023]. This enables FlashAttention-2 to train on double the size of context length for the equal cost of context length trained on FlashAttention. Regardless of such improvements, it is still suboptimal utilization of GPUs moreover on newer GPU like H100 relative to GEMM.
FlashAttention-3
FlashAttention-3 [Dao, 2024] extends FlashAttention-2 by optimizing attention computation for Hopper architecture. While FlashAttention-2 relies on a simplified synchronous model and FP16 precision and makes no explicit use of asynchrony or low-precision format. Although being efficient on Ampere GPUs like A100, it achieves only partial utilization on H100 GPUs due to underutilization of Hopper-specific features, as noted in [Dao, 2024]. FlashAttention-3 addresses this limitation by leveraging H100’s 4th-gen Tensor Cores, Tensor Memory Accelerator (TMA), and native FP8 support. The modifications focuses on utilizing hardware features introducing produce-consumer asynchrony, intra-warpgroup overlapping, and low-precision computing. fig>SMEM, GMEM, Warps
Produces consumer asynchrony
Hopper introduces TMA for asynchronous memory copies between global memory (GMEM, i.e., HBM) and shared memory (SMEM, i.e., on-chip). Additionally, warpgroup-wide matrix multiply-accumulate instructions (WGMMA) are also asynchronous and can read operands directly from SMEM. FlashAttention-3 leverages this via warp-specialized kernels: warps within a thread block (CTA) are assigned distinct roles producer warps load data from GMEM to TMA, and consumer warps compute WGMMA. Using ping-pong buffering (dual alternating buffers), it overlaps data movement and computation, hiding memory latency and thus reducing GPU idle time.
Intra-warp group overlapping
Unlike FlashAttention-2's sequential inner-loop dependencies that limit full parallelization, FlashAttention-3 introduces two stage GEMM-softmax pipelining allowing the GPU to begin computing softmax while simultaneously executing the next GEMM, overlapping them improving utilization.
Low precision with FP8
The H100 architecture offers native support for the FP8 (e4m3) format, significantly reducing memory and bandwidth costs while increasing throughput compared to FP16/FP32. FlashAttention-3 applies blockwise quantization to $Q$, $K$, and $V$, enabling fused computation with preceding operations (e.g., RoPE) at no additional cost. Since FlashAttention-3 operates on small blocks, each block can be scaled independently and efficiently. Additionally, to mitigate quantization error, especially from outliers, Hadamard matrix $M$ with random diagonal matrices of $\pm 1$ is multiplied with $Q$ and $K$ such that $(QM)(KM)^\top=QK^\top$ and $MM^\top = I$. This distributes magnitude more uniformly across elements, reducing quantization error and enabling RoPE fusion without additional computation overhead.
FlashAttention-3 achieves $1.5 \sim 2 \times$ speedup (up to 740 TFLOPs/s) over FlashAttention-2 in the forward pass and $1.5 \sim 1.75 \times$ in the backward pass using FP16. It also achieves close to 1.2 PFLOPs/s in FP8 and reduces FP8 quantization error by $2.6 \times$ compared to standard per-tensor quantization. However, these gains are specific to the Hopper architecture. Future work could extend these optimizations to Blackwell architecture, leveraging their hardware capabilities like support of FP4 format [Dao, 2024].
Despite the significant performance gains offered by FlashAttention and its successors, FlashAttention-2 and FlashAttention-3, these methods have limited support for novel or composable attention variants, such as Sliding Window Attention (SWA), ALiBi, softcapping, or PagedAttention. Implementing each variant or their combinations typically requires developing a separate, custom fused kernel tailored to the specific attention mechanism, a process that is both time-consuming and resource-intensive. This lack of modularity and extensibility creates a substantial development bottleneck, particularly for researchers exploring diverse attention configurations. These constraints motivated the development of FlexAttention, a compiler-driven PyTorch API that enables the implementation of a broad range of attention mechanisms through programmable abstractions.