Flash attention is a fast, memory-efficient algorithm that speeds up the attention mechanism in transformer models by minimizing data transfers between different levels of GPU memory
How Flash Attention works
- Traditional attention: The standard approach calculates the entire attention matrix, which is a large 𝑁×𝑁 matrix for a sequence of length𝑁. Storing this matrix requires a large amount of memory, and frequent data transfers between the slow HBM and fast SRAM become a bottleneck.
- Flash attention:
- Tiling: Divides the input matrices (Query, Key, Value) into smaller blocks.
- Block processing: Computes the attention for one block at a time, keeping only the necessary intermediate results in the fast on-chip SRAM.
- Reduced I/O: Instead of writing the large attention matrix to HBM, it performs the computation and incrementally updates the final output, drastically reducing memory reads and writes.
- **IO-awareness:** It is an “IO-aware” algorithm, meaning it is specifically designed to optimize for the performance characteristics of modern GPUs, which have different types of memory with varying speeds.
Why it is important
- Faster models: Flash attention provides a significant speedup (2-4x) in wall-clock time for training and inference.
- Longer contexts: It allows models to process much longer input sequences without running out of GPU memory.
- Enables larger models: It makes it feasible to work with the large language models (LLMs) that are becoming standard in AI.
- Exact result: It computes the attention with the exact same mathematical result as the standard method, with no approximation involved.
What is HBM? (High-Bandwidth Memory)
Flash attention is a memory-efficient algorithm that speeds up the computation of attention in transformer models by avoiding the need to store large intermediate matrices, while HBM (High Bandwidth Memory) is a type of high-performance memory that is physically stacked in a 3D architecture to achieve very high bandwidth. Flash attention works by using a technique called tiling to process attention in smaller blocks, keeping only small chunks in faster on-chip memory (SRAM) and minimizing slow reads and writes to HBM.
