Skip to content

Commit a8e41b7

Browse files
committed
Support diffusers models
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 92d988a commit a8e41b7

8 files changed

Lines changed: 115 additions & 191 deletions

File tree

modelopt/torch/kernels/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,8 @@
3939

4040
register_triton_attention = _register_triton_attention
4141

42-
from .hf_vsa_attention import register_vsa_attention # noqa: E402
43-
4442
__all__ = [
4543
"IS_AVAILABLE",
4644
"attention",
4745
"register_triton_attention",
48-
"register_vsa_attention",
4946
]

modelopt/torch/kernels/hf_vsa_attention.py

Lines changed: 0 additions & 120 deletions
This file was deleted.

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) -> None:
3636
"""Set the correct attn_implementation based on the sparse attention method/backend.
3737
38-
- ``method="vsa"``: registers the VSA kernel with HF and sets
39-
``attn_implementation="modelopt_vsa"``. HF calls VSA directly via the
40-
registered attention function — no monkey-patching needed.
4138
- ``backend="triton"``: registers the Triton kernel with HF and sets
4239
``attn_implementation="modelopt_triton"``.
4340
- ``backend="pytorch"`` (default): sets ``attn_implementation="eager"`` so that
4441
softmax-patching methods (e.g. skip-softmax) work correctly. FlashAttention
4542
and SDPA bypass ``F.softmax``, so eager is required.
43+
- ``method="vsa"``: no-op. VSA patches ``F.scaled_dot_product_attention``
44+
directly in ``SparseAttentionModule.forward()``, so no ``attn_implementation``
45+
change is needed.
4646
4747
This is called automatically during ``mtsa.sparsify()`` so users never need
4848
to manually set ``attn_implementation``.
@@ -55,31 +55,23 @@ def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) ->
5555
methods = {v.get("method") for v in layer_cfgs}
5656
backends = {v.get("backend", "pytorch") for v in layer_cfgs}
5757

58-
# VSA uses attn_implementation="modelopt_vsa", which is incompatible
59-
# with softmax-patching methods that need "eager" or triton methods that need
60-
# "modelopt_triton". Reject mixed configs.
58+
# VSA patches F.scaled_dot_product_attention directly — it does not change
59+
# attn_implementation. Skip the rest for VSA-only configs.
60+
if methods == {"vsa"}:
61+
return
62+
63+
# Reject mixed VSA + non-VSA configs (VSA patches SDPA globally per-module,
64+
# while softmax-patching methods need attn_implementation="eager").
6165
non_vsa_methods = methods - {"vsa"}
6266
if "vsa" in methods and non_vsa_methods:
6367
raise ValueError(
6468
f"Cannot mix VSA with other sparse attention methods ({non_vsa_methods}). "
65-
f"VSA sets attn_implementation='modelopt_vsa' model-wide, which is incompatible "
69+
f"VSA patches F.scaled_dot_product_attention, which is incompatible "
6670
f"with softmax-patching or triton methods."
6771
)
6872

6973
model_config = getattr(model, "config", None)
7074

71-
if "vsa" in methods:
72-
from .kernels import register_vsa_attention
73-
74-
if not register_vsa_attention():
75-
raise RuntimeError(
76-
"Failed to register VSA attention with HuggingFace. "
77-
"Check that your transformers version supports ALL_ATTENTION_FUNCTIONS."
78-
)
79-
if model_config is not None:
80-
model_config._attn_implementation = "modelopt_vsa"
81-
return
82-
8375
if "triton" in backends and "pytorch" in backends:
8476
raise ValueError(
8577
"Mixed backends ('triton' and 'pytorch') in the same model are not "

modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,10 @@
1515

1616
"""Re-exports from modelopt.torch.kernels for backward compatibility."""
1717

18-
from modelopt.torch.kernels import (
19-
IS_AVAILABLE,
20-
attention,
21-
register_triton_attention,
22-
register_vsa_attention,
23-
)
18+
from modelopt.torch.kernels import IS_AVAILABLE, attention, register_triton_attention
2419

2520
__all__ = [
2621
"IS_AVAILABLE",
2722
"attention",
2823
"register_triton_attention",
29-
"register_vsa_attention",
3024
]

modelopt/torch/sparsity/attention_sparsity/methods/registry.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,29 @@ def __init__(self):
3838
# Target sparsity ratio per phase: {"prefill": 0.5, "decode": 0.5}
3939
self.target_sparse_ratio: dict[str, float] | None = None
4040

41+
def forward_attention(
42+
self,
43+
query: torch.Tensor,
44+
key: torch.Tensor,
45+
value: torch.Tensor,
46+
**kwargs,
47+
) -> tuple[torch.Tensor, dict]:
48+
"""Compute full attention replacement (e.g. VSA).
49+
50+
Default: raises NotImplementedError. Override for methods that replace
51+
the entire attention computation rather than patching softmax.
52+
53+
Args:
54+
query: Query tensor [batch, heads, seq_len, dim].
55+
key: Key tensor [batch, heads, seq_len, dim].
56+
value: Value tensor [batch, heads, seq_len, dim].
57+
**kwargs: Method-specific arguments.
58+
59+
Returns:
60+
Tuple of (attention_output, stats_dict).
61+
"""
62+
raise NotImplementedError(f"{type(self).__name__} does not implement forward_attention.")
63+
4164
def calculate_sparsity(
4265
self,
4366
attention_scores: torch.Tensor,

modelopt/torch/sparsity/attention_sparsity/methods/vsa.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,10 @@
2222
Uses the optimized Triton kernel from fastvideo_kernel.
2323
2424
Integration:
25-
For HuggingFace models, VSA registers as ``attn_implementation="modelopt_vsa"``
26-
via ``ALL_ATTENTION_FUNCTIONS`` (same pattern as the Triton FA backend). HF
27-
dispatches Q, K, V directly to the VSA kernel — no monkey-patching needed.
28-
This is set up automatically by ``mtsa.sparsify()``.
29-
30-
For non-HF models, call ``forward_attention(q, k, v, ...)`` directly::
31-
32-
for module in model.modules():
33-
if isinstance(module, SparseAttentionModule):
34-
vsa = module._sparse_method_instance
35-
vsa.set_video_shape((T, H, W))
36-
output, stats = vsa.forward_attention(q, k, v)
25+
After ``mtsa.sparsify(model, VSA_DEFAULT)``, each attention layer's
26+
``F.scaled_dot_product_attention`` call is intercepted and replaced by the VSA
27+
kernel. Cross-attention (Q/K have different seq_len) is automatically skipped.
28+
This works with HF transformers and diffusers.
3729
"""
3830

3931
import math
@@ -302,11 +294,11 @@ def forward_attention(
302294
# Kernel operates on tiled tensors in [batch, heads, padded_seq, dim] format
303295
try:
304296
from fastvideo_kernel import video_sparse_attn as triton_vsa_kernel
305-
except ModuleNotFoundError:
306-
raise ModuleNotFoundError(
297+
except ImportError as e:
298+
raise ImportError(
307299
"VSA requires the 'fastvideo_kernel' package for its Triton sparse attention "
308-
"kernel. Install it with: pip install fastvideo_kernel"
309-
) from None
300+
f"kernel. Install it with: pip install fastvideo_kernel (error: {e})"
301+
) from e
310302
output_tiled = triton_vsa_kernel(
311303
query_tiled,
312304
key_tiled,

modelopt/torch/sparsity/attention_sparsity/sparse_attention.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,21 +176,22 @@ def _setup(self):
176176
def forward(self, *args, **kwargs):
177177
"""Forward with selected sparse attention method.
178178
179-
- VSA: dispatched by HF via ``ALL_ATTENTION_FUNCTIONS["modelopt_vsa"]``
180-
inside the original forward — just pass through.
179+
- VSA: patches ``F.scaled_dot_product_attention`` to intercept the SDPA
180+
call inside the original forward. Cross-attention is skipped.
181181
- Softmax-patching methods (e.g. ``flash_skip_softmax``): use the
182182
context manager path below.
183183
"""
184184
# Pass through if sparse attention is disabled
185185
if not self.is_enabled:
186186
return super().forward(*args, **kwargs)
187187

188-
# VSA is dispatched by HF via ALL_ATTENTION_FUNCTIONS["modelopt_vsa"]
189-
# inside the original forward — pass through and let HF call our
190-
# registered vsa_attention_forward().
188+
# VSA: patch F.scaled_dot_product_attention so the VSA kernel intercepts
189+
# the SDPA call inside the original forward. This works for diffusers models
190+
# since SDPA is the common attention primitive.
191+
# Only self-attention is replaced. Cross-attention (Q/K have different seq_len) is skipped.
191192
if self._method == "vsa":
192-
result = super().forward(*args, **kwargs)
193-
# Collect stats set by vsa_attention_forward
193+
result = self._forward_with_vsa_sdpa_patch(args, kwargs)
194+
194195
if self._stats_manager is not None and self._last_stats is not None:
195196
self._stats_manager.collect(self._last_stats)
196197
self._last_stats = None
@@ -210,6 +211,52 @@ def forward(self, *args, **kwargs):
210211

211212
return result
212213

214+
def _forward_with_vsa_sdpa_patch(self, args, kwargs):
215+
"""Run forward with F.scaled_dot_product_attention patched for VSA.
216+
217+
Replaces SDPA with the VSA kernel for self-attention calls (Q and K/V
218+
have the same seq_len). Cross-attention calls fall through to the
219+
original SDPA. Warns if SDPA was never called.
220+
"""
221+
import torch.nn.functional as F
222+
223+
from modelopt.torch.quantization.utils import replace_function
224+
225+
vsa = self._sparse_method_instance
226+
original_sdpa = F.scaled_dot_product_attention
227+
self._vsa_sdpa_called = False
228+
229+
def _patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kw):
230+
self._vsa_sdpa_called = True
231+
# Skip VSA for cross-attention (Q and K/V have different seq_len)
232+
if query.shape[2] != key.shape[2]:
233+
return original_sdpa(
234+
query,
235+
key,
236+
value,
237+
attn_mask=attn_mask,
238+
dropout_p=dropout_p,
239+
is_causal=is_causal,
240+
**kw,
241+
)
242+
output, stats = vsa.forward_attention(query, key, value)
243+
self._last_stats = stats
244+
return output
245+
246+
with replace_function(F, "scaled_dot_product_attention", _patched_sdpa):
247+
result = super().forward(*args, **kwargs)
248+
249+
if not self._vsa_sdpa_called:
250+
import warnings
251+
252+
warnings.warn(
253+
f"VSA: F.scaled_dot_product_attention was not called during "
254+
f"{type(self).__name__}.forward(). The attention layer may use a "
255+
f"custom kernel that bypasses SDPA. VSA had no effect on this layer.",
256+
)
257+
258+
return result
259+
213260
def _get_sparse_context(self):
214261
"""Get the context manager for applying sparse attention.
215262

0 commit comments

Comments
 (0)