Skip to content

thu-ml/SpargeAttn

Repository files navigation

SpargeAttention

The official implementation of SpargeAttn, a universal training-free sparse attention accelerating language, image, and video models.

SpargeAttention: Accurate and Training-free Sparse Attention
Accelerating Any Model Inference

Daily papers: HuggingFace arXiv:2502.18137

speed comparison.

overview.

Project Updates

Installation

Base environment

  • python>=3.9 , torch>=2.3.0
  • CUDA:
    • >=12.8 for Blackwell, >=12.4 for fp8 support on Ada, >=12.3 for fp8 support on Hopper, >=12.0 for Ampere

Install Package

pip install ninja   # for parallel compilation
python setup.py install   # or pip install -e .

Avalible API

  • spas_sage2_attn_meansim_topk_cuda: SpargeAttn based on SageAttention2 that we recommend using.

  • spas_sage2_attn_meansim_cuda: SpargeAttn based on SageAttention2 that we do not recommend.

  • spas_sage_attn_meansim_topk_cuda: SpargeAttn based on SageAttention that we recommend using.

  • spas_sage_attn_meansim_cuda: SpargeAttn based on SageAttention that we do not recommend.

Usage Examples

Plug-and-Play Usage

Just replace torch.nn.functional.scaled_dot_product_attention API using spas_sage2_attn_meansim_topk_cuda:

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda

- attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)  # is_causal can be True

+ attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False) # is_causal can be True

A Simple Usage Without Tuning for Any Model

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda

attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, simthreshd1=-0.1, topk=0.5, pvthreshd=15, is_causal=False)

You can adjust topk to balance between attention accuracy (higher topk is more accurate) and sparsity (lower topk is more sparse).

SpargeAttn API based on Top-K selection

Top-K selection is also supported as an alternative to cdfthreshd. We find that Top-K sometimes can achieve better performance than Top-P in SpargeAttn. You can call the API as follows and set the fraction of top elements via the topk parameter:

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda

attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, simthreshd1=-0.1, topk=0.5, pvthreshd=15, is_causal=False)

Note: Automatic tuning for topk is not currently supported.

Sparge+SageAttention2++ with Any Block-Sparse Pattern

from spas_sage_attn import block_sparse_sage2_attn_cuda

block_sparse_sage2_attn_cuda(q, k, v, mask_id=None, scale=None, pvthreshd=20, attention_sink=False, tensor_layout="HND", return_sparsity=False):

In this API, we support computing $S=QK^T$ in any block sparse pattern per attention head. And we compute $PV$ multiplication with further acceleration. Specifically, the attention mask per head, mask_id, is of shape (batch_size, num_qo_heads, qo_seq_len // BLOCK_M, kv_seq_len // BLOCK_N). Currently, the supported block size is aligned to that of SpargeAttention, which is (BLOCK_M = 128, BLOCK_N = 64). The lower pvthreshd, the more sparsity for PV Matmul and faster attention.

Performance

Local Image

Note: All experiments in the above Table and our paper used SpargeAttn based on SageAttention. An updated implementation based on SageAttention2, is available now. It further offers a 30% speedup.


End-to-end video generation on Mochi.
The quality of video generation on Mochi.
End-to-end performance of NIAH.
End-to-end performance of NIAH.

Citation

If you use this code or find our work valuable, please cite:

@inproceedings{zhang2025spargeattn,
  title={Spargeattn: Accurate sparse attention accelerating any model inference},
  author={Zhang, Jintao and Xiang, Chendong and Huang, Haofeng and Wei, Jia and Xi, Haocheng and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2025}
}

@inproceedings{zhang2025sageattention,
  title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration}, 
  author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}

@inproceedings{zhang2024sageattention2,
  title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization},
  author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2025}
}

Releases

No releases published

Packages

No packages published

Languages