Skip to content

Commit df80a0f

Browse files
authored
[5/n] Add VSA for Video Diffusion (#1053)
### What does this PR do? Type of change: ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> <!-- Details about the change. --> New feature. Adds Video Sparse Attention (VSA) as a new sparse attention method in ModelOpt. VSA implements a two-branch architecture (compression + sparse) using 3D block tiling for video diffusion model. VSA integrates with HuggingFace models by registering as attn_implementation="modelopt_vsa" in HF's ALL_ATTENTION_FUNCTIONS, which is the same pattern used by the existing Triton FA backend. After `sparsify()`, HF dispatches Q, K, V directly to the VSA kernel with no monkey-patching needed. ### Usage ```python # Load any HuggingFace model from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") # Define VSA config vsa_config = { "sparse_cfg": { "*attn*": { "method": "vsa", "block_size_3d": (4, 4, 4), # 3D tile dimensions (T, H, W) "top_k_ratio": 0.5, # keep top 50% of blocks "video_shape": (8, 16, 16), # video dims after patchification "enable": True, }, "default": {"enable": False}, }, } # Apply — registers modelopt_vsa with HF automatically model = mtsa.sparsify(model, vsa_config) ``` ### Testing <!-- Mention how have you tested your change if applicable. --> `pytest tests/unit/torch/sparsity/attention_sparsity/test_vsa.py` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Video Sparse Attention (VSA) as a new sparse attention method for attention optimization. * Introduced VSA configuration support with customizable block sizes and sparsity ratios. * Integrated VSA with HuggingFace transformers for enhanced model compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 18ce04f commit df80a0f

11 files changed

Lines changed: 1248 additions & 26 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ repos:
101101
examples/speculative_decoding/server_generate.py|
102102
experimental/dms/models/qwen3/configuration_qwen3_dms.py|
103103
experimental/dms/models/qwen3/modeling_qwen3_dms.py|
104+
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
104105
)$
105106
106107
# Default hook for Apache 2.0 in c/c++/cuda files

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ NVIDIA Model Optimizer Changelog
99
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
1010
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1111
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
12+
- Add Video Sparse Attention (VSA) method for video diffusion models (``modelopt.torch.sparsity.attention_sparsity``). VSA uses 3D block tiling with a two-branch architecture for attention speedup.
1213
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
1314
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1415
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.

modelopt/torch/kernels/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@
3535

3636
attention = _attention
3737
IS_AVAILABLE = True
38-
with import_plugin("transformers"):
39-
from .hf_triton_attention import register_triton_attention as _register_triton_attention
38+
from .hf_triton_attention import register_triton_attention as _register_triton_attention
4039

41-
register_triton_attention = _register_triton_attention
40+
register_triton_attention = _register_triton_attention
4241

4342
__all__ = [
4443
"IS_AVAILABLE",

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
498498

499499

500500
# Configuration with RULER calibration
501-
# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length
501+
# Note: threshold field is omitted - calibration determines dynamic threshold lambda = a / length
502502
# The calibrated threshold adapts to sequence length for optimal sparsity
503503
SKIP_SOFTMAX_CALIB = {
504504
"sparse_cfg": {
@@ -521,6 +521,136 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
521521
}
522522

523523

524+
class VSAAttributeConfig(ModeloptBaseConfig):
525+
"""Video Sparse Attention (VSA) attribute configuration.
526+
527+
VSA uses a two-branch architecture optimized for video diffusion models:
528+
1. Compression branch: Block-averaged coarse attention
529+
2. Sparse branch: Top-K block selection for fine-grained attention
530+
"""
531+
532+
method: str = ModeloptField(
533+
default="vsa",
534+
title="Sparse attention method.",
535+
description="Must be 'vsa' for Video Sparse Attention.",
536+
)
537+
538+
enable: bool = ModeloptField(
539+
default=True,
540+
title="Enable VSA.",
541+
description="If True, enables Video Sparse Attention. If False, bypasses sparsity.",
542+
)
543+
544+
block_size_3d: tuple[int, int, int] | list[int] = ModeloptField(
545+
default=(4, 4, 4),
546+
title="3D block size.",
547+
description=(
548+
"Video block dimensions (T, H, W) for spatial-temporal tiling. "
549+
"Default (4, 4, 4) creates 64-token blocks."
550+
),
551+
)
552+
553+
top_k_ratio: float = ModeloptField(
554+
default=0.5,
555+
title="Top-K selection ratio.",
556+
description=(
557+
"Ratio of blocks to keep in sparse branch (0.0 to 1.0). "
558+
"Lower values mean more sparsity. Default 0.5 keeps 50% of blocks."
559+
),
560+
)
561+
562+
video_shape: tuple[int, int, int] | list[int] | None = ModeloptField(
563+
default=None,
564+
title="Video shape.",
565+
description=(
566+
"Video dimensions (T, H, W) after patchification. "
567+
"Required for VSA — set via config or call set_video_shape() at runtime."
568+
),
569+
)
570+
571+
collect_stats: bool = ModeloptField(
572+
default=False,
573+
title="Collect statistics.",
574+
description="Whether to collect sparsity statistics during forward pass.",
575+
)
576+
577+
@field_validator("method")
578+
@classmethod
579+
def validate_vsa_method(cls, v):
580+
"""Validate method is 'vsa'."""
581+
if v != "vsa":
582+
raise ValueError(f"VSAAttributeConfig method must be 'vsa', got '{v}'")
583+
return v
584+
585+
@field_validator("block_size_3d")
586+
@classmethod
587+
def validate_block_size_3d(cls, v):
588+
"""Validate 3D block size."""
589+
if isinstance(v, list):
590+
v = tuple(v)
591+
if len(v) != 3:
592+
raise ValueError(f"block_size_3d must have 3 elements (T, H, W), got {len(v)}")
593+
if any(x <= 0 for x in v):
594+
raise ValueError(f"All block_size_3d values must be positive, got {v}")
595+
return v
596+
597+
@field_validator("top_k_ratio")
598+
@classmethod
599+
def validate_top_k_ratio(cls, v):
600+
"""Validate top-K ratio is in valid range."""
601+
if not 0.0 < v <= 1.0:
602+
raise ValueError(f"top_k_ratio must be in range (0, 1], got {v}")
603+
return v
604+
605+
@field_validator("video_shape")
606+
@classmethod
607+
def validate_video_shape(cls, v):
608+
"""Validate video shape if provided."""
609+
if v is None:
610+
return v
611+
if isinstance(v, list):
612+
v = tuple(v)
613+
if len(v) != 3:
614+
raise ValueError(f"video_shape must have 3 elements (T, H, W), got {len(v)}")
615+
if any(x <= 0 for x in v):
616+
raise ValueError(f"All video_shape values must be positive, got {v}")
617+
return v
618+
619+
620+
class VSAConfig(SparseAttentionConfig):
621+
"""Configuration for Video Sparse Attention optimization."""
622+
623+
sparse_cfg: SparseAttentionCfgType = ModeloptField(
624+
default={
625+
"*attn*": {
626+
"method": "vsa",
627+
"block_size_3d": (4, 4, 4),
628+
"top_k_ratio": 0.5,
629+
"enable": True,
630+
},
631+
"default": {"enable": False},
632+
},
633+
title="VSA configuration",
634+
description="Pattern-based configuration for Video Sparse Attention.",
635+
validate_default=True,
636+
)
637+
638+
639+
# Pre-defined VSA Configuration for video diffusion models.
640+
# Pattern "*attn*" matches attention module names by convention.
641+
VSA_DEFAULT = {
642+
"sparse_cfg": {
643+
"*attn*": {
644+
"method": "vsa",
645+
"block_size_3d": (4, 4, 4),
646+
"top_k_ratio": 0.5,
647+
"enable": True,
648+
},
649+
"default": {"enable": False},
650+
},
651+
}
652+
653+
524654
# Default N:M sparse softmax configuration
525655
SPARSE_SOFTMAX_DEFAULT = {
526656
"sparse_cfg": {
@@ -557,10 +687,13 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
557687
"SKIP_SOFTMAX_DEFAULT",
558688
"SKIP_SOFTMAX_TRITON_DEFAULT",
559689
"SPARSE_SOFTMAX_DEFAULT",
690+
"VSA_DEFAULT",
560691
"CalibrationConfig",
561692
"FlashSkipSoftmaxConfig",
562693
"SparseAttentionAttributeConfig",
563694
"SparseAttentionCfgType",
564695
"SparseAttentionConfig",
565696
"SparseAttributeConfig",
697+
"VSAAttributeConfig",
698+
"VSAConfig",
566699
]

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,42 +33,57 @@
3333

3434

3535
def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) -> None:
36-
"""Set the correct attn_implementation based on the sparse attention backend.
36+
"""Set the correct attn_implementation based on the sparse attention method/backend.
3737
3838
- ``backend="triton"``: registers the Triton kernel with HF and sets
3939
``attn_implementation="modelopt_triton"``.
4040
- ``backend="pytorch"`` (default): sets ``attn_implementation="eager"`` so that
4141
softmax-patching methods (e.g. skip-softmax) work correctly. FlashAttention
4242
and SDPA bypass ``F.softmax``, so eager is required.
43+
- ``method="vsa"``: no-op. VSA patches ``F.scaled_dot_product_attention``
44+
directly in ``SparseAttentionModule.forward()``, so no ``attn_implementation``
45+
change is needed.
4346
4447
This is called automatically during ``mtsa.sparsify()`` so users never need
4548
to manually set ``attn_implementation``.
4649
"""
4750
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
4851

49-
# Collect backends only from layer configs (identified by having a "method" key).
52+
# Collect methods and backends only from layer configs (identified by having a "method" key).
5053
# Other dict entries (e.g. "calibration") are not layer configs.
51-
backends = {
52-
v.get("backend", "pytorch")
53-
for v in sparse_cfg.values()
54-
if isinstance(v, dict) and "method" in v
55-
}
54+
layer_cfgs = [v for v in sparse_cfg.values() if isinstance(v, dict) and "method" in v]
55+
methods = {v.get("method") for v in layer_cfgs}
56+
backends = {v.get("backend", "pytorch") for v in layer_cfgs}
57+
58+
# VSA patches F.scaled_dot_product_attention directly — it does not change
59+
# attn_implementation. Skip the rest for VSA-only configs.
60+
if methods == {"vsa"}:
61+
return
62+
63+
# Reject mixed VSA + non-VSA configs (VSA patches SDPA globally per-module,
64+
# while softmax-patching methods need attn_implementation="eager").
65+
non_vsa_methods = methods - {"vsa"}
66+
if "vsa" in methods and non_vsa_methods:
67+
raise ValueError(
68+
f"Cannot mix VSA with other sparse attention methods ({non_vsa_methods}). "
69+
f"VSA patches F.scaled_dot_product_attention, which is incompatible "
70+
f"with softmax-patching or triton methods."
71+
)
72+
73+
model_config = getattr(model, "config", None)
5674

5775
if "triton" in backends and "pytorch" in backends:
5876
raise ValueError(
5977
"Mixed backends ('triton' and 'pytorch') in the same model are not "
6078
"supported. All sparse attention layers must use the same backend."
6179
)
6280

63-
model_config = getattr(model, "config", None)
64-
6581
if "triton" in backends:
6682
from .kernels import register_triton_attention
6783

6884
if register_triton_attention is None:
6985
raise ImportError(
70-
"Triton backend requires 'triton' and 'transformers' packages. "
71-
"Install with: pip install triton transformers"
86+
"Triton backend requires 'triton' package. Install with: pip install triton"
7287
)
7388
if not register_triton_attention():
7489
raise RuntimeError(
@@ -83,7 +98,6 @@ def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) ->
8398
model_config._attn_implementation = "modelopt_triton"
8499
elif model_config is not None:
85100
# For pytorch backend, force eager for softmax patching.
86-
# TODO: Add the triton backend support for skip-softmax.
87101
model_config._attn_implementation = "eager"
88102

89103

modelopt/torch/sparsity/attention_sparsity/methods/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
]
2525

2626
# Import method implementations to trigger registration
27-
from . import flash_skip_softmax, triton_skip_softmax, triton_sparse_softmax
27+
# Note: vsa imports no external deps at module level; fastvideo_kernel is imported lazily at runtime.
28+
from . import flash_skip_softmax, triton_skip_softmax, triton_sparse_softmax, vsa

modelopt/torch/sparsity/attention_sparsity/methods/registry.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def __init__(self):
3737
self.calibration_params: dict[str, dict[str, float]] | None = None
3838
# Target sparsity ratio per phase: {"prefill": 0.5, "decode": 0.5}
3939
self.target_sparse_ratio: dict[str, float] | None = None
40+
# Video shape for VSA (T, H, W). None for non-VSA methods.
41+
self.video_shape: tuple[int, int, int] | None = None
42+
43+
def forward_attention(
44+
self,
45+
query: torch.Tensor,
46+
key: torch.Tensor,
47+
value: torch.Tensor,
48+
**kwargs,
49+
) -> tuple[torch.Tensor, dict]:
50+
"""Compute full attention replacement (e.g. VSA).
51+
52+
Default: raises NotImplementedError. Override for methods that replace
53+
the entire attention computation rather than patching softmax.
54+
55+
Args:
56+
query: Query tensor [batch, heads, seq_len, dim].
57+
key: Key tensor [batch, heads, seq_len, dim].
58+
value: Value tensor [batch, heads, seq_len, dim].
59+
**kwargs: Method-specific arguments.
60+
61+
Returns:
62+
Tuple of (attention_output, stats_dict).
63+
"""
64+
raise NotImplementedError(f"{type(self).__name__} does not implement forward_attention.")
4065

4166
def calculate_sparsity(
4267
self,

0 commit comments

Comments
 (0)