FlashMLA is a high-performance decoding kernel library designed especially for Multi-Head Latent Attention (MLA) workloads, targeting NVIDIA Hopper GPU architectures. It provides optimized kernels for MLA decoding, including support for variable-length sequences, helping reduce latency and increase throughput in model inference systems using that attention style. The library supports both BF16 and FP16 data types, and includes a paged KV cache implementation with a block size of 64 to efficiently manage memory during decoding. On very compute-bound settings, it can reach up to ~660 TFLOPS on H800 SXM5 hardware, while in memory-bound configurations it can push memory throughput to ~3000 GB/s. The team regularly updates it with performance improvements; for example, a 2025 update claims 5 % to 15 % gains on compute-bound workloads while maintaining API compatibility.
Features
- Decoding kernel optimized for MLA (Multi-Head Latent Attention) modules
- Support for BF16 and FP16 precision to balance speed vs numerical stability
- Paged KV cache with block size = 64 to efficiently handle varying sequence lengths
- GPU-native implementation targeting NVIDIA Hopper architecture
- Python / PyTorch integration via functions like flash_mla_with_kvcache
- Regular performance improvements over time (e.g. 5–15 % uplift in newer versions)