The official implementation of SpargeAttn, a universal training-free sparse attention accelerating language, image, and video models.
- Sparse SageAttention1 API and Sparse SageAttention2 API can compute attention with any block sparse pattern very fast.
- SpargeAttn based on SageAttention2++ will be released around June 25.
- [2025-05-11]: Add a very simple usage without tuning or calibration:
o = spas_sage2_attn_meansim_topk_cuda(q, k, v). - [2025-05-02]: 🎉SpargeAttn and SageAttention2 are accepted by ICML 2025!
- [2025-01-24]: 🎉SageAttention is accepted by ICLR 2025!
python>=3.9,torch>=2.3.0
CUDA:>=12.8for Blackwell,>=12.4for fp8 support on Ada,>=12.3for fp8 support on Hopper,>=12.0for Ampere
pip install ninja # for parallel compilation
python setup.py install # or pip install -e .-
spas_sage2_attn_meansim_topk_cuda: SpargeAttn based on SageAttention2 that we recommend using. -
spas_sage2_attn_meansim_cuda: SpargeAttn based on SageAttention2 that we do not recommend. -
spas_sage_attn_meansim_topk_cuda: SpargeAttn based on SageAttention that we recommend using. -
spas_sage_attn_meansim_cuda: SpargeAttn based on SageAttention that we do not recommend.
Just replace torch.nn.functional.scaled_dot_product_attention API using spas_sage2_attn_meansim_topk_cuda:
from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda
- attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False) # is_causal can be True
+ attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False) # is_causal can be Truefrom spas_sage_attn import spas_sage2_attn_meansim_topk_cuda
attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, simthreshd1=-0.1, topk=0.5, pvthreshd=15, is_causal=False)You can adjust topk to balance between attention accuracy (higher topk is more accurate) and sparsity (lower topk is more sparse).
Top-K selection is also supported as an alternative to cdfthreshd. We find that Top-K sometimes can achieve better performance than Top-P in SpargeAttn. You can call the API as follows and set the fraction of top elements via the topk parameter:
from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda
attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, simthreshd1=-0.1, topk=0.5, pvthreshd=15, is_causal=False)Note: Automatic tuning for topk is not currently supported.
from spas_sage_attn import block_sparse_sage2_attn_cuda
block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, scale=None, pvthreshd=20, attention_sink=False, tensor_layout="HND", return_sparsity=False):In this API, we support computing mask_id, is of shape (batch_size, num_qo_heads, qo_seq_len // BLOCK_M, kv_seq_len // BLOCK_N). Currently, the supported block size is aligned to that of SpargeAttention, which is (BLOCK_M = 128, BLOCK_N = 64). The lower pvthreshd, the more sparsity for PV Matmul and faster attention.
Note: All experiments in the above Table and our paper used SpargeAttn based on SageAttention. An updated implementation based on SageAttention2, is available now. It further offers a 30% speedup.
The quality of video generation on Mochi. |
End-to-end performance of NIAH. |
If you use this code or find our work valuable, please cite:
@inproceedings{zhang2025spargeattn,
title={Spargeattn: Accurate sparse attention accelerating any model inference},
author={Zhang, Jintao and Xiang, Chendong and Huang, Haofeng and Wei, Jia and Xi, Haocheng and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Machine Learning (ICML)},
year={2025}
}
@inproceedings{zhang2025sageattention,
title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration},
author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}
@inproceedings{zhang2024sageattention2,
title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization},
author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei},
booktitle={International Conference on Machine Learning (ICML)},
year={2025}
}




