FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention

  • Hi, one of the authors of this blog post (Horace He), along with Driss Guessous, Yanbo Liang, and Joy Dong.

    We’re quite happy with this abstraction - happy to answer any questions about it!

  • It's interesting that optimizing a computation that can be described in a single line of math takes so much work. It took forever even to discover Flash attention. And in the 6 years since transformers were invented, thousands of papers worked on making it faster.

    Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V

    FlexAttention seems to have found the right abstraction for the task.

  • For most LLM workloads today (short text chats), hundreds or a couple thousand tokens suffice. attention mechanisms don’t dominate (< 30% compute). But as the modalities inevitably grow, work in attention approximation/compression is going to be paramount.

    Nice to see Pytorch already elegantly supporting this next step in research

  • I didn't see any notice of this being CUDA only (like FlashAttention). I tried running on my Mac M3, python 3.11.8, following the quickstart (with the deviation of running it in a new venv). Got the following error:

    /attention-gym/.venv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.) cpu = _conversion_method_template(device=torch.device("cpu")) Traceback (most recent call last): File "/attention-gym/attn_gym/masks/document_mask.py", line 7, in <module> from torch.nn.attention.flex_attention import _mask_mod_signature ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'

  • > FlexAttention achieves 90% of FlashAttention2’s performance in the forward pass and 85% in the backward pass.

    It's very good. But note FlashAttention-3 is 1.5x - 2x faster than FlashAttention-2.

  • Always had the curiosity to put something together with pytorch but it always seemed either a steep learning curve or there wasn't a big motivator (project, problem to solve, something in my daily routine to optimize).

    Does anybody have a good starting point to learn with hands-on projects and also that could accommodate for flexattention?

  • This is so cool. I want to try to implement something with this right now.

  • Can someone do a short summary or TL;DR for this?