diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index 6af226752d..a64ea2b1b4 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio | Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | | Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | | Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | +| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | | | Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | | | Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | | Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | @@ -276,6 +277,67 @@ mto.restore(pipe.unet, your_quantized_ckpt) By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance. +## Sparse Attention (Skip-Softmax) + +Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. + +### Getting Started + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# 1. Define config with calibration +config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.5}, + "threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, + 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1, + 8e-1, 9e-1, 9.9e-1], + }, + "*.attn1": { + "method": "triton_skip_softmax", + "backend": "triton", + "is_causal": False, + "collect_stats": True, + "enable": True, + }, + "*.attn2": {"enable": False}, + "default": {"enable": False}, + }, +} + +# 2. Provide a calibration forward loop +def forward_loop(model): + pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...) + +# 3. Sparsify + calibrate +mtsa.sparsify(transformer, config, forward_loop=forward_loop) + +# 4. Generate as usual — sparsity is applied automatically +output = pipeline(prompt="a dog on the beach", ...) +``` + +### Example Scripts + +#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py) + +The 14B model automatically sparsifies both `transformer` and `transformer_2`. + +```bash +# 5B model — calibrate + generate (4 prompts from OpenVid-1M, 151 frames, 40 steps) +python sparsity/wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --prompt "A sunset over mountains" --output out.mp4 + +# 14B model (both transformers sparsified) +python sparsity/wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --prompt "A sunset over mountains" --output out.mp4 +``` + ## Cache Diffusion Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality. diff --git a/examples/diffusers/quantization/wan2_sage_attention.py b/examples/diffusers/quantization/wan2_sage_attention.py new file mode 100644 index 0000000000..d80ee6cb5f --- /dev/null +++ b/examples/diffusers/quantization/wan2_sage_attention.py @@ -0,0 +1,923 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan2.2 Text-to-Video inference with SageAttention and FP8/NVFP4 attention quantization. + +Attention kernel variants supported via ``--kernel``: + +``sage1`` + ``sageattn`` — SageAttention v1. INT8 QK, FP16 PV. Ampere/Ada/Hopper. + +``sage2-fp16`` + ``sageattn_qk_int8_pv_fp16_cuda`` — SageAttention2. INT8 QK, FP16 PV, per-thread. + +``sage2-fp8`` (default SageAttention kernel) + ``sageattn_qk_int8_pv_fp8_cuda`` — SageAttention2++. INT8 QK, FP8 PV. + Fastest on Ada (SM89) / Hopper (SM90). + +``fp8`` (default for accuracy testing, always available) + Python-level FP8 E4M3 attention. Inspired by SageAttention2: + - Q, K, V are channel-smoothed (per-channel mean subtraction) before quantization + - Per-token FP8 E4M3 quantization with max-based scales + - Dequantize back to BF16, then run standard SDPA + - No CUDA kernel required. Use for accuracy verification on any GPU. + +``triton-sparse`` (requires triton + modelopt) + ModelOpt Triton flash-attention kernel with N:M sparse softmax (2:4 by default). + Applied via ``mtsa.sparsify()`` to the WAN transformer using the ``triton`` + backend with the modelopt_triton diffusers attention backend. For every 4 K + positions, keeps top-2 attention scores; the other 2 are set to -inf before + softmax. + +``triton-skip`` (requires triton + modelopt) + ModelOpt Triton flash-attention kernel with skip-softmax tile pruning. + Tiles whose attention mass is below a threshold (default 0.1) are skipped entirely. + Applied via ``mtsa.sparsify()`` using the ``triton`` backend with the modelopt_triton + diffusers attention backend. + +NVFP4 P-matrix quantization (``--quantize-p``) is a **SageAttention** feature — +an independent quantization pass applied via +``modelopt.torch.quantization.apply_sage_attention()``. It quantizes the +post-softmax P tile to NVFP4 E2M1 inside the Triton kernel (per-tile max scaling). +``--quantize-p`` can be combined with any Triton sparse kernel or used standalone. + +Requirements:: + + pip install sageattention diffusers transformers accelerate ftfy + +Usage:: + + # FP8 accuracy check vs baseline (CLIP score + pixel metrics) + python wan2_sage_attention.py --prompt "..." --compare + + # Single run with FP8 attention + python wan2_sage_attention.py --prompt "..." --kernel fp8 + + # Compare a specific kernel vs baseline with accuracy metrics + python wan2_sage_attention.py --prompt "..." --kernel sage2-fp16 --compare + + # Baseline — standard SDPA + python wan2_sage_attention.py --prompt "..." --baseline + + # Benchmark all kernels (timing only) + python wan2_sage_attention.py --prompt "..." --benchmark + + # ModelOpt Triton N:M sparse attention + python wan2_sage_attention.py --prompt "..." --kernel triton-sparse + + # ModelOpt Triton skip-softmax attention + python wan2_sage_attention.py --prompt "..." --kernel triton-skip + + # SageAttention standalone — NVFP4 P-matrix quantization only (no sparsity) + python wan2_sage_attention.py --prompt "..." --kernel nvfp4 + + # SageAttention v3 — per-group MX NVFP4 on Q, K, V, and P (arxiv 2505.11594) + python wan2_sage_attention.py --prompt "..." --kernel nvfp4-v3 + + # ModelOpt Triton sparse + NVFP4 P-matrix quantization + python wan2_sage_attention.py --prompt "..." --kernel triton-sparse --quantize-p + + # ModelOpt Triton skip-softmax + NVFP4 P-matrix quantization + python wan2_sage_attention.py --prompt "..." --kernel triton-skip --quantize-p + + # Smaller 5B model (fits on a single 24 GB GPU) + python wan2_sage_attention.py \\ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \\ + --prompt "Two cats boxing on a stage" +""" + +import argparse +import os +import time +from contextlib import contextmanager + +import numpy as np +import torch +import torch.nn.functional as F + +# Model IDs available on HuggingFace Hub +MODEL_T2V_14B = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" +MODEL_TI2V_5B = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" + +DEFAULT_MODEL = MODEL_TI2V_5B +DEFAULT_NEGATIVE_PROMPT = "low quality, blurry, distorted, watermark, text, cropped, overexposed" + +# Kernel choices +KERNEL_FP8 = "fp8" +KERNEL_SAGE1 = "sage1" +KERNEL_SAGE2_FP16 = "sage2-fp16" +KERNEL_SAGE2_FP8 = "sage2-fp8" +KERNEL_TRITON_SPARSE = "triton-sparse" +KERNEL_TRITON_SKIP = "triton-skip" +KERNEL_NVFP4 = "nvfp4" +KERNEL_NVFP4_V3 = "nvfp4-v3" +KERNEL_CHOICES = [ + KERNEL_FP8, + KERNEL_SAGE1, + KERNEL_SAGE2_FP16, + KERNEL_SAGE2_FP8, + KERNEL_TRITON_SPARSE, + KERNEL_TRITON_SKIP, + KERNEL_NVFP4, + KERNEL_NVFP4_V3, +] + +# Kernels that modify pipe.transformer in-place via ModelOpt APIs (not SDPA patching). +_TRITON_MODELOPT_KERNELS = { + KERNEL_TRITON_SPARSE, + KERNEL_TRITON_SKIP, +} + +_KERNEL_DESCRIPTIONS = { + KERNEL_FP8: "FP8 E4M3 QKV (Python-level, SA2-inspired smoothing, no CUDA kernel required)", + KERNEL_SAGE1: "sageattn (SA1, INT8 QK + FP16 PV, auto-select)", + KERNEL_SAGE2_FP16: "sageattn_qk_int8_pv_fp16_cuda (SA2, INT8 QK + FP16 PV, per-thread)", + KERNEL_SAGE2_FP8: "sageattn_qk_int8_pv_fp8_cuda (SA2++, INT8 QK + FP8 PV, fp32+fp16 accum)", + KERNEL_TRITON_SPARSE: "ModelOpt Triton flash-attn + N:M sparse softmax (2:4) via mtsa.sparsify()", + KERNEL_TRITON_SKIP: "ModelOpt Triton flash-attn + skip-softmax tile pruning via mtsa.sparsify()", + KERNEL_NVFP4: "ModelOpt SageAttention NVFP4 E2M1 P-matrix quantization via mtq.apply_sage_attention()", +} + +# SageAttention CUDA kernel support by GPU compute capability: +# SM80 Ampere (A100, RTX 3090) sage1, sage2-fp16 +# SM89 Ada (RTX 4090, RTX PRO 6000 Ada) sage1, sage2-fp16, sage2-fp8 +# SM90 Hopper (H100) sage1, sage2-fp16, sage2-fp8 +# SM100 Blackwell datacenter (B100/B200) sage1, sage2-fp16, sage2-fp8 +# SM120 Blackwell consumer (RTX 50-series, +# RTX PRO 6000 Blackwell) NOT supported by SA 2.2.0 +# fp8 kernel always works (pure Python) +_SUPPORTED_SM = {80, 86, 89, 90, 100} + +# FP8 max value for float8_e4m3fn +_FP8_MAX = 448.0 + + +# --------------------------------------------------------------------------- +# GPU detection +# --------------------------------------------------------------------------- + + +def _get_gpu_sm() -> int | None: + if not torch.cuda.is_available(): + return None + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def _fp8_available() -> bool: + return hasattr(torch, "float8_e4m3fn") + + +def _detect_available_kernels() -> list[str]: + """Return kernels available given the installed packages and GPU.""" + available = [] + + # fp8 is pure Python — available on any GPU / PyTorch version + if _fp8_available(): + available.append(KERNEL_FP8) + else: + print( + "[FP8] WARNING: torch.float8_e4m3fn not found. " + "Upgrade to PyTorch >= 2.1 to use the fp8 kernel." + ) + + try: + import sageattention as _sa + except ImportError: + _sa = None + + if _sa is not None: + sm = _get_gpu_sm() + if sm is not None and sm not in _SUPPORTED_SM: + print( + f"[SageAttention] WARNING: GPU SM{sm} not officially supported by SA 2.2.0 " + f"(supported: SM{sorted(_SUPPORTED_SM)}). CUDA kernels may fail. " + "Try: TORCH_CUDA_ARCH_LIST='8.9+PTX' pip install --no-cache-dir sageattention" + ) + + if hasattr(_sa, "sageattn"): + available.append(KERNEL_SAGE1) + if hasattr(_sa, "sageattn_qk_int8_pv_fp16_cuda"): + available.append(KERNEL_SAGE2_FP16) + if hasattr(_sa, "sageattn_qk_int8_pv_fp8_cuda"): + available.append(KERNEL_SAGE2_FP8) + + # Triton ModelOpt kernels require: triton + modelopt sparse attention + try: + import triton # noqa: F401 + + import modelopt.torch.sparsity.attention_sparsity # noqa: F401 + + available.append(KERNEL_TRITON_SPARSE) + available.append(KERNEL_TRITON_SKIP) + available.append(KERNEL_NVFP4) + except ImportError: + pass + + return available + + +AVAILABLE_KERNELS: list[str] = _detect_available_kernels() + + +# --------------------------------------------------------------------------- +# FP8 attention — Python-level, SA2-inspired +# --------------------------------------------------------------------------- + + +def _smooth_quantize_fp8(x: torch.Tensor) -> torch.Tensor: + """Smooth + FP8-quantize + dequantize a Q, K, or V tensor. + + Implements the channel-wise mean-subtraction smoothing from SageAttention2 + (arXiv 2411.10958): + + 1. Subtract per-channel mean across the token dimension (removes systematic + outliers in each head-dim channel, compressing the dynamic range). + 2. Quantize the zero-centred tensor to FP8 E4M3 with a per-token max scale. + 3. Dequantize back to the original dtype. + 4. Add the channel mean back so the result is mathematically equivalent + to the original, up to FP8 rounding error. + + Args: + x: Attention tensor with layout ``(B, H, N, D)`` — batch, heads, + tokens, head-dim. + + Returns: + Tensor of the same shape and dtype as ``x``, simulating FP8 precision. + """ + orig_dtype = x.dtype + + # Step 1 — channel smoothing: mean over token dimension → (B, H, 1, D) + mean = x.mean(dim=-2, keepdim=True) + x_smooth = x - mean + + # Step 2 — per-token scale: max over head-dim → (B, H, N, 1) + scale = x_smooth.abs().amax(dim=-1, keepdim=True).float().clamp(min=1e-12) / _FP8_MAX + + # Step 3 — quantize to FP8 E4M3, then immediately dequantize + x_fp8 = (x_smooth.float() / scale).clamp(-_FP8_MAX, _FP8_MAX).to(torch.float8_e4m3fn) + x_dq = x_fp8.to(orig_dtype) * scale.to(orig_dtype) + + # Step 4 — restore mean + return x_dq + mean + + +def _fp8_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool, + scale: float | None, +) -> torch.Tensor: + """FP8 E4M3 attention with SA2-inspired Q/K/V smoothing. + + All three matrices are independently smoothed and FP8-quantized before + being passed to standard SDPA. This simulates the precision loss of a + true FP8 attention kernel without requiring any compiled CUDA code. + """ + q_dq = _smooth_quantize_fp8(query) + k_dq = _smooth_quantize_fp8(key) + v_dq = _smooth_quantize_fp8(value) + return _orig_sdpa(q_dq, k_dq, v_dq, is_causal=is_causal, scale=scale) + + +# --------------------------------------------------------------------------- +# SDPA patching +# --------------------------------------------------------------------------- + +_orig_sdpa = F.scaled_dot_product_attention +_active_kernel: str = KERNEL_FP8 +_sage_calls: int = 0 +_fallback_calls: int = 0 + + +def _run_kernel( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool, + scale: float | None, +) -> torch.Tensor: + if _active_kernel == KERNEL_FP8: + return _fp8_sdpa(query, key, value, is_causal=is_causal, scale=scale) + + if _active_kernel == KERNEL_SAGE1: + from sageattention import sageattn + + return sageattn(query, key, value, tensor_layout="HND", is_causal=is_causal, sm_scale=scale) + + if _active_kernel == KERNEL_SAGE2_FP16: + from sageattention import sageattn_qk_int8_pv_fp16_cuda + + return sageattn_qk_int8_pv_fp16_cuda( + query, + key, + value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran="per_thread", + sm_scale=scale, + smooth_k=True, + ) + + # KERNEL_SAGE2_FP8 + from sageattention import sageattn_qk_int8_pv_fp8_cuda + + return sageattn_qk_int8_pv_fp8_cuda( + query, + key, + value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran="per_thread", + sm_scale=scale, + pv_accum_dtype="fp32+fp16", + smooth_k=True, + ) + + +def _patched_sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + **kwargs, +) -> torch.Tensor: + global _sage_calls, _fallback_calls + # Fall back to standard SDPA for unsupported cases + if ( + attn_mask is not None + or dropout_p > 0.0 + or query.dtype not in (torch.float16, torch.bfloat16) + ): + _fallback_calls += 1 + return _orig_sdpa( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + _sage_calls += 1 + try: + return _run_kernel(query, key, value, is_causal=is_causal, scale=scale) + except (AssertionError, RuntimeError) as e: + print(f"[Attention] WARNING: kernel={_active_kernel!r} failed ({e}). Falling back to SDPA.") + return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale) + + +def enable_attention_kernel(kernel: str = KERNEL_FP8) -> None: + """Patch ``F.scaled_dot_product_attention`` with the selected kernel. + + Args: + kernel: Kernel name from ``KERNEL_CHOICES``. + """ + global _active_kernel, _sage_calls, _fallback_calls + + if kernel not in KERNEL_CHOICES: + raise ValueError(f"Unknown kernel {kernel!r}. Choose from {KERNEL_CHOICES}") + if kernel in _TRITON_MODELOPT_KERNELS: + raise ValueError( + f"Kernel {kernel!r} cannot be activated via enable_attention_kernel(). " + "Use apply_triton_sparse_kernel() instead." + ) + if kernel not in AVAILABLE_KERNELS: + raise RuntimeError(f"Kernel {kernel!r} is not available. Available: {AVAILABLE_KERNELS}") + + _active_kernel = kernel + _sage_calls = 0 + _fallback_calls = 0 + + F.scaled_dot_product_attention = _patched_sdpa + import torch.nn.functional as _F + + _F.scaled_dot_product_attention = _patched_sdpa + + print(f"[Attention] kernel={kernel} {_KERNEL_DESCRIPTIONS[kernel]}") + + +def disable_attention_kernel() -> None: + F.scaled_dot_product_attention = _orig_sdpa + import torch.nn.functional as _F + + _F.scaled_dot_product_attention = _orig_sdpa + + +@contextmanager +def attention_kernel_ctx(kernel: str = KERNEL_FP8): + enable_attention_kernel(kernel) + try: + yield + finally: + disable_attention_kernel() + + +# --------------------------------------------------------------------------- +# ModelOpt Triton sparse attention — applied via mtsa.sparsify() +# --------------------------------------------------------------------------- + +_TRITON_SPARSE_CONFIG = { + "sparse_cfg": { + "*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + } +} + +_TRITON_SKIP_DEFAULT_THRESHOLD = 0.1 + +_TRITON_SKIP_CONFIG = { + "sparse_cfg": { + "*": { + "method": "triton_skip_softmax", + "skip_softmax_threshold": _TRITON_SKIP_DEFAULT_THRESHOLD, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + } +} + +_TRITON_KERNEL_CONFIGS = { + KERNEL_TRITON_SPARSE: _TRITON_SPARSE_CONFIG, + KERNEL_TRITON_SKIP: _TRITON_SKIP_CONFIG, +} + + +def apply_triton_sparse_kernel( + transformer: torch.nn.Module, + kernel: str, + skip_threshold: float | None = None, +) -> None: + """Apply a ModelOpt Triton sparse attention kernel to the WAN transformer. + + Calls ``mtsa.sparsify()`` with the ``triton`` backend, which activates the + modelopt_triton diffusers attention backend for every attention forward pass. + + This modifies the model in-place. To additionally apply NVFP4 P-matrix + quantization (SageAttention), call ``apply_sage_attention(transformer)`` + **after** this function. + + Args: + transformer: The ``pipe.transformer`` WAN model. + kernel: One of the ``KERNEL_TRITON_*`` constants. + skip_threshold: Override ``skip_softmax_threshold`` for skip-softmax kernels. + ``None`` uses the kernel's built-in default. + Lower = better quality, less speedup. Typical range: 0.001–0.1. + """ + import copy + + import modelopt.torch.sparsity.attention_sparsity as mtsa + + config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel]) + if skip_threshold is not None and kernel == KERNEL_TRITON_SKIP: + star: dict = config["sparse_cfg"]["*"] + star["skip_softmax_threshold"] = skip_threshold + + mtsa.sparsify(transformer, config) + thr = config["sparse_cfg"].get("*", {}).get("skip_softmax_threshold", "n/a") + print(f"[Attention] Applied {kernel}: {_KERNEL_DESCRIPTIONS[kernel]}") + if kernel == KERNEL_TRITON_SKIP: + print(f"[Attention] skip_softmax_threshold={thr}") + + +def print_kernel_stats() -> None: + total = _sage_calls + _fallback_calls + print(f"[Attention] calls: {_sage_calls} quantized, {_fallback_calls} fallback (total {total})") + + +# --------------------------------------------------------------------------- +# Accuracy metrics +# --------------------------------------------------------------------------- + + +def _frames_to_uint8(frames: list) -> np.ndarray: + """Convert a list of PIL images to a uint8 numpy array of shape (N, H, W, 3).""" + import numpy as np + + arrays = [] + for f in frames: + if isinstance(f, np.ndarray): + arr = f if f.dtype == np.uint8 else (f * 255).clip(0, 255).astype(np.uint8) + else: + arr = np.array(f.convert("RGB"), dtype=np.uint8) + arrays.append(arr) + return np.stack(arrays, axis=0) + + +def compute_video_metrics( + frames_ref: list, + frames_quant: list, +) -> dict[str, float]: + """Compute frame-level accuracy metrics between two video frame sequences. + + Metrics: + psnr Peak Signal-to-Noise Ratio (dB). Higher = better. + >40 dB: excellent (barely noticeable). + 30-40 dB: good. + 20-30 dB: noticeable but acceptable. + mae_pct Mean Absolute Error as % of max pixel value (255). Lower = better. + cos_sim Mean cosine similarity of flattened frames. Closer to 1 = better. + + Args: + frames_ref: List of PIL images from the baseline run. + frames_quant: List of PIL images from the quantized run. + + Returns: + Dict with keys ``"psnr"``, ``"mae_pct"``, ``"cos_sim"``. + """ + ref = _frames_to_uint8(frames_ref).astype(np.float32) # (N, H, W, 3) + quant = _frames_to_uint8(frames_quant).astype(np.float32) + + # PSNR + mse_per_frame = ((ref - quant) ** 2).mean(axis=(1, 2, 3)) # (N,) + # Avoid log(0) for identical frames + psnr_per_frame = np.where( + mse_per_frame < 1e-10, + 100.0, + 10.0 * np.log10(255.0**2 / mse_per_frame), + ) + psnr = float(psnr_per_frame.mean()) + + # MAE as % of 255 + mae_pct = float(np.abs(ref - quant).mean() / 255.0 * 100.0) + + # Cosine similarity: flatten each frame to a vector + ref_flat = ref.reshape(ref.shape[0], -1) + quant_flat = quant.reshape(quant.shape[0], -1) + dot = (ref_flat * quant_flat).sum(axis=1) + norm_ref = np.linalg.norm(ref_flat, axis=1) + norm_quant = np.linalg.norm(quant_flat, axis=1) + cos_sim = float((dot / (norm_ref * norm_quant + 1e-12)).mean()) + + return {"psnr": psnr, "mae_pct": mae_pct, "cos_sim": cos_sim} + + +def compute_clip_score( + frames: list, + prompt: str, + clip_model_id: str = "openai/clip-vit-large-patch14", + device: str = "cuda", + max_frames: int = 16, +) -> float: + """Compute mean CLIP score (text-image cosine similarity) over video frames. + + Samples up to ``max_frames`` evenly from the sequence and returns the + average cosine similarity between the CLIP text embedding of ``prompt`` + and the CLIP image embedding of each frame. Higher = more semantically + aligned with the prompt. + + Args: + frames: List of PIL images (the generated video). + prompt: The text prompt used to generate the video. + clip_model_id: HuggingFace model ID or local path for CLIP. + Pass ``HF_TOKEN`` env var for authenticated downloads. + device: Device for the CLIP model (``"cuda"`` or ``"cpu"``). + max_frames: Maximum number of frames to score (evenly sampled). + + Returns: + Mean CLIP cosine similarity score in ``[-1, 1]``. + Typical values for good text-video alignment: ~0.15-0.30 + (varies by model and prompt; compare baseline vs quantized delta). + """ + from transformers import CLIPModel, CLIPProcessor + + token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + processor = CLIPProcessor.from_pretrained(clip_model_id, token=token) + clip_model = CLIPModel.from_pretrained(clip_model_id, token=token).to(device) + clip_model.eval() + + # Evenly sample frames + indices = np.linspace(0, len(frames) - 1, min(max_frames, len(frames)), dtype=int) + sampled = [frames[int(i)] for i in indices] + + with torch.no_grad(): + text_inputs = processor(text=[prompt], return_tensors="pt", padding=True).to(device) + text_feat = clip_model.get_text_features(**text_inputs) + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + + scores = [] + for frame in sampled: + img_inputs = processor(images=frame, return_tensors="pt").to(device) + img_feat = clip_model.get_image_features(**img_inputs) + img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) + scores.append((text_feat * img_feat).sum().item()) + + # Free CLIP model from GPU memory + del clip_model + torch.cuda.empty_cache() + + return float(np.mean(scores)) + + +def print_metrics(metrics: dict[str, float], label: str = "") -> None: + prefix = f"[{label}] " if label else "" + print(f"\n{prefix}Accuracy vs baseline:") + print(f" PSNR: {metrics['psnr']:.2f} dB (>40 excellent, 30-40 good, <30 noticeable)") + print(f" MAE: {metrics['mae_pct']:.4f}% of max pixel value") + print(f" Cosine sim: {metrics['cos_sim']:.6f} (1.0 = identical)") + + +# --------------------------------------------------------------------------- +# Pipeline helpers +# --------------------------------------------------------------------------- + + +def load_pipeline(model_id: str): + """Load the Wan2.2 pipeline (VAE in FP32, transformer + text encoder in BF16).""" + from diffusers import AutoencoderKLWan, WanPipeline + + print(f"[Pipeline] Loading VAE (fp32) from {model_id}...") + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + print("[Pipeline] Loading transformer + text encoder (bf16)...") + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + return pipe + + +def run_inference(pipe, args, label: str = "") -> tuple[float, list]: + """Run one generation pass. + + Returns: + (elapsed_seconds, frames) where frames is the list of PIL images. + """ + generator = torch.Generator("cuda").manual_seed(args.seed) + + if label: + print(f"\n[{label}] Generating {args.num_frames} frames @ {args.height}x{args.width}...") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + frames = pipe( + prompt=args.prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + height=args.height, + width=args.width, + num_frames=args.num_frames, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_steps, + generator=generator, + ).frames[0] + + torch.cuda.synchronize() + elapsed = time.perf_counter() - t0 + + from diffusers.utils import export_to_video + + out_path = args.output if not label else args.output.replace(".mp4", f"_{label}.mp4") + export_to_video(frames, out_path, fps=16) + print(f"[{label or 'result'}] Saved to {out_path} ({elapsed:.1f}s)") + return elapsed, frames + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Wan2.2 T2V with quantized attention (FP8, SageAttention)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--prompt", type=str, required=True, help="Text prompt") + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL, + choices=[MODEL_T2V_14B, MODEL_TI2V_5B], + help="HuggingFace model ID", + ) + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") + parser.add_argument("--height", type=int, default=480) + parser.add_argument("--width", type=int, default=832) + parser.add_argument("--num-frames", type=int, default=81) + parser.add_argument("--num-steps", type=int, default=40, help="Denoising steps") + parser.add_argument("--guidance-scale", type=float, default=4.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--kernel", + type=str, + default=KERNEL_FP8, + choices=KERNEL_CHOICES, + help=( + "fp8: Python-level FP8 E4M3 (SA2 smoothing, no CUDA kernel, accuracy testing); " + "sage1: SA1 INT8+FP16; " + "sage2-fp16: SA2 INT8+FP16; " + "sage2-fp8: SA2++ INT8+FP8; " + "triton-sparse: ModelOpt Triton 2:4 N:M sparse softmax (requires triton + modelopt); " + "triton-skip: ModelOpt Triton skip-softmax tile pruning (requires triton + modelopt); " + "nvfp4: ModelOpt SageAttention NVFP4 P-matrix quantization standalone (requires triton + modelopt); " + "nvfp4-v3: SageAttention v3 per-group MX NVFP4 on Q/K/V/P (requires triton + modelopt)" + ), + ) + parser.add_argument( + "--quantize-p", + action="store_true", + default=False, + help=( + "Apply SageAttention NVFP4 E2M1 P-matrix quantization via " + "modelopt.torch.quantization.apply_sage_attention(). " + "Quantizes the post-softmax P tile inside the Triton kernel (per-tile max scaling). " + "Can be used standalone or combined with any Triton sparse kernel: " + "--kernel triton-sparse --quantize-p" + ), + ) + parser.add_argument( + "--baseline", + action="store_true", + help="Run with standard SDPA, no quantization", + ) + parser.add_argument( + "--compare", + action="store_true", + help=( + "Run baseline + selected kernel, then report accuracy metrics " + "(PSNR, MAE, cosine similarity). Default kernel is fp8." + ), + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Run baseline + all available kernels, report timing table", + ) + parser.add_argument( + "--skip-threshold", + type=float, + default=None, + metavar="LAMBDA", + help=( + "Override skip_softmax_threshold for the triton-skip kernel. " + f"Default: {_TRITON_SKIP_DEFAULT_THRESHOLD}. " + "A tile is skipped when exp(tile_max - running_max) < LAMBDA " + "(equivalently: tile_max < running_max + log(LAMBDA)). " + "Lower = better quality, less speedup. " + "Typical sweep: 0.1 (aggressive), 0.01 (moderate), 0.001 (conservative)." + ), + ) + parser.add_argument( + "--clip-model", + type=str, + default="openai/clip-vit-large-patch14", + help=( + "CLIP model ID or local path for --compare CLIP scoring. " + "Set HF_TOKEN env var for authenticated HuggingFace downloads." + ), + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + pipe = load_pipeline(args.model) + + if args.compare: + # --- Baseline --- + _, frames_base = run_inference(pipe, args, label="baseline") + + # --- Quantized --- + if args.kernel == KERNEL_NVFP4: + from modelopt.torch.quantization import apply_sage_attention + + apply_sage_attention(pipe.transformer) + elif args.kernel in _TRITON_MODELOPT_KERNELS: + apply_triton_sparse_kernel( + pipe.transformer, + args.kernel, + skip_threshold=args.skip_threshold, + ) + if args.quantize_p: + from modelopt.torch.quantization import apply_sage_attention + + apply_sage_attention(pipe.transformer) + else: + enable_attention_kernel(args.kernel) + _, frames_quant = run_inference(pipe, args, label=args.kernel) + if args.kernel not in _TRITON_MODELOPT_KERNELS and args.kernel != KERNEL_NVFP4: + print_kernel_stats() + disable_attention_kernel() + + # --- CLIP scores (per-video semantic alignment with prompt) --- + # Use CPU to avoid OOM: the WAN pipeline already occupies GPU memory. + print("\nComputing CLIP scores (prompt-video semantic alignment)...") + try: + clip_base = compute_clip_score( + frames_base, args.prompt, clip_model_id=args.clip_model, device="cpu" + ) + clip_quant = compute_clip_score( + frames_quant, args.prompt, clip_model_id=args.clip_model, device="cpu" + ) + print(f" baseline CLIP: {clip_base:.4f}") + print(f" {args.kernel} CLIP: {clip_quant:.4f} (delta {clip_quant - clip_base:+.4f})") + print( + " (absolute value varies by model; focus on the delta between baseline and quantized)" + ) + print( + " Tip: set HF_TOKEN env var or use --clip-model to avoid rate limits" + ) + except (OSError, RuntimeError) as e: + print(f" WARNING: CLIP scoring failed ({e})") + print(" To fix: set HF_TOKEN env var or pass --clip-model ") + + # --- Pixel-level metrics --- + metrics = compute_video_metrics(frames_base, frames_quant) + print_metrics(metrics, label=args.kernel) + + elif args.benchmark: + timing: dict[str, float] = {} + + timing["baseline"], _ = run_inference(pipe, args, label="baseline") + + for kernel in KERNEL_CHOICES: + if kernel not in AVAILABLE_KERNELS: + print(f"\n[{kernel}] Skipped — not available") + continue + if kernel in _TRITON_MODELOPT_KERNELS or kernel == KERNEL_NVFP4: + print( + f"\n[{kernel}] Skipped in --benchmark (ModelOpt kernels modify the model " + f"in-place; run separately with --kernel {kernel})" + ) + continue + enable_attention_kernel(kernel) + timing[kernel], _ = run_inference(pipe, args, label=kernel) + print_kernel_stats() + disable_attention_kernel() + + t_base = timing["baseline"] + print(f"\n{'=' * 55}") + print(f" {'Kernel':<20} {'Time':>8} {'Speedup':>8}") + print(f" {'-' * 40}") + print(f" {'baseline (SDPA)':<20} {t_base:>7.1f}s {'1.00x':>8}") + for kernel in KERNEL_CHOICES: + if kernel in _TRITON_MODELOPT_KERNELS or kernel == KERNEL_NVFP4: + print(f" {kernel:<20} {'N/A':>8} {'N/A':>8} (run separately)") + continue + if kernel not in timing: + print(f" {kernel:<20} {'N/A':>8} {'N/A':>8} (not available)") + continue + t = timing[kernel] + print(f" {kernel:<20} {t:>7.1f}s {t_base / t:>7.2f}x") + print(f"{'=' * 55}") + + elif args.baseline: + run_inference(pipe, args, label="baseline") + + elif args.kernel == KERNEL_NVFP4: + from modelopt.torch.quantization import apply_sage_attention + + apply_sage_attention(pipe.transformer) + run_inference(pipe, args, label=args.kernel) + + elif args.kernel == KERNEL_NVFP4_V3: + from modelopt.torch.quantization import apply_sage_attention_v3 + + apply_sage_attention_v3(pipe.transformer) + run_inference(pipe, args, label=args.kernel) + + elif args.kernel in _TRITON_MODELOPT_KERNELS: + apply_triton_sparse_kernel( + pipe.transformer, + args.kernel, + skip_threshold=args.skip_threshold, + ) + if args.quantize_p: + from modelopt.torch.quantization import apply_sage_attention + + apply_sage_attention(pipe.transformer) + run_inference(pipe, args, label=args.kernel) + + else: + enable_attention_kernel(args.kernel) + run_inference(pipe, args, label=args.kernel) + print_kernel_stats() + disable_attention_kernel() + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md new file mode 100644 index 0000000000..e4aea18f0d --- /dev/null +++ b/examples/diffusers/sparsity/README.md @@ -0,0 +1,123 @@ +# Skip-Softmax Sparse Attention for Diffusion Models + +Skip-softmax sparse attention (BLASST) skips KV tiles whose attention scores +are negligible during the FlashAttention computation, reducing FLOPs without +retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) +is calibrated once, then the target sparsity can be adjusted at runtime without +recalibration. + +## Changes from Main Branch + +### Core Triton Kernel (`modelopt/torch/kernels/`) + +| File | Change | +|------|--------| +| `triton_fa.py` | Added `_attn_fwd_calibrate` kernel: computes full attention while measuring skip decisions for multiple thresholds via atomic counters. Added `attention_calibrate()` Python API. | +| `__init__.py` | Export `attention_calibrate` alongside `attention`. | + +The kernel has two modes: +- **Inference** (`_attn_fwd`): Autotuned, single threshold, actual tile skipping. +- **Calibration** (`_attn_fwd_calibrate`): Fixed block sizes (128×64), multi-threshold measurement, no skipping (full attention output). + +### Sparse Attention Methods (`modelopt/torch/sparsity/attention_sparsity/methods/`) + +| File | Change | +|------|--------| +| `triton_skip_softmax.py` | Extended with calibration support: `_triton_calibration_context()` sets Triton calibration mode and collects counters; `_triton_inference_context()` activates diffusers backend with calibrated threshold; `_get_diffusers_backend_context()` activates `modelopt_triton` attention backend. | +| `flash_skip_softmax.py` | Enhanced `get_sparse_context()` with `ExitStack` to also activate diffusers eager backend for calibration. | +| `registry.py` | Added `set_calibration_mode()` to base `SparseAttentionMethod` class. | +| `__init__.py` | Updated imports. | + +### Kernel Backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`) + +| File | Change | +|------|--------| +| `__init__.py` | Added thread-local context (`set_skip_softmax_context` / `get_skip_softmax_context`), lazy imports for diffusers/LTX backends with `contextlib.suppress(ImportError, RuntimeError)`. | +| `diffusers_triton_attention.py` | **New.** Registers `modelopt_triton` backend in diffusers. Two modes: inference calls `attention()`, calibration calls `attention_calibrate()`. Accumulates counters across attention calls. | +| `diffusers_eager_attention.py` | **New.** Registers `modelopt_skip_softmax` eager backend for LLM calibration (explicit `F.softmax` for patching). | +| `ltx_triton_attention.py` | **New.** Patches `ltx_core.Attention` modules for Triton dispatch. Supports calibration and inference modes. | +| `ltx_eager_attention.py` | **New.** Patches `ltx_core.Attention` for eager attention calibration. | + +### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`) + +| File | Change | +|------|--------| +| `calibrate.py` | Skip RULER dataset generation when user provides `forward_loop` (required for diffusion models). Guard `from transformers import AutoTokenizer` as lazy import. | +| `calibrator.py` | `_set_thresholds()` detects method type — sets `_threshold_trials` for `triton_skip_softmax`, `thresholds` for `flash_skip_softmax`. | + +### Conversion & Config + +| File | Change | +|------|--------| +| `conversion.py` | Added `_register_diffusers_backends_if_needed()` — auto-registers diffusers/LTX backends on `sparsify()`. Updated export config and summary display. | +| `config.py` | Added `skip_softmax_threshold` field to `SparseAttentionAttributeConfig`. | +| `plugins/huggingface.py` | Added diffusers `ModelMixin` support in `_is_supported_model()`. Lazy `import transformers`. | +| `stats_manager.py` | Made `sparse_blocks` optional in `collect()`. Preserve `normalized_gaps` in calibration stats. | +| `sparse_attention.py` | (Changes from main for VSA support also present.) | + +### Example Scripts + +| File | Description | +|------|-------------| +| `wan22_skip_softmax.py` | **New.** Wan 2.2 text-to-video with skip-softmax. Supports 5B (single transformer) and 14B (dual transformer). Uses `triton_skip_softmax` with Triton calibration kernel. Calibration prompts from OpenVid-1M. | + +### Tests + +| File | Description | +|------|-------------| +| `test_kernel_backends.py` | **New.** Unit tests for diffusers kernel backends with mocked dependencies (no GPU required). | + +## Usage + +```bash +# Wan 2.2 5B — calibrate + generate +python wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --calib-frames 151 --calib-steps 40 \ + --prompt "A cat sitting on a windowsill" --output out.mp4 + +# Wan 2.2 14B — both transformers sparsified +python wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --calib-frames 151 --calib-steps 40 \ + --prompt "A sunset over mountains" --output out.mp4 + +# Calibrate only (no video generation) +python wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 +``` + +## Architecture + +```text +mtsa.sparsify(transformer, config, forward_loop) + │ + ├─ apply_mode() → replace attention with SparseAttentionModule + │ + └─ calibrate() + │ + ├─ DynamicThresholdCalibrator._set_thresholds() + │ └─ sets method._threshold_trials = [1e-6, ..., 9.9e-1] + │ + ├─ forward_loop(model) + │ │ + │ └─ SparseAttentionModule.forward() + │ │ + │ └─ triton_skip_softmax._triton_calibration_context() + │ ├─ set_triton_skip_softmax_config(calibration_mode=True) + │ ├─ attention_backend("modelopt_triton") + │ ├─ _diffusers_triton_attention() → attention_calibrate() + │ │ └─ _attn_fwd_calibrate kernel (full attn + atomic counters) + │ └─ _collect_calibration_stats() → module._last_stats + │ + ├─ Fit: scale_factor = a * exp(b * sparsity) + │ + └─ Apply a, b to all modules + │ + └─ Inference: triton_skip_softmax._triton_inference_context() + ├─ threshold = a * exp(b * target) / seqlen + └─ attention() with skip_softmax_threshold → actual tile skipping +``` diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py new file mode 100644 index 0000000000..2170f46249 --- /dev/null +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan 2.2 inference with skip-softmax sparse attention. + +This example applies skip-softmax sparse attention to the Wan 2.2 video +generation model (text-to-video) using exponential model calibration +(``scale_factor = a * exp(b * target_sparsity)``). + +During calibration, ``flash_skip_softmax`` with the eager attention backend +collects sparsity statistics across multiple threshold trials. The fitted +exponential model then allows runtime control of the target sparsity ratio +without recalibration. + +The Wan 2.2 5B model has 40 transformer blocks with self-attention (attn1) +and cross-attention (attn2). Only self-attention is sparsified. + +Usage:: + + # With calibration (recommended) + python wan22_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ + --calibrate --target-sparsity 0.25 + + # Custom model path + python wan22_skip_softmax.py --model-path /path/to/Wan2.2-T2V-5B \\ + --prompt "A sunset over mountains" --output sunset.mp4 --calibrate +""" + +import argparse +import os + +import torch +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.utils import export_to_video + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") + +# fmt: off +# ruff: noqa: RUF001 +DEFAULT_NEGATIVE_PROMPT = ( # Official Wan 2.2 negative prompt (Chinese) + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面," + "杂乱的背景,三条腿,背景人很多,倒着走" +) +# fmt: on + +# Default threshold trials for calibration +DEFAULT_THRESHOLD_TRIALS = [ + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 2e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, + 8e-1, + 9e-1, + 9.9e-1, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Wan 2.2 video generation with skip-softmax sparse attention" + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Text prompt for generation (optional, skips generation if not set)", + ) + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") + parser.add_argument( + "--model-path", type=str, default=DEFAULT_MODEL_PATH, help="Wan 2.2 model path or HF ID" + ) + parser.add_argument( + "--num-frames", type=int, default=81, help="Number of frames (must be 4k+1)" + ) + parser.add_argument("--height", type=int, default=480, help="Video height") + parser.add_argument("--width", type=int, default=832, help="Video width") + parser.add_argument("--num-steps", type=int, default=40, help="Number of inference steps") + parser.add_argument( + "--guidance-scale", type=float, default=4.0, help="Classifier-free guidance scale" + ) + parser.add_argument( + "--guidance-scale-2", + type=float, + default=3.0, + help="Second guidance scale for 14B dual-transformer model (ignored by 5B)", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=DEFAULT_NEGATIVE_PROMPT, + help="Negative prompt", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + # Sparse attention options + parser.add_argument( + "--skip-first-last", + type=int, + default=2, + help="Number of first/last transformer layers to keep dense (default: 2)", + ) + + # Calibration options + parser.add_argument( + "--calibrate", + action="store_true", + help="Calibrate threshold via exponential model (recommended)", + ) + parser.add_argument( + "--target-sparsity", + type=float, + default=0.5, + help="Target sparsity ratio for calibration (0.0-1.0)", + ) + parser.add_argument( + "--calib-steps", + type=int, + default=40, + help="Inference steps for calibration", + ) + parser.add_argument( + "--calib-frames", + type=int, + default=151, + help="Number of frames for calibration", + ) + parser.add_argument( + "--calib-size", + type=int, + default=4, + help="Number of calibration prompts from OpenVid-1M dataset", + ) + return parser.parse_args() + + +def build_pipeline(model_path: str) -> WanPipeline: + """Build the Wan 2.2 text-to-video pipeline.""" + vae = AutoencoderKLWan.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_path, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + return pipe + + +def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build sparse attention config from CLI args. + + Uses triton_skip_softmax with the Triton FA kernel for both calibration + and inference. Calibration collects multi-threshold sparsity statistics + via the Triton calibration kernel, then fits an exponential model: + scale_factor = a * exp(b * sparsity). + """ + attn_cfg: dict = { + "method": "triton_skip_softmax", + "skip_softmax_threshold": 0.1, + "backend": "triton", + "is_causal": False, # Diffusion = bidirectional attention + "collect_stats": True, + "enable": True, + } + + sparse_cfg: dict = { + "*.attn1*": attn_cfg, # Self-attention only + "*.attn2*": {"enable": False}, # Text cross-attention + "default": {"enable": False}, + } + + # Keep first/last N layers dense for quality + for i in range(args.skip_first_last): + sparse_cfg[f"*blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*blocks.{num_blocks - 1 - i}.attn*"] = {"enable": False} + + config: dict = {"sparse_cfg": sparse_cfg} + + # Add calibration config with threshold trials + if args.calibrate: + sparse_cfg["calibration"] = { + "target_sparse_ratio": {"prefill": args.target_sparsity}, + "samples": 1, + "threshold_trials": DEFAULT_THRESHOLD_TRIALS, + } + + return config + + +def load_calib_prompts(calib_size: int) -> list[str]: + """Load calibration prompts from OpenVid-1M dataset.""" + from datasets import load_dataset + + dataset = load_dataset("nkp37/OpenVid-1M", split="train") + prompts = list(dataset["caption"][:calib_size]) + print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M") + return prompts + + +def build_calibration_forward_loop( + pipe: WanPipeline, + calib_size: int = 4, + num_steps: int = 40, + num_frames: int = 151, + height: int = 480, + width: int = 832, + seed: int = 42, + guidance_scale: float = 4.0, + guidance_scale_2: float | None = 3.0, + negative_prompt: str = "", +): + """Build a forward loop for exponential model calibration. + + Uses prompts from OpenVid-1M dataset (same as quantization examples). + Each prompt is run individually (batch_size=1). + """ + calib_prompts = load_calib_prompts(calib_size) + + def forward_loop(model): + for i, prompt in enumerate(calib_prompts): + print(f"Calibration [{i + 1}/{len(calib_prompts)}]: {prompt[:60]}...") + kw: dict = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "height": height, + "width": width, + "num_inference_steps": num_steps, + "guidance_scale": guidance_scale, + "generator": torch.Generator(device="cuda").manual_seed(seed), + } + if guidance_scale_2 is not None: + kw["guidance_scale_2"] = guidance_scale_2 + pipe(**kw) + + return forward_loop + + +def print_sparsity_summary(model: torch.nn.Module) -> None: + """Print per-module sparsity statistics.""" + enabled, disabled = [], [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + if module.is_enabled: + enabled.append((name, module)) + else: + disabled.append(name) + + print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled") + for name, module in enabled: + info = module.get_threshold_info() + print(f" {name}: {info}") + + +def _get_num_blocks(transformer: torch.nn.Module) -> int: + """Count transformer blocks by looking for *.blocks.N.* submodules.""" + max_idx = -1 + for name, _ in transformer.named_modules(): + parts = name.split(".") + for i, part in enumerate(parts): + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + max_idx = max(max_idx, int(parts[i + 1])) + return max_idx + 1 + + +def main() -> None: + args = parse_args() + + # ---- Build pipeline ---- + print(f"Loading Wan 2.2 from {args.model_path}...") + pipe = build_pipeline(args.model_path) + + # ---- Collect transformers to sparsify ---- + # Wan 2.2 5B has one transformer; 14B has two (transformer + transformer_2) + transformers_to_sparsify = [] + if pipe.transformer is not None: + transformers_to_sparsify.append(("transformer", pipe.transformer)) + if getattr(pipe, "transformer_2", None) is not None: + transformers_to_sparsify.append(("transformer_2", pipe.transformer_2)) + + # ---- Build calibration forward loop (shared across transformers) ---- + forward_loop = None + if args.calibrate: + forward_loop = build_calibration_forward_loop( + pipe, + calib_size=args.calib_size, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + height=args.height, + width=args.width, + seed=args.seed, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_2, + negative_prompt=args.negative_prompt, + ) + + # ---- Sparsify each transformer ---- + for name, transformer in transformers_to_sparsify: + num_blocks = _get_num_blocks(transformer) + print(f"Applying skip-softmax to {name} ({num_blocks} blocks)...") + config = build_sparse_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + # ---- Generate (optional) ---- + if args.prompt: + print(f"Generating: {args.prompt[:80]}...") + pipe_kwargs: dict = { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "num_frames": args.num_frames, + "height": args.height, + "width": args.width, + "num_inference_steps": args.num_steps, + "guidance_scale": args.guidance_scale, + "generator": torch.Generator(device="cuda").manual_seed(args.seed), + } + if args.guidance_scale_2 is not None: + pipe_kwargs["guidance_scale_2"] = args.guidance_scale_2 + output = pipe(**pipe_kwargs) + + export_to_video(output.frames[0], args.output, fps=16) + print(f"Saved to {args.output}") + + # ---- Print stats ---- + for name, transformer in transformers_to_sparsify: + print(f"\n{name}:") + print_sparsity_summary(transformer) + + +if __name__ == "__main__": + main() diff --git a/examples/vllm_serve/sparse_attn_worker.py b/examples/vllm_serve/sparse_attn_worker.py new file mode 100644 index 0000000000..c84b8a7eeb --- /dev/null +++ b/examples/vllm_serve/sparse_attn_worker.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom vLLM workers for sparse attention. + +``SparseAttnWorker``: Replaces ``FlashAttentionImpl`` with +``ModelOptSparseAttentionImpl`` on each Attention module after model loading. +The sparse impl uses the ModelOpt Triton kernel for both prefill and decode. + +``SparseQuantWorker``: Applies quantization first, then sparse attention via +direct module walk (registry stacking does not work due to ``_DMRegistryCls`` +forward identity check). + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --enforce-eager +""" + +import fnmatch +import json +import os +from typing import Any + +from fakequant_worker import disable_compilation +from vllm.attention.layer import Attention as VLLMAttention +from vllm.v1.worker.gpu_worker import Worker as BaseWorker + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import ModelOptSparseAttentionImpl + +# --------------------------------------------------------------------------- +# Configuration from environment variables +# --------------------------------------------------------------------------- + +sparse_config: dict[str, Any] = { + "sparse_cfg": os.environ.get("SPARSE_ATTN_CFG", None), + "calib_config_path": os.environ.get("SPARSE_CALIB_CONFIG_PATH", None), +} + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +_DEFAULT_SPARSE_CFG = { + "sparse_cfg": { + "*attn*": { + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 1, + "enable": True, + }, + "default": {"enable": False}, + }, +} + + +def _build_sparse_config(env_config: dict[str, Any]) -> dict | None: + """Build sparse_cfg dict from env vars.""" + cfg_name = env_config["sparse_cfg"] + if cfg_name is None: + return None + # Try looking up preset from mtsa, fall back to default + cfg = getattr(mtsa, cfg_name, None) + if cfg is not None: + return cfg + # Use built-in default if name matches + if cfg_name in ("SPARSE_SOFTMAX_DEFAULT", "default"): + return _DEFAULT_SPARSE_CFG + raise ValueError( + f"Unknown sparse config: {cfg_name}. Set SPARSE_ATTN_CFG to 'default' or a valid preset name." + ) + + +def _load_sparse_config(path: str) -> dict: + """Load offline calibration config JSON.""" + with open(path) as f: + calib_cfg = json.load(f) + + sparse_cfg = {} + for pattern, layer_cfg in calib_cfg.items(): + if pattern == "calibration": + sparse_cfg[pattern] = layer_cfg + continue + layer_cfg.setdefault("method", "triton_sparse_softmax") + layer_cfg.setdefault("backend", "triton") + layer_cfg.setdefault("enable", True) + sparse_cfg[pattern] = layer_cfg + sparse_cfg["default"] = {"enable": False} + + return {"sparse_cfg": sparse_cfg} + + +def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None: + """Match a module name against sparse_cfg patterns.""" + cfg = sparse_cfg.get("sparse_cfg", sparse_cfg) + for pattern, layer_cfg in cfg.items(): + if pattern in ("default", "calibration"): + continue + if fnmatch.fnmatch(module_name, pattern): + return layer_cfg + return None + + +def _replace_attention_impl(worker, config: dict): + """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers. + + Shared by SparseAttnWorker and SparseQuantWorker. + """ + if config["calib_config_path"]: + cfg = _load_sparse_config(config["calib_config_path"]) + else: + cfg = _build_sparse_config(config) + + if cfg is None: + return + + model = worker.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + patched = 0 + for name, module in model.named_modules(): + if not isinstance(module, VLLMAttention): + continue + + # Match per-layer sparse config using name-based patterns + layer_cfg = _match_sparse_config(name, cfg) + if layer_cfg is None or not layer_cfg.get("enable", True): + continue + + method = layer_cfg.get("method", "triton_sparse_softmax") + backend = layer_cfg.get("backend", "triton") + if backend != "triton" or method not in {"triton_sparse_softmax", "triton_skip_softmax"}: + raise ValueError( + f"{name}: unsupported sparse config for vLLM worker " + f"(backend={backend!r}, method={method!r}). " + "Only backend='triton' with method='triton_sparse_softmax' or " + "'triton_skip_softmax' is supported." + ) + + # Build per-layer sparse kwargs + sparse_kw = {} + sparsity_n = layer_cfg.get("sparsity_n", 0) + if sparsity_n > 0: + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + threshold = layer_cfg.get("skip_softmax_threshold") + if threshold: + sparse_kw["skip_softmax_threshold"] = threshold + + # Replace impl and store per-layer config + old_impl = module.impl + new_impl = ModelOptSparseAttentionImpl( + num_heads=old_impl.num_heads, + head_size=old_impl.head_size, + scale=old_impl.scale, + num_kv_heads=old_impl.num_kv_heads, + alibi_slopes=old_impl.alibi_slopes, + sliding_window=None, # overwritten below + kv_cache_dtype=old_impl.kv_cache_dtype, + logits_soft_cap=old_impl.logits_soft_cap, + attn_type=old_impl.attn_type, + kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name, + ) + # Copy the already-transformed sliding_window tuple directly, + # since __init__ transforms int -> (sw-1, 0) and we can't reverse it. + new_impl.sliding_window = old_impl.sliding_window + # Store per-layer sparse kwargs on the impl for forward() to read + new_impl.sparse_kw = sparse_kw + module.impl = new_impl + patched += 1 + print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers") + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + + +class SparseAttnWorker(BaseWorker): + """vLLM worker that uses the ModelOpt sparse attention backend. + + Replaces FlashAttentionImpl with ModelOptSparseAttentionImpl on each + Attention module right after model loading — before any forward pass + (including determine_available_memory profiling). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self, sparse_config) + + +class SparseQuantWorker(BaseWorker): + """vLLM worker that applies quantization + sparse attention. + + Quantization uses the standard registry-based ``mtq.quantize()``. + Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl + (same approach as SparseAttnWorker). + """ + + def load_model(self, *args, **kwargs) -> None: + """Load model, then replace attention impl with sparse variant.""" + super().load_model(*args, **kwargs) + _replace_attention_impl(self, sparse_config) + + def compile_or_warm_up_model(self) -> None: + """Apply quantization before warm-up.""" + from fakequant_worker import _fakequant_run_prolog_worker, quant_config + + model = self.model_runner.model + if hasattr(model, "unwrap"): + model = model.unwrap() + + with disable_compilation(model): + if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + _fakequant_run_prolog_worker(self) + + super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_serve_sparse_attn.py b/examples/vllm_serve/vllm_serve_sparse_attn.py new file mode 100644 index 0000000000..157936e657 --- /dev/null +++ b/examples/vllm_serve/vllm_serve_sparse_attn.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launch vLLM with sparse attention. + +Usage: + SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT python vllm_serve_sparse_attn.py \\ + meta-llama/Llama-3.1-8B --max-model-len 8192 + +Combined with quantization: + QUANT_CFG=INT8_SMOOTHQUANT_CFG SPARSE_ATTN_CFG=SPARSE_SOFTMAX_DEFAULT \\ + python vllm_serve_sparse_attn.py meta-llama/Llama-3.1-8B +""" + +import os +import sys +from pathlib import Path + +import uvloop +import vllm +from packaging import version +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser + +vllm_version = version.parse(vllm.__version__) +if vllm_version <= version.parse("0.11.0"): + from vllm.utils import FlexibleArgumentParser +else: + from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Pass sparse attention env vars to ray workers (if supported by this vLLM version) +additional_env_vars = { + "SPARSE_ATTN_CFG", + "SPARSE_CALIB_CONFIG_PATH", + "QUANT_DATASET", + "QUANT_CALIB_SIZE", + "QUANT_CFG", + "AMAX_FILE_PATH", + "KV_QUANT_CFG", +} + +try: + if vllm_version <= version.parse("0.11.0"): + from vllm.executor.ray_distributed_executor import RayDistributedExecutor + else: + from vllm.v1.executor.ray_executor import RayDistributedExecutor + if hasattr(RayDistributedExecutor, "ADDITIONAL_ENV_VARS"): + RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) +except ImportError: + pass # Ray not installed, single-node only + + +def main(): + """Launch vLLM with sparse attention worker.""" + parser = FlexibleArgumentParser(description="vLLM model server with sparse attention") + parser.add_argument("model", type=str, help="The path or name of the model to serve") + parser = make_arg_parser(parser) + + # Ensure workers can import our custom worker module + repo_root = str(Path(__file__).resolve().parent) + if repo_root not in sys.path: + sys.path.insert(0, repo_root) + existing = os.environ.get("PYTHONPATH") + parts = [p for p in [existing, repo_root] if p] + os.environ["PYTHONPATH"] = os.pathsep.join(parts) + + # Select worker based on env vars + has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG") + has_sparse = os.environ.get("SPARSE_ATTN_CFG") or os.environ.get("SPARSE_CALIB_CONFIG_PATH") + + if has_quant and has_sparse: + worker_cls = "sparse_attn_worker.SparseQuantWorker" + elif has_sparse: + worker_cls = "sparse_attn_worker.SparseAttnWorker" + else: + print("Warning: No SPARSE_ATTN_CFG or QUANT_CFG set. Running standard vLLM.") + worker_cls = None + + if worker_cls: + parser.set_defaults(worker_cls=worker_cls) + + args = parser.parse_args() + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/kernels/__init__.py b/modelopt/torch/kernels/__init__.py index 24d27a1ba2..fa07b06e20 100644 --- a/modelopt/torch/kernels/__init__.py +++ b/modelopt/torch/kernels/__init__.py @@ -21,6 +21,7 @@ IS_AVAILABLE = False attention = None +attention_calibrate = None register_triton_attention = None if torch.cuda.is_available(): @@ -32,8 +33,10 @@ ), ): from .triton_fa import attention as _attention + from .triton_fa import attention_calibrate as _attention_calibrate attention = _attention + attention_calibrate = _attention_calibrate IS_AVAILABLE = True from .hf_triton_attention import register_triton_attention as _register_triton_attention @@ -42,5 +45,6 @@ __all__ = [ "IS_AVAILABLE", "attention", + "attention_calibrate", "register_triton_attention", ] diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index 8d3b11f1af..aa91dcc70a 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -186,6 +186,95 @@ def _is_dense_region( return is_sink or is_local +# --------------------------------------------------------------------------- +# Paged KV cache helpers +# --------------------------------------------------------------------------- +@triton.jit +def _load_paged_k_tile( + K_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load K^T tile [BLOCK_D, BLOCK_N] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos # absolute token positions + kv_valid = kv_abs < seq_len_kv + + # Translate token positions -> (page_id, offset_in_page) + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # Load K values: K_cache[page_global, offset_in_page, kv_head_idx, dim] + # K^T layout [BLOCK_D, BLOCK_N] for Q @ K^T matmul + k_ptrs = ( + page_global[None, :] * stride_kc_block + + offset_in_page[None, :] * stride_kc_pos + + kv_head_idx * stride_kc_head + + dim_pos[:, None] + ) + return tl.load(K_cache + k_ptrs, mask=kv_valid[None, :] & d_mask[:, None], other=0.0) + + +@triton.jit +def _load_paged_v_tile( + V_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + Block_table, # [batch, max_blocks_per_seq] + batch_idx, + kv_head_idx, + kv_start, + kv_pos, # [BLOCK_N] relative positions + dim_pos, # [BLOCK_D] + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HEAD_DIM: tl.constexpr, + max_blocks_per_seq, +): + """Load V tile [BLOCK_N, BLOCK_D] from paged KV cache.""" + d_mask = dim_pos < HEAD_DIM + kv_abs = kv_start + kv_pos + kv_valid = kv_abs < seq_len_kv + + page_local = kv_abs // PAGE_SIZE + offset_in_page = kv_abs % PAGE_SIZE + page_global = tl.load( + Block_table + batch_idx * max_blocks_per_seq + page_local, + mask=kv_valid, + other=0, + ) + + # V layout [BLOCK_N, BLOCK_D] + v_ptrs = ( + page_global[:, None] * stride_vc_block + + offset_in_page[:, None] * stride_vc_pos + + kv_head_idx * stride_vc_head + + dim_pos[None, :] + ) + return tl.load(V_cache + v_ptrs, mask=kv_valid[:, None] & d_mask[None, :], other=0.0) + + # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -212,6 +301,79 @@ def _apply_mask( return scores +# --------------------------------------------------------------------------- +# NVFP4 E2M1 per-tile P-matrix quantization helper +# --------------------------------------------------------------------------- +@triton.jit +def _quantize_p_nvfp4(p, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + """Quantize a post-softmax p tile to NVFP4 E2M1 (straight-through estimator). + + P values are non-negative (softmax output). Per-tile scaling: + ``scale = max(p) / 6.0`` then boundary-compare to the 8 positive NVFP4 + E2M1 levels {0, 0.5, 1, 1.5, 2, 3, 4, 6}. The backward pass uses the + straight-through estimator (quantization not re-applied during backward). + """ + p_max = tl.max(tl.max(p, 1), 0) # per-tile scalar maximum + scale = tl.maximum(p_max, 1e-12) / 6.0 # NVFP4 max representable level = 6.0 + p_scaled = p / scale # normalize to [0, 6] + # Progressive boundary-compare: walk through NVFP4 rounding boundaries + # Boundaries: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + q = tl.full((BLOCK_M, BLOCK_N), 0.0, dtype=tl.float32) + q = tl.where(p_scaled >= 0.25, 0.5, q) + q = tl.where(p_scaled >= 0.75, 1.0, q) + q = tl.where(p_scaled >= 1.25, 1.5, q) + q = tl.where(p_scaled >= 1.75, 2.0, q) + q = tl.where(p_scaled >= 2.50, 3.0, q) + q = tl.where(p_scaled >= 3.50, 4.0, q) + q = tl.where(p_scaled >= 5.00, 6.0, q) + return q * scale # dequantize back to original scale + + +# NVFP4 E2M1 per-group microscaling helper (SageAttention v3 style) +# --------------------------------------------------------------------------- +@triton.jit +def _quantize_vec_mx_nvfp4(x, ROWS: tl.constexpr, COLS: tl.constexpr, MX_BLOCK: tl.constexpr): + """Per-group NVFP4 E2M1 microscaling quantization (SageAttention v3). + + Quantizes ``x [ROWS, COLS]`` in groups of ``MX_BLOCK`` consecutive elements + along the last axis (i.e. along the head-dimension for Q/K/V, or along the + KV axis for P). Each group of ``MX_BLOCK`` elements shares one scale: + ``scale = max(|x|) / 6.0`` — a float32 approximation of the paper's FP8-E8M0 + per-block scale. Supports both positive and negative values (Q, K, V are + not bounded to [0, 1] like P). + + The quantized levels are the 15 symmetric NVFP4 E2M1 values: + ``{±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6, 0}``. + + Usage: + * Q ``[BLOCK_M, BLOCK_D]``: call directly. + * K^T ``[BLOCK_D, BLOCK_N]``: call as + ``tl.trans(_quantize_vec_mx_nvfp4(tl.trans(k), BLOCK_N, BLOCK_D, 16))`` + so that the 16-element groups run along BLOCK_D within each key vector. + * V ``[BLOCK_N, BLOCK_D]``: call directly. + * P ``[BLOCK_M, BLOCK_N]``: call directly (groups of 16 along KV axis). + """ + NUM_GROUPS: tl.constexpr = ROWS * COLS // MX_BLOCK + x_grouped = tl.reshape(x, (NUM_GROUPS, MX_BLOCK)) + # Per-group scale: max(|x|) / 6.0 + x_max = tl.max(tl.abs(x_grouped), axis=1) # [NUM_GROUPS] + scale = tl.maximum(x_max, 1e-12) / 6.0 # [NUM_GROUPS] + x_norm = x_grouped / scale[:, None] # [NUM_GROUPS, MX_BLOCK] + # Symmetric NVFP4 E2M1 boundary-compare + # Boundaries: ±0.25, ±0.75, ±1.25, ±1.75, ±2.5, ±3.5, ±5.0 + x_abs = tl.abs(x_norm) + sign = tl.where(x_norm >= 0.0, 1.0, -1.0) + q = tl.full((NUM_GROUPS, MX_BLOCK), 0.0, dtype=tl.float32) + q = tl.where(x_abs >= 0.25, 0.5 * sign, q) + q = tl.where(x_abs >= 0.75, 1.0 * sign, q) + q = tl.where(x_abs >= 1.25, 1.5 * sign, q) + q = tl.where(x_abs >= 1.75, 2.0 * sign, q) + q = tl.where(x_abs >= 2.50, 3.0 * sign, q) + q = tl.where(x_abs >= 3.50, 4.0 * sign, q) + q = tl.where(x_abs >= 5.00, 6.0 * sign, q) + return tl.reshape(q * scale[:, None], (ROWS, COLS)) + + # --------------------------------------------------------------------------- # Forward kernel # --------------------------------------------------------------------------- @@ -252,6 +414,20 @@ def _attn_fwd( DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores + QUANTIZE_P: tl.constexpr = False, # Quantize post-softmax p tile to NVFP4 E2M1 (per-tile, STE) + QUANTIZE_QKV: tl.constexpr = False, # SageAttn-v3: per-group MX NVFP4 for Q/K/V + finer P + IS_PAGED: tl.constexpr = False, # Whether K/V are in paged cache + K_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged K + V_cache=None, # [num_blocks, page_size, num_kv_heads, head_dim] paged V + Block_table=None, # [batch, max_blocks_per_seq] page table + stride_kc_block=0, + stride_kc_pos=0, + stride_kc_head=0, + stride_vc_block=0, + stride_vc_pos=0, + stride_vc_head=0, + PAGE_SIZE: tl.constexpr = 16, + max_blocks_per_seq=0, ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -280,6 +456,9 @@ def _attn_fwd( # --- Load Q tile [BLOCK_M, BLOCK_D]: stays in SRAM for the entire KV loop --- q_ptrs = (q_offset + q_pos[:, None]) * stride_qbs + head_idx * stride_qh + dim_pos[None, :] q = tl.load(Q + q_ptrs, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :], other=0.0) + if QUANTIZE_QKV: + # SageAttn-v3: per-group MX NVFP4 quantization of Q (groups of 16 along head_dim) + q = _quantize_vec_mx_nvfp4(q, BLOCK_M, BLOCK_D, 16) # Base pointers for K and V at this KV head (per-tile offset added in loop) k_base = K + kv_head_idx * stride_kh @@ -298,12 +477,37 @@ def _attn_fwd( kv_start = tl.multiple_of(kv_start, BLOCK_N) # Compiler hint for alignment # Load K^T [BLOCK_D, BLOCK_N] (transposed layout for Q @ K^T matmul) - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) + if IS_PAGED: + k = _load_paged_k_tile( + K_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_kc_block, + stride_kc_pos, + stride_kc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) + + if QUANTIZE_QKV: + # SageAttn-v3: per-group MX NVFP4 for K (groups of 16 along head_dim). + # K^T is [BLOCK_D, BLOCK_N]; transpose so groups run along BLOCK_D per key vector. + k = tl.trans(_quantize_vec_mx_nvfp4(tl.trans(k), BLOCK_N, BLOCK_D, 16)) # scores = Q @ K^T * scale [BLOCK_M, BLOCK_N] scores = tl.dot(q, k) * qk_scale @@ -355,12 +559,39 @@ def _attn_fwd( row_sum = row_sum * correction + l_new acc = acc * correction[:, None] - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[ + None, : + ] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + if QUANTIZE_QKV: + v = _quantize_vec_mx_nvfp4(v, BLOCK_N, BLOCK_D, 16) + p = _quantize_vec_mx_nvfp4(p, BLOCK_M, BLOCK_N, 16) + elif QUANTIZE_P: + p = _quantize_p_nvfp4(p, BLOCK_M, BLOCK_N) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new # else: tile skipped: no softmax computation, V load, and BMM2 computation @@ -375,12 +606,37 @@ def _attn_fwd( acc = acc * correction[:, None] # Load V and accumulate - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) + if IS_PAGED: + v = _load_paged_v_tile( + V_cache, + Block_table, + batch_idx, + kv_head_idx, + kv_start, + kv_pos, + dim_pos, + seq_len_kv, + stride_vc_block, + stride_vc_pos, + stride_vc_head, + PAGE_SIZE, + BLOCK_N, + BLOCK_D, + HEAD_DIM, + max_blocks_per_seq, + ) + else: + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + if QUANTIZE_QKV: + v = _quantize_vec_mx_nvfp4(v, BLOCK_N, BLOCK_D, 16) + p = _quantize_vec_mx_nvfp4(p, BLOCK_M, BLOCK_N, 16) + elif QUANTIZE_P: + p = _quantize_p_nvfp4(p, BLOCK_M, BLOCK_N) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new @@ -768,6 +1024,12 @@ def forward( num_sink_tokens, dense_window_size, skip_softmax_threshold, + k_cache, + v_cache, + block_table, + page_size, + quantize_p, + quantize_qkv, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -775,6 +1037,8 @@ def forward( kv_group_num = num_q_heads // num_kv_heads batch = b_seq_len.shape[0] + is_paged = k_cache is not None + # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V. # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata. if b_seq_len_k is None: @@ -782,17 +1046,33 @@ def forward( b_start_loc_k = b_start_loc max_input_len_k = max_input_len + if is_paged: + if v_cache is None or block_table is None: + raise ValueError("k_cache, v_cache, and block_table must all be provided together") + # Paged mode: b_start_loc_k is never dereferenced, but Triton still needs a tensor. + if b_start_loc_k is None: + b_start_loc_k = torch.zeros_like(b_start_loc) + elif b_start_loc_k is None and b_seq_len_k is not None: + raise ValueError( + "b_start_loc_k is required when K/V are passed as a separate packed tensor" + ) + # Pre-multiply scale by log2(e) so the kernel can use exp2() # exp(score * sm_scale) = exp2(score * sm_scale * log2(e)) qk_scale = sm_scale * LOG2E # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax: convert threshold to scaled log2 space for the kernel. - # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks - # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space - # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we - # pre-scale: threshold_scaled = log2(lambda) * sm_scale. + # Skip-softmax: convert lambda threshold to log2 space for the kernel. + # The threshold is scaled by sm_scale to control sparsity relative to + # head dimension: larger head_dim → smaller sm_scale → more aggressive + # skipping for the same lambda value. + if (quantize_p or quantize_qkv) and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "quantize_p / quantize_qkv support inference only; " + "backward does not model the quantized path" + ) + apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 if apply_skip: skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale @@ -839,6 +1119,20 @@ def grid(META): DENSE_WINDOW_SIZE=dense_window_size, APPLY_SKIP_SOFTMAX=apply_skip, SKIP_THRESHOLD_LOG2=skip_threshold_log2, + QUANTIZE_P=quantize_p, + QUANTIZE_QKV=quantize_qkv, + IS_PAGED=is_paged, + K_cache=k_cache, + V_cache=v_cache, + Block_table=block_table, + stride_kc_block=k_cache.stride(0) if is_paged else 0, + stride_kc_pos=k_cache.stride(1) if is_paged else 0, + stride_kc_head=k_cache.stride(2) if is_paged else 0, + stride_vc_block=v_cache.stride(0) if is_paged else 0, + stride_vc_pos=v_cache.stride(1) if is_paged else 0, + stride_vc_head=v_cache.stride(2) if is_paged else 0, + PAGE_SIZE=page_size, + max_blocks_per_seq=block_table.shape[1] if is_paged else 0, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) @@ -859,6 +1153,8 @@ def grid(META): ctx.dense_window_size = dense_window_size ctx.apply_skip = apply_skip ctx.skip_threshold_log2 = skip_threshold_log2 + ctx.quantize_p = quantize_p + ctx.quantize_qkv = quantize_qkv return o @staticmethod @@ -972,19 +1268,25 @@ def backward(ctx, grad_output): dq, dk, dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, + None, # b_start_loc + None, # b_seq_len + None, # max_input_len + None, # is_causal + None, # sm_scale + None, # b_start_loc_k + None, # b_seq_len_k + None, # max_input_len_k + None, # sparsity_n + None, # sparsity_m + None, # num_sink_tokens + None, # dense_window_size + None, # skip_softmax_threshold + None, # k_cache + None, # v_cache + None, # block_table + None, # page_size + None, # quantize_p + None, # quantize_qkv ) @@ -1006,8 +1308,14 @@ def attention( num_sink_tokens: int = 0, dense_window_size: int = 64, skip_softmax_threshold: float | None = None, + quantize_p: bool = False, + quantize_qkv: bool = False, + k_cache: torch.Tensor | None = None, + v_cache: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + page_size: int = 16, ) -> torch.Tensor: - """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. + """Variable-length flash attention with GQA, autograd, optional sparsity, and paged KV. Args: q: [total_q_tokens, num_q_heads, head_dim] @@ -1030,13 +1338,27 @@ def attention( (attention sinks). Absolute token count, BLOCK_N-independent. dense_window_size: Tokens near the query diagonal kept dense (local attention window). Absolute token count, BLOCK_N-independent. - Default 64 (one reference block). + Default 64 tokens. skip_softmax_threshold: BLASST threshold lambda (https://arxiv.org/pdf/2512.12087). Skip KV tiles where ``exp(tile_max - running_max) < lambda``, meaning the tile's softmax contribution is negligible. Tiles are skipped entirely (no softmax, V load, or BMM2). The threshold is applied on unscaled scores. Set to ``None`` or ``0`` to disable. + quantize_p: If ``True``, quantize the post-softmax P tile to NVFP4 + E2M1 before the P @ V matmul (per-tile max scaling, STE). + Default ``False``. + quantize_qkv: If ``True``, apply SageAttention-v3-style per-group + microscaling NVFP4 to Q, K, V (groups of 16 along head_dim) and + per-group NVFP4 to P (groups of 16 along the KV axis). Supersedes + ``quantize_p`` when both are set. Default ``False``. + k_cache: Paged K cache [num_blocks, page_size, num_kv_heads, head_dim]. + When provided, K/V are read from paged cache via block_table + instead of from contiguous k/v tensors. + v_cache: Paged V cache [num_blocks, page_size, num_kv_heads, head_dim]. + block_table: Page table [batch, max_blocks_per_seq] mapping sequence + block indices to global page IDs. + page_size: Number of tokens per page in the KV cache. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -1059,7 +1381,232 @@ def attention( num_sink_tokens, dense_window_size, skip_softmax_threshold, + k_cache, + v_cache, + block_table, + page_size, + quantize_p, + quantize_qkv, ) -__all__ = ["attention"] +# --------------------------------------------------------------------------- +# Calibration kernel: collect multi-threshold skip-softmax sparsity stats +# --------------------------------------------------------------------------- +@triton.jit +def _attn_fwd_calibrate( + Q, + K, + V, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + Threshold_trials, # [NUM_THRESHOLDS] float32 — pre-scaled to log2 space + Sparsity_counters, # [NUM_THRESHOLDS * 2] int64 — [total, skipped] per threshold + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_THRESHOLDS: tl.constexpr, +): + """Forward kernel with multi-threshold sparsity measurement. + + Computes full attention (no skipping) while counting how many KV tiles + would be skipped at each threshold. Statistics are collected via atomic + adds to ``Sparsity_counters[t*2]`` (total tiles) and + ``Sparsity_counters[t*2+1]`` (skipped tiles). + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + tile_q = tl.program_id(2) + kv_head_idx = head_idx // kv_group_num + + seq_len_q = tl.load(b_seq_len + batch_idx) + seq_len_kv = tl.load(b_seq_len_k + batch_idx) + q_offset = tl.load(b_start_loc + batch_idx) + kv_offset = tl.load(b_start_loc_k + batch_idx) + + if tile_q * BLOCK_M >= seq_len_q: + return + + q_pos = tile_q * BLOCK_M + tl.arange(0, BLOCK_M) + kv_pos = tl.arange(0, BLOCK_N) + dim_pos = tl.arange(0, BLOCK_D) + d_mask = dim_pos < HEAD_DIM + + q_ptrs = (q_offset + q_pos[:, None]) * stride_qbs + head_idx * stride_qh + dim_pos[None, :] + q = tl.load(Q + q_ptrs, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :], other=0.0) + + k_base = K + kv_head_idx * stride_kh + v_base = V + kv_head_idx * stride_vh + + row_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + row_sum = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) + + kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + + for kv_start in range(0, kv_bound, BLOCK_N): + kv_start = tl.multiple_of(kv_start, BLOCK_N) + + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) + + scores = tl.dot(q, k) * qk_scale + scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + tile_row_max = tl.max(scores, 1) + + # --- Multi-threshold sparsity measurement --- + for t in range(NUM_THRESHOLDS): + thresh = tl.load(Threshold_trials + t) + can_skip = tile_row_max < (row_max + thresh) + skip_tile = tl.min(can_skip.to(tl.int32)) == 1 + tl.atomic_add(Sparsity_counters + t * 2, 1) # total tiles + if skip_tile: + tl.atomic_add(Sparsity_counters + t * 2 + 1, 1) # skipped tiles + + # --- Always compute full attention (no skipping) --- + m_new = tl.maximum(row_max, tile_row_max) + p = tl.math.exp2(scores - m_new[:, None]) + l_new = tl.sum(p, 1) + correction = tl.math.exp2(row_max - m_new) + row_sum = row_sum * correction + l_new + acc = acc * correction[:, None] + + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + row_max = m_new + + acc = acc / row_sum[:, None] + o_ptrs = (q_offset + q_pos[:, None]) * stride_obs + head_idx * stride_oh + dim_pos[None, :] + tl.store(Out + o_ptrs, acc, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :]) + + +def attention_calibrate( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, + *, + threshold_trials: list[float] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Flash attention with multi-threshold skip-softmax sparsity measurement. + + Computes full attention (identical output to dense attention) while + measuring how many KV tiles would be skipped at each threshold in + ``threshold_trials``. No autograd — forward only. + + Args: + q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal, + softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k: + Same as :func:`attention`. + threshold_trials: List of threshold values to measure sparsity for. + Each value is converted to log2-scaled space for the kernel. + + Returns: + Tuple of (output, sparsity_counters): + - output: ``[total_q_tokens, num_q_heads, head_dim]`` + - sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where + ``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles. + Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``. + """ + if threshold_trials is None or len(threshold_trials) == 0: + raise ValueError("threshold_trials must be a non-empty list") + + HEAD_DIM = q.shape[2] + num_q_heads = q.shape[1] + num_kv_heads = k.shape[1] + kv_group_num = num_q_heads // num_kv_heads + batch = b_seq_len.shape[0] + sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale + qk_scale = sm_scale * LOG2E + BLOCK_D = triton.next_power_of_2(HEAD_DIM) + BLOCK_M = 128 + BLOCK_N = 64 + + if b_seq_len_k is None: + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + + num_thresholds = len(threshold_trials) + + # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale + threshold_tensor = torch.tensor( + [math.log2(t) * sm_scale for t in threshold_trials], + dtype=torch.float32, + device=q.device, + ) + + # Atomic counters: [num_thresholds * 2] — flat layout [total_0, skip_0, total_1, skip_1, ...] + sparsity_counters = torch.zeros(num_thresholds * 2, dtype=torch.int64, device=q.device) + + o = torch.empty_like(q) + + grid = (batch, num_q_heads, triton.cdiv(max_input_len, BLOCK_M)) + + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + sparsity_counters, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + num_warps=4, + num_stages=1, + ) + + # Reshape to [num_thresholds, 2] + return o, sparsity_counters.view(num_thresholds, 2) + + +__all__ = ["attention", "attention_calibrate"] diff --git a/modelopt/torch/quantization/__init__.py b/modelopt/torch/quantization/__init__.py index 87dbf30bb5..a8a2fc6a72 100644 --- a/modelopt/torch/quantization/__init__.py +++ b/modelopt/torch/quantization/__init__.py @@ -24,4 +24,5 @@ from .conversion import * from .model_quant import * from .nn.modules.quant_module import QuantModuleRegistry +from .sage_attention import apply_sage_attention, apply_sage_attention_v3 from .utils import update_quant_cfg_with_kv_cache_quant diff --git a/modelopt/torch/quantization/plugins/diffusion/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py index f9ae55b3e2..c5b1cd61c2 100644 --- a/modelopt/torch/quantization/plugins/diffusion/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -197,8 +197,8 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin) - QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin) + QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin) if Flux2Attention is not None: QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin) if Flux2ParallelSelfAttention is not None: diff --git a/modelopt/torch/quantization/sage_attention/__init__.py b/modelopt/torch/quantization/sage_attention/__init__.py new file mode 100644 index 0000000000..1909b12374 --- /dev/null +++ b/modelopt/torch/quantization/sage_attention/__init__.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SageAttention-style attention quantization for diffusers models. + +``apply_sage_attention`` patches a diffusers transformer to quantize the +post-softmax P tile to NVFP4 E2M1 inside ModelOpt's Triton flash-attention +kernel (``quantize_p=True``). This is purely a **quantization** feature — +it is independent of, and can be freely combined with, the sparse attention +methods in ``modelopt.torch.sparsity.attention_sparsity``. + +Design +------ +SageAttention wraps the transformer's ``forward`` once: + +1. Before the forward, it sets ``quantize_p=True`` in a thread-local store + that the Triton kernel reads. +2. It activates the ``modelopt_triton`` diffusers attention backend for the + duration of the forward pass so that attention calls are routed to the + ModelOpt Triton kernel. +3. After the forward (``finally`` block), it resets ``quantize_p=False``. + +Sparse attention methods (skip-softmax / N:M sparse softmax) manage their +own thread-local params (threshold, sparsity_n/m, …) and deliberately **do +not touch** ``quantize_p``, enabling transparent combination: + +.. code-block:: python + + import modelopt.torch.sparsity.attention_sparsity as mtsa + from modelopt.torch.quantization import apply_sage_attention + + # SageAttention standalone — NVFP4 P-matrix quantization only + apply_sage_attention(transformer) + + # Combined with N:M sparse softmax + mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT) + apply_sage_attention(transformer) + + # Combined with skip-softmax tile pruning + mtsa.sparsify(transformer, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT) + apply_sage_attention(transformer) + +Supported models +---------------- +Currently targets **diffusers** transformer models (WAN, LTX, …) that use +the diffusers attention-dispatch mechanism. The ``modelopt_triton`` backend +is registered in ``diffusers._AttentionBackendRegistry`` on first call. + +Requirements +------------ +- CUDA GPU + Triton installed +- ``modelopt.torch.sparsity.attention_sparsity`` (provides the Triton kernel + and diffusers backend registration) +""" + +import torch + +__all__ = ["apply_sage_attention", "apply_sage_attention_v3"] + + +def apply_sage_attention( + transformer: torch.nn.Module, + quantize_p: bool = True, +) -> None: + """Patch a diffusers transformer to use NVFP4 P-matrix quantization. + + Wraps ``transformer.forward`` so that every call activates the + ``modelopt_triton`` diffusers attention backend with ``quantize_p=True`` + inside the Triton flash-attention kernel. + + This is a standalone quantization feature and does not depend on or + conflict with ``mtsa.sparsify()``. Both can be applied to the same + transformer — sparsity parameters and quantization parameters are stored + in independent thread-local slots. + + Args: + transformer: A diffusers transformer module (e.g. ``pipe.transformer`` + for WAN2.2 / LTX Video). + quantize_p: If True (default), quantize the post-softmax P tile to + NVFP4 E2M1 with per-tile max scaling inside the Triton kernel. + + Raises: + ImportError: If ``modelopt.torch.sparsity.attention_sparsity`` is not + installed (required for the Triton kernel and diffusers backend). + """ + try: + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + clear_sage_attention_config, + get_triton_attention_backend, + register_diffusers_triton_attention, + set_sage_attention_config, + ) + except ImportError as exc: + raise ImportError( + "apply_sage_attention requires modelopt.torch.sparsity.attention_sparsity " + "(Triton kernel + diffusers backend). Install modelopt with the [all] extra " + "or ensure triton is available." + ) from exc + + register_diffusers_triton_attention() + + original_forward = transformer.forward + + def _sage_forward(*args, **kwargs): + set_sage_attention_config(quantize_p=quantize_p) + with get_triton_attention_backend(): + try: + return original_forward(*args, **kwargs) + finally: + clear_sage_attention_config() + + transformer.forward = _sage_forward + transformer._modelopt_sage_attention = True # mark for inspection + + q_str = "NVFP4 E2M1" if quantize_p else "disabled" + print( + f"[ModelOpt] SageAttention applied: quantize_p={quantize_p} ({q_str} P-tile quantization)" + ) + + +def apply_sage_attention_v3(transformer: torch.nn.Module) -> None: + """Patch a diffusers transformer with SageAttention v3 microscaling NVFP4. + + Wraps ``transformer.forward`` to quantize **Q, K, V, and P** to NVFP4 E2M1 + using per-group microscaling (groups of 16 elements along the head dimension), + following the SageAttention v3 paper (arxiv 2505.11594). + + Compared to :func:`apply_sage_attention` (which only quantizes P with per-tile + scaling), this also quantizes Q, K, and V with finer per-group scales, targeting + Blackwell / Ada GPUs where FP4 tensor cores provide maximum throughput. + + Args: + transformer: A diffusers transformer module (e.g. ``pipe.transformer``). + + Raises: + ImportError: If ``modelopt.torch.sparsity.attention_sparsity`` is not + installed (required for the Triton kernel and diffusers backend). + """ + try: + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + clear_sage_attention_config, + get_triton_attention_backend, + register_diffusers_triton_attention, + set_sage_attention_config, + ) + except ImportError as exc: + raise ImportError( + "apply_sage_attention_v3 requires modelopt.torch.sparsity.attention_sparsity " + "(Triton kernel + diffusers backend). Install modelopt with the [all] extra " + "or ensure triton is available." + ) from exc + + register_diffusers_triton_attention() + + original_forward = transformer.forward + + def _sage_v3_forward(*args, **kwargs): + set_sage_attention_config(quantize_p=False, quantize_qkv=True) + with get_triton_attention_backend(): + try: + return original_forward(*args, **kwargs) + finally: + clear_sage_attention_config() + + transformer.forward = _sage_v3_forward + transformer._modelopt_sage_attention_v3 = True # mark for inspection + print("[ModelOpt] SageAttention v3 applied: per-group MX NVFP4 on Q, K, V, and P") diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index dbc4d5bc27..3215a7530b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -21,7 +21,6 @@ import torch import torch.nn as nn -from transformers import AutoTokenizer from modelopt.torch.utils import get_module_device @@ -32,8 +31,10 @@ from .ruler_dataset import RulerDatasetBuilder -def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer": +def _load_tokenizer(tokenizer_name_or_path: str): """Load tokenizer and ensure pad_token is set.""" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token @@ -255,11 +256,14 @@ def calibrate_sparse_attention( print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") - # Extract tokenizer and build calibration data if needed - tokenizer = _extract_tokenizer_from_model(model) + # Extract tokenizer and build calibration data only if no forward_loop is provided. + # When the user supplies their own forward_loop (e.g. for diffusion models), + # RULER dataset generation is skipped entirely. + tokenizer = None calibration_data = None - if calibrate_prefill or calibrate_decode: + if forward_loop is None and (calibrate_prefill or calibrate_decode): + tokenizer = _extract_tokenizer_from_model(model) builder = RulerDatasetBuilder( samples=calib_config.samples, max_seqlen=calib_config.max_seqlen, @@ -280,11 +284,15 @@ def calibrate_sparse_attention( print("PREFILL PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: + if forward_loop is None and calibration_data is None: raise RuntimeError("calibration_data must be built before prefill") - prefill_forward_loop = forward_loop or create_calibration_forward_loop( - calibration_data, tokenizer, chunk_size=calib_config.chunk_size - ) + if forward_loop is not None: + prefill_forward_loop = forward_loop + else: + assert calibration_data is not None and tokenizer is not None + prefill_forward_loop = create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) prefill_calibrator = DynamicThresholdCalibrator( threshold_trials=calib_config.threshold_trials, @@ -302,8 +310,8 @@ def calibrate_sparse_attention( print("DECODE PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: - raise RuntimeError("calibration_data must be built before decode") + if calibration_data is None or tokenizer is None: + raise RuntimeError("calibration_data and tokenizer must be built before decode") decode_forward_loop = create_decode_calibration_forward_loop( calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens ) diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 6821206937..f2bb8831d7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -333,6 +333,16 @@ def _extract_calibration_stats( return aggregated_stats def _set_thresholds(self, modules: list[nn.Module], thresholds: list[float]): - """Set thresholds list on sparse attention modules.""" + """Set thresholds list on sparse attention modules. + + Supports both flash_skip_softmax (sets ``thresholds`` attribute) and + triton_skip_softmax (sets ``_threshold_trials`` attribute). + """ for module in modules: - module._sparse_method_instance.thresholds = thresholds + method = module._sparse_method_instance + if hasattr(method, "_threshold_trials"): + # triton_skip_softmax: calibration uses Triton calibration kernel + method._threshold_trials = thresholds + else: + # flash_skip_softmax: calibration uses F.softmax patching + method.thresholds = thresholds diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index fa415b322b..ff3ca2d8fc 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -155,7 +155,8 @@ def validate_backend(cls, v): raise ValueError( f"Invalid backend: {v}. Supported backends: 'pytorch' (requires " f"attn_implementation='eager'), 'triton' (requires " - f"attn_implementation='modelopt_triton')." + f"attn_implementation='modelopt_triton' for HuggingFace models or the " + f"modelopt_triton diffusers backend for diffusers models)." ) return v diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index ccd69e1195..3a1d52fbd4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -115,6 +115,45 @@ def is_attn_sparsified(model: nn.Module) -> bool: return any(isinstance(module, SparseAttentionModule) for module in model.modules()) +def _register_diffusers_backends_if_needed(model: nn.Module) -> None: + """Register diffusers/LTX attention backends if the model needs them. + + Called before plugin registration so that the backends are available + when ``SparseAttentionModule.forward()`` activates the skip-softmax context. + """ + # Register the diffusers eager and Triton backends if the model is a diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + from .kernels import ( + register_diffusers_eager_attention, + register_diffusers_triton_attention, + ) + + if register_diffusers_eager_attention is not None: + register_diffusers_eager_attention() + if register_diffusers_triton_attention is not None: + register_diffusers_triton_attention() + except (ImportError, Exception): + pass + + # Patch ltx_core Attention modules if present (independent of diffusers) + import contextlib + + try: + from .kernels import register_ltx_eager_attention, register_ltx_triton_attention + except (ImportError, RuntimeError): + return + + if register_ltx_eager_attention is not None: + with contextlib.suppress(Exception): + register_ltx_eager_attention(model) + if register_ltx_triton_attention is not None: + with contextlib.suppress(Exception): + register_ltx_triton_attention(model) + + def convert_to_sparse_attention_model( model: ModelLikeModule, config: SparseAttentionConfig ) -> ConvertReturnType: @@ -130,6 +169,9 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + # Register diffusers backends for diffusion models + _register_diffusers_backends_if_needed(model) + # Set the correct attn_implementation for the chosen backend _set_attn_implementation(model, config) @@ -484,6 +526,8 @@ def print_sparse_attention_summary(model: nn.Module): # Group by (method, threshold) groups: dict[tuple[str, str], int] = {} for _, module in sparse_modules: + if not module.is_enabled: + continue method = getattr(module, "_method", "unknown") threshold = _format_threshold(module.get_threshold_info()) groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py index dee1bc472a..81f4295bb4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py @@ -13,12 +13,61 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-exports from modelopt.torch.kernels for backward compatibility.""" +"""Kernel integrations for sparse attention: Triton FA and diffusers backends.""" +import contextlib +import threading + +# --------------------------------------------------------------------------- +# Triton FA kernel re-exports (for HuggingFace LLM integration) +# --------------------------------------------------------------------------- from modelopt.torch.kernels import IS_AVAILABLE, attention, register_triton_attention +# --------------------------------------------------------------------------- +# Thread-local context: shared by diffusers eager and Triton backends +# --------------------------------------------------------------------------- +_thread_local = threading.local() + + +def set_skip_softmax_context(active: bool) -> None: + """Set thread-local flag indicating skip-softmax eager attention is active.""" + _thread_local.skip_softmax_active = active + + +def get_skip_softmax_context() -> bool: + """Return True if skip-softmax eager attention is active in this thread.""" + return getattr(_thread_local, "skip_softmax_active", False) + + +# --------------------------------------------------------------------------- +# Optional backend registrations (depend on diffusers / ltx_core) +# --------------------------------------------------------------------------- +register_diffusers_eager_attention = None +register_diffusers_triton_attention = None +register_ltx_eager_attention = None +register_ltx_triton_attention = None + +# Suppress ImportError (missing package) and RuntimeError (triton without GPU driver) +with contextlib.suppress(ImportError, RuntimeError): + from .diffusers_eager_attention import register_diffusers_eager_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .diffusers_triton_attention import register_diffusers_triton_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .ltx_eager_attention import register_ltx_eager_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .ltx_triton_attention import register_ltx_triton_attention + __all__ = [ "IS_AVAILABLE", "attention", + "get_skip_softmax_context", + "register_diffusers_eager_attention", + "register_diffusers_triton_attention", + "register_ltx_eager_attention", + "register_ltx_triton_attention", "register_triton_attention", + "set_skip_softmax_context", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py new file mode 100644 index 0000000000..16dd895f27 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Eager attention backend for diffusers skip-softmax sparse attention. + +Registers a ``modelopt_skip_softmax`` backend in diffusers' +``_AttentionBackendRegistry`` that computes attention eagerly with an explicit +``F.softmax`` call. This allows the existing softmax-patching mechanism in +``SparseAttentionModule`` to intercept and apply block-wise sparsity. + +Used during **calibration only** — inference uses the Triton FA kernel. +""" + +import inspect +import math + +import torch +import torch.nn.functional as F +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _AttentionBackendRegistry, + attention_backend, +) + +_BACKEND_NAME = "modelopt_skip_softmax" +_BACKEND_REGISTERED = False + + +# --------------------------------------------------------------------------- +# Eager attention implementation +# --------------------------------------------------------------------------- + + +def _diffusers_eager_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> torch.Tensor: + """Compute attention eagerly on diffusers layout ``[B, S, H, D]``. + + The explicit ``F.softmax`` call is what the skip-softmax patch intercepts. + """ + # Diffusers convention: [B, S, H, D] → transpose to [B, H, S, D] + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Handle GQA: repeat K/V heads to match Q heads + if enable_gqa and query.shape[1] != key.shape[1]: + num_heads_q = query.shape[1] + num_heads_kv = key.shape[1] + n_rep = num_heads_q // num_heads_kv + key = key.repeat_interleave(n_rep, dim=1) + value = value.repeat_interleave(n_rep, dim=1) + + if scale is None: + scale = 1.0 / math.sqrt(query.shape[-1]) + + # Q @ K^T * scale + scores = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply attention mask if provided + if attn_mask is not None: + scores = scores + attn_mask + + # Apply causal mask if needed + if is_causal: + seq_q, seq_k = scores.shape[-2], scores.shape[-1] + causal_mask = torch.triu( + torch.full((seq_q, seq_k), float("-inf"), device=scores.device, dtype=scores.dtype), + diagonal=seq_k - seq_q + 1, + ) + scores = scores + causal_mask + + # F.softmax — this is where the skip-softmax patch intercepts + scores = F.softmax(scores, dim=-1) + + if dropout_p > 0.0: + scores = F.dropout(scores, p=dropout_p, training=True) + + # scores @ V + out = torch.matmul(scores, value) + + # Transpose back: [B, H, S, D] → [B, S, H, D] + out = out.transpose(1, 2) + return out + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def register_diffusers_eager_attention() -> None: + """Register ``modelopt_skip_softmax`` backend in diffusers. + + Safe to call multiple times; registration happens only once. + """ + global _BACKEND_REGISTERED + if _BACKEND_REGISTERED: + return + + # Extend the AttentionBackendName enum with our custom value + new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) + new_member._name_ = "MODELOPT_SKIP_SOFTMAX" + new_member._value_ = _BACKEND_NAME + AttentionBackendName._member_map_["MODELOPT_SKIP_SOFTMAX"] = new_member + AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member + + # Register the backend function + _AttentionBackendRegistry._backends[new_member] = _diffusers_eager_attention + _AttentionBackendRegistry._constraints[new_member] = [] + _AttentionBackendRegistry._supported_arg_names[new_member] = set( + inspect.signature(_diffusers_eager_attention).parameters.keys() + ) + + _BACKEND_REGISTERED = True + + +def get_skip_softmax_attention_backend(): + """Return a context manager that activates the modelopt_skip_softmax backend. + + Raises RuntimeError if the backend has not been registered yet. + """ + if not _BACKEND_REGISTERED: + raise RuntimeError( + "modelopt_skip_softmax backend not registered. " + "Call register_diffusers_eager_attention() first." + ) + return attention_backend(_BACKEND_NAME) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py new file mode 100644 index 0000000000..ed71cb4cf7 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention backend for diffusers models. + +Registers a ``modelopt_triton`` backend in diffusers' ``_AttentionBackendRegistry`` +that converts the diffusers [B, S, H, D] layout to the Triton FA kernel's varlen +[total_tokens, H, D] format. + +Two modes: +- **Inference**: Calls ``attention()`` with optional skip-softmax tile skipping, + N:M sparse softmax, and/or NVFP4 P-matrix quantization. +- **Calibration**: Calls ``attention_calibrate()`` to collect multi-threshold + sparsity statistics without skipping any tiles. +""" + +import inspect +import math +import threading + +import torch +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _AttentionBackendRegistry, + attention_backend, +) + +from modelopt.torch.kernels import attention, attention_calibrate + +_BACKEND_NAME = "modelopt_triton" +_BACKEND_REGISTERED = False + +# Thread-local storage for per-forward skip-softmax configuration. +_thread_local = threading.local() + + +def set_triton_skip_softmax_config( + threshold: float | None = None, + calibration_mode: bool = False, + threshold_trials: list[float] | None = None, + sparsity_n: int = 0, + sparsity_m: int = 4, + num_sink_tokens: int = 0, + dense_window_size: int = 64, +) -> None: + """Set thread-local sparse attention config for the next Triton attention call. + + This controls skip-softmax tile skipping and N:M sparsity. It does NOT touch + ``quantize_p`` — that is managed independently by :func:`set_sage_attention_config`. + + Args: + threshold: Skip-softmax threshold for inference mode. + calibration_mode: If True, use the calibration kernel to collect + multi-threshold sparsity stats instead of skipping tiles. + threshold_trials: List of thresholds to measure sparsity for + (only used when calibration_mode=True). + sparsity_n: Keep top-N of every M attention scores (0 to disable N:M sparsity). + sparsity_m: Group size for N:M sparsity (4 or 8). + num_sink_tokens: KV positions before this index kept dense (attention sinks). + dense_window_size: Tokens near the diagonal kept dense (absolute token count). + """ + _thread_local.skip_threshold = threshold + _thread_local.calibration_mode = calibration_mode + _thread_local.threshold_trials = threshold_trials + _thread_local.sparsity_n = sparsity_n + _thread_local.sparsity_m = sparsity_m + _thread_local.num_sink_tokens = num_sink_tokens + _thread_local.dense_window_size = dense_window_size + # Accumulated counters across all attention calls in one forward pass + _thread_local.calibration_counters = None + + +def clear_triton_skip_softmax_config() -> None: + """Clear thread-local sparse attention config. + + Only clears skip-softmax / N:M sparsity params. Does NOT reset ``quantize_p`` + so that :func:`set_sage_attention_config` remains active across attention layers. + """ + _thread_local.skip_threshold = None + _thread_local.calibration_mode = False + _thread_local.threshold_trials = None + _thread_local.sparsity_n = 0 + _thread_local.sparsity_m = 4 + _thread_local.num_sink_tokens = 0 + _thread_local.dense_window_size = 64 + _thread_local.calibration_counters = None + + +def set_sage_attention_config( + quantize_p: bool = True, + quantize_qkv: bool = False, +) -> None: + """Set NVFP4 quantization flags for SageAttention. + + Manages ``quantize_p`` and ``quantize_qkv`` independently of sparse-attention + params so either can be combined with skip-softmax / N:M sparsity without + clobbering the other's state. + + Args: + quantize_p: If True, quantize the post-softmax P tile to NVFP4 E2M1 + (per-tile max scaling, SageAttn v1/v2 style). Default True. + quantize_qkv: If True, apply SageAttn-v3-style per-group microscaling NVFP4 + to Q, K, V (groups of 16 along head_dim) and finer per-group NVFP4 to P. + Supersedes ``quantize_p`` when set. Default False. + """ + _thread_local.quantize_p = quantize_p + _thread_local.quantize_qkv = quantize_qkv + + +def clear_sage_attention_config() -> None: + """Clear NVFP4 quantization flags.""" + _thread_local.quantize_p = False + _thread_local.quantize_qkv = False + + +def get_calibration_counters() -> "torch.Tensor | None": + """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" + return getattr(_thread_local, "calibration_counters", None) + + +# --------------------------------------------------------------------------- +# Triton attention implementation for diffusers layout +# --------------------------------------------------------------------------- + + +def _diffusers_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> torch.Tensor: + """Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``.""" + batch, seq_q, num_heads_q, head_dim = query.shape + seq_k = key.shape[1] + device = query.device + + # Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D] + q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous() + k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous() + v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous() + + # Build varlen metadata + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32) + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": is_causal, + "softmax_scale": scale, + } + + if seq_q != seq_k: + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # --- Calibration mode: collect multi-threshold stats --- + calib_mode = getattr(_thread_local, "calibration_mode", False) + if calib_mode: + trials = getattr(_thread_local, "threshold_trials", None) + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) + + # Accumulate counters across all attention calls in this forward pass + prev = getattr(_thread_local, "calibration_counters", None) + if prev is None: + _thread_local.calibration_counters = counters + else: + _thread_local.calibration_counters = prev + counters + + return o.view(batch, seq_q, num_heads_q, head_dim) + + # --- Inference mode: optional skip-softmax, N:M sparsity, and/or NVFP4 quantization --- + threshold = getattr(_thread_local, "skip_threshold", None) + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + sparsity_n = getattr(_thread_local, "sparsity_n", 0) + if sparsity_n > 0: + kw["sparsity_n"] = sparsity_n + kw["sparsity_m"] = getattr(_thread_local, "sparsity_m", 4) + num_sink_tokens = getattr(_thread_local, "num_sink_tokens", 0) + if num_sink_tokens > 0: + kw["num_sink_tokens"] = num_sink_tokens + dense_window_size = getattr(_thread_local, "dense_window_size", 64) + if dense_window_size > 0: + kw["dense_window_size"] = dense_window_size + + quantize_qkv = getattr(_thread_local, "quantize_qkv", False) + if quantize_qkv: + kw["quantize_qkv"] = True + elif getattr(_thread_local, "quantize_p", False): + kw["quantize_p"] = True + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + o = attention(q, k, v, **kw) + return o.view(batch, seq_q, num_heads_q, head_dim) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def register_diffusers_triton_attention() -> None: + """Register ``modelopt_triton`` backend in diffusers. + + Safe to call multiple times; registration happens only once. + """ + global _BACKEND_REGISTERED + if _BACKEND_REGISTERED: + return + + new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) + new_member._name_ = "MODELOPT_TRITON" + new_member._value_ = _BACKEND_NAME + AttentionBackendName._member_map_["MODELOPT_TRITON"] = new_member + AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member + + _AttentionBackendRegistry._backends[new_member] = _diffusers_triton_attention + _AttentionBackendRegistry._constraints[new_member] = [] + _AttentionBackendRegistry._supported_arg_names[new_member] = set( + inspect.signature(_diffusers_triton_attention).parameters.keys() + ) + + _BACKEND_REGISTERED = True + + +def get_triton_attention_backend(): + """Return a context manager that activates the modelopt_triton backend.""" + if not _BACKEND_REGISTERED: + raise RuntimeError( + "modelopt_triton backend not registered. " + "Call register_diffusers_triton_attention() first." + ) + return attention_backend(_BACKEND_NAME) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py new file mode 100644 index 0000000000..6c082ee588 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Eager attention wrapper for LTX-2 (ltx_core) skip-softmax sparse attention. + +Patches ``Attention`` modules from ``ltx_core`` so that when the skip-softmax +thread-local flag is active, attention is computed eagerly with an explicit +``F.softmax`` call that the softmax-patching mechanism can intercept. + +Used during **calibration only** — inference uses the Triton FA kernel via +the diffusers Triton backend. +""" + +import math + +import torch +import torch.nn.functional as F +from ltx_core.model.transformer.attention import Attention + +from . import get_skip_softmax_context + + +def _ltx_eager_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Eager attention on LTX-2 layout ``[B, T, H*D]``. + + Mirrors the ``PytorchAttention`` class in ltx_core but uses an explicit + ``F.softmax`` instead of ``scaled_dot_product_attention``. + """ + b, _, dim_total = q.shape + dim_head = dim_total // heads + + # Reshape to [B, T, H, D] then transpose to [B, H, T, D] + q = q.view(b, -1, heads, dim_head).transpose(1, 2) + k = k.view(b, -1, heads, dim_head).transpose(1, 2) + v = v.view(b, -1, heads, dim_head).transpose(1, 2) + + scale = 1.0 / math.sqrt(dim_head) + + # Q @ K^T * scale + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + + # Apply mask if provided + if mask is not None: + # Expand mask dimensions to match scores [B, H, Sq, Sk] + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + scores = scores + mask + + # F.softmax — intercepted by skip-softmax patch + scores = F.softmax(scores, dim=-1) + + # scores @ V + out = torch.matmul(scores, v) + + # [B, H, T, D] → [B, T, H*D] + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + +class _SkipSoftmaxLTXAttentionWrapper: + """Wraps an ``attention_function`` callable from ltx_core. + + When the thread-local skip-softmax flag is active, routes to the eager + attention path. Otherwise calls the original function. + """ + + def __init__(self, original_fn): + self._original_fn = original_fn + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if get_skip_softmax_context(): + return _ltx_eager_attention(q, k, v, heads, mask) + return self._original_fn(q, k, v, heads, mask) + + +def register_ltx_eager_attention(model: torch.nn.Module) -> None: + """Walk *model* and patch all ``ltx_core.model.transformer.attention.Attention`` modules. + + Patches modules so their ``attention_function`` is routed through the eager wrapper. + Safe to call multiple times on the same model — already-wrapped modules are + skipped. + """ + for module in model.modules(): + if isinstance(module, Attention): + fn = module.attention_function + if not isinstance(fn, _SkipSoftmaxLTXAttentionWrapper): + module.attention_function = _SkipSoftmaxLTXAttentionWrapper(fn) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py new file mode 100644 index 0000000000..8ef2569d5b --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention wrapper for LTX-2 (ltx_core) skip-softmax sparse attention. + +Two modes: +- **Inference**: ``attention()`` with skip-softmax tile skipping. +- **Calibration**: ``attention_calibrate()`` to collect multi-threshold stats. +""" + +import math +import threading + +import torch +from ltx_core.model.transformer.attention import Attention + +from modelopt.torch.kernels import attention, attention_calibrate + +# Thread-local storage for skip-softmax configuration +_thread_local = threading.local() + + +def set_ltx_triton_context( + active: bool, + threshold: float | None = None, + calibration_mode: bool = False, + threshold_trials: list[float] | None = None, +) -> None: + """Set thread-local Triton config for LTX-2 attention.""" + _thread_local.active = active + _thread_local.threshold = threshold + _thread_local.calibration_mode = calibration_mode + _thread_local.threshold_trials = threshold_trials + if not calibration_mode: + _thread_local.calibration_counters = None + + +def clear_ltx_triton_context() -> None: + """Clear thread-local Triton config.""" + _thread_local.active = False + _thread_local.threshold = None + _thread_local.calibration_mode = False + _thread_local.threshold_trials = None + _thread_local.calibration_counters = None + + +def _get_ltx_triton_context() -> tuple[bool, float | None]: + """Return (active, threshold).""" + return ( + getattr(_thread_local, "active", False), + getattr(_thread_local, "threshold", None), + ) + + +def get_calibration_counters() -> "torch.Tensor | None": + """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" + return getattr(_thread_local, "calibration_counters", None) + + +def _ltx_triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + threshold: float | None = None, +) -> torch.Tensor: + """Triton FA attention on LTX-2 layout ``[B, T, H*D]``.""" + b, seq_q, dim_total = q.shape + dim_head = dim_total // heads + seq_k = k.shape[1] + device = q.device + + q_flat = q.view(b, seq_q, heads, dim_head).reshape(b * seq_q, heads, dim_head).contiguous() + k_flat = k.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + v_flat = v.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + + b_start_loc_q = torch.arange(b, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((b,), seq_q, device=device, dtype=torch.int32) + + scale = 1.0 / math.sqrt(dim_head) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": False, + "softmax_scale": scale, + } + + if seq_q != seq_k: + b_start_loc_k = torch.arange(b, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((b,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # --- Calibration mode --- + calib_mode = getattr(_thread_local, "calibration_mode", False) + if calib_mode: + trials = getattr(_thread_local, "threshold_trials", None) + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q_flat, k_flat, v_flat, **kw, threshold_trials=trials) + + prev = getattr(_thread_local, "calibration_counters", None) + if prev is None: + _thread_local.calibration_counters = counters + else: + _thread_local.calibration_counters = prev + counters + + return o.view(b, seq_q, heads * dim_head) + + # --- Inference mode --- + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + o = attention(q_flat, k_flat, v_flat, **kw) + return o.view(b, seq_q, heads * dim_head) + + +class _TritonLTXAttentionWrapper: + """Wraps ltx_core attention_function for Triton dispatch.""" + + def __init__(self, original_fn): + self._original_fn = original_fn + + def __call__(self, q, k, v, heads, mask=None): + active, threshold = _get_ltx_triton_context() + if active: + return _ltx_triton_attention(q, k, v, heads, mask, threshold) + return self._original_fn(q, k, v, heads, mask) + + +def register_ltx_triton_attention(model: torch.nn.Module) -> None: + """Patch all ``ltx_core.Attention`` modules for Triton dispatch.""" + for module in model.modules(): + if isinstance(module, Attention): + fn = module.attention_function + if not isinstance(fn, _TritonLTXAttentionWrapper): + module.attention_function = _TritonLTXAttentionWrapper(fn) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 2501b58f65..aab399292a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from contextlib import ExitStack from typing import Any import numpy as np @@ -369,7 +370,11 @@ def get_threshold_info(self) -> dict[str, Any]: } def get_sparse_context(self, module: torch.nn.Module): - """Return a context manager that patches F.softmax with sparse masking.""" + """Return a context manager that patches F.softmax with sparse masking. + + Also registers the diffusers eager backend so that diffusion models + (which don't call F.softmax directly) route through the patched path. + """ original_softmax = F.softmax def sparse_softmax(input, dim=-1, *args, **kwargs): @@ -379,7 +384,21 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): input = self.apply_sparsity(input, sparse_mask) return original_softmax(input, dim, *args, **kwargs) - return replace_function(torch.nn.functional, "softmax", sparse_softmax) + from ..kernels import set_skip_softmax_context + + stack = ExitStack() + set_skip_softmax_context(True) + stack.callback(set_skip_softmax_context, False) + + try: + from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend + + stack.enter_context(get_skip_softmax_attention_backend()) + except (ImportError, RuntimeError): + pass + + stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) + return stack @property def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 8037146643..3cb4f9010e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -40,6 +40,10 @@ def __init__(self): # Video shape for VSA (T, H, W). None for non-VSA methods. self.video_shape: tuple[int, int, int] | None = None + def set_calibration_mode(self, enabled: bool) -> None: + """Enable or disable calibration mode (called by DynamicThresholdCalibrator).""" + self._calibration_mode = enabled + def forward_attention( self, query: torch.Tensor, diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index 4db51e894e..b55af1f01f 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Skip-softmax method for attention via Triton kernel tile skipping.""" +"""Skip-softmax method for attention via Triton kernel tile skipping. + +Supports two modes: +- **Inference**: KV tiles with negligible scores are skipped in-kernel. +- **Calibration**: The Triton calibration kernel collects multi-threshold + sparsity statistics without skipping any tiles. +""" from contextlib import contextmanager +import torch + from .registry import SparseAttentionMethod, register_sparse_method @@ -39,21 +47,176 @@ def __init__(self, method_config=None): super().__init__() method_config = method_config or {} self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) + # Calibration state + self._threshold_trials: list[float] | None = None @property def name(self) -> str: """Method name identifier.""" return "triton_skip_softmax" + def calculate_sparsity(self, attention_scores): + """Return a no-op mask (skip decision is made inside the Triton kernel).""" + mask = torch.ones_like(attention_scores, dtype=torch.bool) + return mask, {} + + def apply_sparsity(self, attention_scores, sparse_mask=None): + """Not supported — tile skipping is fused into the Triton kernel.""" + raise NotImplementedError( + "triton_skip_softmax applies tile skipping inside the Triton kernel. " + "Use backend='triton', not backend='pytorch'." + ) + def get_sparse_context(self, module): - """Return context manager that activates skip-softmax during forward.""" + """Return context manager that activates skip-softmax during forward. + + In calibration mode, configures the Triton backend to use the + calibration kernel which collects multi-threshold sparsity stats. + In inference mode, sets the skip threshold for tile skipping. + """ + if self._calibration_mode and self._threshold_trials: + return self._triton_calibration_context(module) + return self._triton_inference_context(module) + + @contextmanager + def _triton_inference_context(self, module): + """Inference: activate skip-softmax with calibrated or fixed threshold.""" + module._apply_skip_softmax = True + + # Set threshold on Triton backends + threshold = self._get_effective_threshold(module) + self._set_triton_backends(threshold=threshold) + with self._get_diffusers_backend_context(): + try: + yield + finally: + module._apply_skip_softmax = False + self._clear_triton_backends() - @contextmanager - def _skip_softmax_context(): - module._apply_skip_softmax = True + @contextmanager + def _triton_calibration_context(self, module): + """Calibration: collect multi-threshold sparsity stats via Triton kernel.""" + module._apply_skip_softmax = True + self._set_triton_backends(calibration_mode=True, threshold_trials=self._threshold_trials) + with self._get_diffusers_backend_context(): try: yield + # After forward pass, extract counters and build stats + self._collect_calibration_stats(module) finally: module._apply_skip_softmax = False + self._clear_triton_backends() + + def _get_effective_threshold(self, module) -> float: + """Compute threshold from calibration params or use fixed value.""" + if self.calibration_params and self.target_sparse_ratio: + import math + + params = self.calibration_params.get("prefill", {}) + a = params.get("a", 0) + b = params.get("b", 0) + target = self.target_sparse_ratio.get("prefill", 0.5) + if a > 0 and b > 0: + # scale_factor = a * exp(b * target_sparsity) + # threshold = scale_factor / seqlen + # For diffusion with fixed seqlen, use a representative value. + # The actual seqlen adaptation happens at the kernel level. + scale_factor = a * math.exp(b * target) + # Use a default seqlen estimate; the kernel threshold is in + # absolute space so we just pass the raw threshold. + return scale_factor / 4224 # TODO: pass actual seqlen at runtime + return self.skip_softmax_threshold + + @staticmethod + @contextmanager + def _get_diffusers_backend_context(): + """Activate the modelopt_triton diffusers backend if registered.""" + try: + from ..kernels.diffusers_triton_attention import get_triton_attention_backend + + with get_triton_attention_backend(): + yield + except (ImportError, RuntimeError): + yield + + def _set_triton_backends(self, **kwargs): + """Set config on both diffusers and LTX Triton backends.""" + try: + from ..kernels.diffusers_triton_attention import set_triton_skip_softmax_config + + set_triton_skip_softmax_config(**kwargs) + except ImportError: + pass + try: + from ..kernels.ltx_triton_attention import set_ltx_triton_context + + set_ltx_triton_context(active=True, **kwargs) + except ImportError: + pass + + def _clear_triton_backends(self): + """Clear config on both Triton backends.""" + try: + from ..kernels.diffusers_triton_attention import clear_triton_skip_softmax_config + + clear_triton_skip_softmax_config() + except ImportError: + pass + try: + from ..kernels.ltx_triton_attention import clear_ltx_triton_context + + clear_ltx_triton_context() + except ImportError: + pass + + def _collect_calibration_stats(self, module): + """Read Triton calibration counters and store as stats on the module.""" + counters = None + + try: + from ..kernels.diffusers_triton_attention import get_calibration_counters + + counters = get_calibration_counters() + except ImportError: + pass + + if counters is None: + try: + from ..kernels.ltx_triton_attention import get_calibration_counters + + counters = get_calibration_counters() + except ImportError: + pass + + if counters is None or self._threshold_trials is None: + return + + # counters: [num_thresholds, 2] — [:, 0]=total, [:, 1]=skipped + total = counters[:, 0].float() + skipped = counters[:, 1].float() + sparsity_list = (skipped / total.clamp(min=1)).tolist() + + # Estimate sample_length from total tiles: + # total_tiles = num_heads * num_q_tiles * num_kv_tiles * batch + # For simplicity, use total[0] as a proxy for sequence length scaling + sample_length = int(total[0].item()) + + module._last_stats = { + "sparsity": sparsity_list, + "sample_length": sample_length, + "phase": "prefill", + } - return _skip_softmax_context() + def get_threshold_info(self) -> dict: + """Get threshold information for debugging/display.""" + if self.calibration_params and self.target_sparse_ratio: + return { + "type": "dynamic_calibrated", + "formula": "threshold = a * exp(b * target_sparsity) / seqlen", + "calibration_params": self.calibration_params, + "target_sparse_ratio": self.target_sparse_ratio, + } + return { + "type": "static", + "value": self.skip_softmax_threshold, + } diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py index c0639ed0b5..c017599f2d 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py @@ -59,9 +59,47 @@ def get_sparse_context(self, module): @contextmanager def _sparse_nm_context(): module._apply_sparse_nm = True - try: - yield - finally: - module._apply_sparse_nm = False + self._set_triton_backends( + sparsity_n=self.sparsity_n, + sparsity_m=self.sparsity_m, + num_sink_tokens=self.num_sink_tokens, + dense_window_size=self.dense_window_size, + ) + with self._get_diffusers_backend_context(): + try: + yield + finally: + module._apply_sparse_nm = False + self._clear_triton_backends() return _sparse_nm_context() + + @staticmethod + @contextmanager + def _get_diffusers_backend_context(): + """Activate the modelopt_triton diffusers backend if registered.""" + try: + from ..kernels.diffusers_triton_attention import get_triton_attention_backend + + with get_triton_attention_backend(): + yield + except (ImportError, RuntimeError): + yield + + def _set_triton_backends(self, **kwargs): + """Set config on the diffusers Triton backend.""" + try: + from ..kernels.diffusers_triton_attention import set_triton_skip_softmax_config + + set_triton_skip_softmax_config(**kwargs) + except ImportError: + pass + + def _clear_triton_backends(self): + """Clear config on the diffusers Triton backend.""" + try: + from ..kernels.diffusers_triton_attention import clear_triton_skip_softmax_config + + clear_triton_skip_softmax_config() + except ImportError: + pass diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 599832943d..d26b73f0b4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -16,7 +16,6 @@ """Dynamic sparse attention registration for HuggingFace models.""" import torch.nn as nn -import transformers from modelopt.torch.opt.dynamic import DynamicModule @@ -112,11 +111,22 @@ def _is_supported_model(model: nn.Module) -> bool: """ # Check for HuggingFace PreTrainedModel try: + import transformers + if isinstance(model, transformers.PreTrainedModel): return True except ImportError: pass + # Check for diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + return True + except ImportError: + pass + # Support any PyTorch model with attention modules return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py new file mode 100644 index 0000000000..5903985b30 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ModelOpt sparse attention backend for vLLM. + +Registers a custom vLLM attention backend that uses the ModelOpt Triton kernel +with paged KV cache support. Integration approach: + +- No module replacement — the Attention module stays intact with all its state +- Only ``impl`` is swapped from FlashAttentionImpl to ModelOptSparseAttentionImpl +- KV cache update is handled by vLLM (inherited ``do_kv_cache_update``) +- Only ``forward()`` is overridden to call our Triton kernel for both prefill and decode +""" + +import torch +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, +) + +from modelopt.torch.kernels.triton_fa import attention as triton_attention + + +class ModelOptSparseAttentionImpl(FlashAttentionImpl): + """Attention implementation that uses the ModelOpt Triton kernel. + + Inherits from FlashAttentionImpl to reuse: + - __init__ (all configuration) + - do_kv_cache_update (KV cache writing) + Only overrides forward() to replace the attention computation. + """ + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward with ModelOpt Triton sparse attention kernel.""" + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run + return output.fill_(0) + + num_actual_tokens = attn_metadata.num_actual_tokens + is_prefill = attn_metadata.max_query_len > 1 + + # Unpack paged KV cache: [2, num_blocks, page_size, num_kv_heads, head_dim] + key_cache, value_cache = kv_cache.unbind(0) + page_size = key_cache.shape[1] + + # Per-layer sparse kwargs (set by _replace_attention_impl in the worker) + sparse_kw = getattr(self, "sparse_kw", {}) + + # Prepare metadata for our kernel + q = query[:num_actual_tokens].contiguous() + cu_seqlens_q = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + batch = seq_lens.shape[0] + + b_start_loc = cu_seqlens_q[:batch] + b_seq_len = cu_seqlens_q[1 : batch + 1] - cu_seqlens_q[:batch] + + # Dummy K/V for paged mode: not used by the kernel (KV are read from + # k_cache/v_cache via block_table), but shape[1] must be num_kv_heads + # so the kernel computes the correct GQA ratio (num_q_heads // num_kv_heads). + k_dummy = torch.empty(0, self.num_kv_heads, self.head_size, device=q.device, dtype=q.dtype) + + # Call ModelOpt Triton kernel with paged KV. + # b_seq_len is the query length (e.g., 6 for prefill, 1 for decode). + # b_seq_len_k is the total KV length including cache (e.g., 6 for first + # prefill, 7/8/... for subsequent decode steps). + triton_out = triton_attention( + q, + k=k_dummy, + v=k_dummy, + # Query metadata + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=attn_metadata.max_query_len, + is_causal=is_prefill, # causal for prefill, non-causal for decode + softmax_scale=self.scale, + # KV metadata + b_start_loc_k=None, # paged mode: KV offsets not needed + b_seq_len_k=seq_lens, # total KV length per sequence + max_input_len_k=attn_metadata.max_seq_len, + # Paged KV cache + k_cache=key_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + v_cache=value_cache, # [num_blocks, page_size, num_kv_heads, head_dim] + block_table=attn_metadata.block_table, # [batch, max_blocks] + page_size=page_size, # tokens per page in the KV cache + **sparse_kw, + ) + + output[:num_actual_tokens] = triton_out + return output + + +class ModelOptSparseAttentionBackend(FlashAttentionBackend): + """Attention backend that uses ModelOpt's sparse Triton kernel. + + Inherits everything from FlashAttentionBackend except get_impl_cls and get_name. + """ + + @staticmethod + def get_name() -> str: + """Return backend name.""" + return "MODELOPT_SPARSE" + + @staticmethod + def get_impl_cls() -> type: + """Return the attention implementation class.""" + return ModelOptSparseAttentionImpl diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 1eabdfe358..3b8d9e2b92 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -66,12 +66,13 @@ def collect(self, stats: dict): self.aggregated_stats["total_calls"] += 1 self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) - incoming = stats["sparse_blocks"] - if "sparse_blocks" not in self.aggregated_stats: - self.aggregated_stats["sparse_blocks"] = list(incoming) - else: - for i, val in enumerate(incoming): - self.aggregated_stats["sparse_blocks"][i] += val + incoming = stats.get("sparse_blocks") + if incoming is not None: + if "sparse_blocks" not in self.aggregated_stats: + self.aggregated_stats["sparse_blocks"] = list(incoming) + else: + for i, val in enumerate(incoming): + self.aggregated_stats["sparse_blocks"][i] += val phase = stats.get("phase", "unknown") if phase in self.aggregated_stats["phase_counts"]: @@ -79,14 +80,15 @@ def collect(self, stats: dict): # In calibration mode, store per-sample stats if self.calibration_mode: - self.per_sample_stats.append( - { - "module": self.module_name, - "sparsity": stats.get("sparsity", 0.0), - "sample_length": stats.get("sample_length", 0), - "phase": phase, - } - ) + sample_stat = { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + if "normalized_gaps" in stats: + sample_stat["normalized_gaps"] = stats["normalized_gaps"] + self.per_sample_stats.append(sample_stat) def get_summary(self) -> dict: """Get aggregated statistics summary. diff --git a/pyproject.toml b/pyproject.toml index bb8c72a4b4..ee4398e3d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -221,6 +221,7 @@ convention = "google" [tool.ruff.lint.isort] known-first-party = ["modelopt"] +known-third-party = ["vllm"] split-on-trailing-comma = false diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py new file mode 100644 index 0000000000..f342bcd9ad --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for paged KV cache mode of the Triton flash attention kernel.""" + +import pytest +import torch +from conftest import make_qkv, make_varlen_meta + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.kernels import attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _scatter_to_paged_cache(k, v, b_start_loc, b_seq_len, num_kv_heads, head_dim, page_size): + """Scatter contiguous K/V into a paged KV cache + block table. + + Args: + k: [total_kv, num_kv_heads, head_dim] contiguous keys + v: [total_kv, num_kv_heads, head_dim] contiguous values + b_start_loc: [batch] start offsets + b_seq_len: [batch] sequence lengths + num_kv_heads: number of KV heads + head_dim: head dimension + page_size: tokens per page + + Returns: + k_cache: [num_blocks, page_size, num_kv_heads, head_dim] + v_cache: [num_blocks, page_size, num_kv_heads, head_dim] + block_table: [batch, max_blocks_per_seq] + """ + batch = b_seq_len.shape[0] + device = k.device + dtype = k.dtype + + # Calculate blocks needed per sequence + blocks_per_seq = [] + for b in range(batch): + slen = int(b_seq_len[b].item()) + blocks_per_seq.append((slen + page_size - 1) // page_size) + + max_blocks = max(blocks_per_seq) + num_blocks = sum(blocks_per_seq) + + k_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(num_blocks, page_size, num_kv_heads, head_dim, device=device, dtype=dtype) + block_table = torch.zeros(batch, max_blocks, device=device, dtype=torch.int32) + + global_block = 0 + for b in range(batch): + start = int(b_start_loc[b].item()) + slen = int(b_seq_len[b].item()) + for blk in range(blocks_per_seq[b]): + block_table[b, blk] = global_block + tok_start = blk * page_size + tok_end = min(tok_start + page_size, slen) + n_toks = tok_end - tok_start + k_cache[global_block, :n_toks] = k[start + tok_start : start + tok_end] + v_cache[global_block, :n_toks] = v[start + tok_start : start + tok_end] + global_block += 1 + + return k_cache, v_cache, block_table + + +# --------------------------------------------------------------------------- +# Paged KV cache tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestPagedKV: + """Paged KV cache mode tests — verify paged output matches contiguous.""" + + def test_paged_matches_contiguous(self): + """Paged mode produces same output as contiguous mode with identical data.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(42) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + # Build paged cache from the same K/V + locs_k, lens_k = locs, lens + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs_k, lens_k, num_kv_heads, head_dim, page_size + ) + + # Paged mode + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_start_loc_k=locs_k, + b_seq_len_k=lens_k, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_no_nan(self): + """Paged mode output is finite.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(55) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert not torch.isnan(out).any(), "NaN in paged output" + assert not torch.isinf(out).any(), "Inf in paged output" + + def test_paged_variable_length(self): + """Paged mode works with variable-length sequences.""" + seq_lens = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = sum(seq_lens) + + torch.manual_seed(77) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta(seq_lens) + + # Contiguous reference + out_contig = attention(q, k, v, locs, lens, max(seq_lens), softmax_scale=scale) + + # Paged + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + max(seq_lens), + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=max(seq_lens), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("page_size", [16, 32, 64]) + def test_paged_different_page_sizes(self, page_size): + """Paged mode works with different page sizes.""" + batch = 2 + seq_len = 128 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(88) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + out_contig = attention(q, k, v, locs, lens, seq_len, softmax_scale=scale) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + torch.testing.assert_close(out_paged, out_contig, rtol=1e-2, atol=1e-2) + + def test_paged_with_sparsity(self): + """Paged mode works with N:M sparsity enabled.""" + batch = 2 + seq_len = 256 + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total = batch * seq_len + + torch.manual_seed(99) + q, k, v = make_qkv(total, num_heads, num_kv_heads, head_dim) + locs, lens = make_varlen_meta([seq_len] * batch) + + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k, v, locs, lens, num_kv_heads, head_dim, page_size + ) + + out_paged_sparse = attention( + q, + k, + v, + locs, + lens, + seq_len, + softmax_scale=scale, + b_seq_len_k=lens, + max_input_len_k=seq_len, + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + sparsity_n=2, + sparsity_m=4, + ) + + assert not torch.isnan(out_paged_sparse).any(), "NaN in paged + sparse output" + assert not torch.isinf(out_paged_sparse).any(), "Inf in paged + sparse output" + assert out_paged_sparse.shape == q.shape + + def test_paged_decode(self): + """Paged mode works for decode (single Q token, long KV context).""" + batch = 2 + seq_lens_k = [64, 128] + num_heads, num_kv_heads, head_dim = 4, 2, 64 + page_size = 16 + scale = 1.0 / (head_dim**0.5) + total_kv = sum(seq_lens_k) + + torch.manual_seed(33) + q_flat = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float16) + k_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v_flat = torch.randn(total_kv, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + b_start_loc_q = torch.arange(batch, device="cuda", dtype=torch.int32) + b_seq_len_q = torch.ones(batch, device="cuda", dtype=torch.int32) + cumsum = [0] + for sl in seq_lens_k: + cumsum.append(cumsum[-1] + sl) + b_start_loc_k = torch.tensor(cumsum[:-1], device="cuda", dtype=torch.int32) + b_seq_len_k = torch.tensor(seq_lens_k, device="cuda", dtype=torch.int32) + + # Build paged cache + k_cache, v_cache, block_table = _scatter_to_paged_cache( + k_flat, v_flat, b_start_loc_k, b_seq_len_k, num_kv_heads, head_dim, page_size + ) + + out = attention( + q_flat, + k_flat, + v_flat, + b_start_loc_q, + b_seq_len_q, + 1, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=max(seq_lens_k), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_table, + page_size=page_size, + ) + + assert out.shape == q_flat.shape + assert not torch.isnan(out).any(), "NaN in paged decode output" diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py b/tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py new file mode 100644 index 0000000000..bc8539a9a8 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the diffusers WAN sparse attention via the modelopt_triton backend. + +Tests cover: +- Config validation: "triton" backend is accepted; "diffusers_triton" is rejected +- quantize_p is NOT in SparseAttentionAttributeConfig (it moved to quantization.sage_attention) +- triton_sparse_softmax and triton_skip_softmax do NOT have quantize_p attribute +- diffusers_triton_attention: set_triton_skip_softmax_config does NOT touch quantize_p +- diffusers_triton_attention: set_sage_attention_config / clear_sage_attention_config +- clear_triton_skip_softmax_config does NOT reset quantize_p (composability) +""" + +import pytest + +pytest.importorskip("diffusers") + +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Configs (no quantize_p — that's now a quantization feature) +# --------------------------------------------------------------------------- + + +_SPARSE_CFG = { + "sparse_cfg": { + "*": { + "method": "triton_sparse_softmax", + "sparsity_n": 2, + "sparsity_m": 4, + "num_sink_tokens": 0, + "dense_window_size": 0, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + } +} + +_SKIP_CFG = { + "sparse_cfg": { + "*": { + "method": "triton_skip_softmax", + "skip_softmax_threshold": 0.1, + "backend": "triton", + "enable": True, + }, + "default": {"enable": False}, + } +} + + +# --------------------------------------------------------------------------- +# Tests: config validation +# --------------------------------------------------------------------------- + + +class TestTritonBackend: + """Validate that the "triton" backend is accepted and "diffusers_triton" is rejected.""" + + def test_backend_accepted_sparse(self): + from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + cfg = SparseAttentionAttributeConfig(backend="triton", method="triton_sparse_softmax") + assert cfg.backend == "triton" + + def test_backend_accepted_skip(self): + from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + cfg = SparseAttentionAttributeConfig(backend="triton", method="triton_skip_softmax") + assert cfg.backend == "triton" + + def test_diffusers_triton_backend_rejected(self): + from pydantic import ValidationError + + from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + with pytest.raises(ValidationError): + SparseAttentionAttributeConfig(backend="diffusers_triton") + + def test_invalid_backend_still_rejected(self): + from pydantic import ValidationError + + from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + with pytest.raises(ValidationError): + SparseAttentionAttributeConfig(backend="unknown_backend") + + def test_quantize_p_not_in_sparse_config(self): + """quantize_p moved to quantization.sage_attention — should not appear in sparse config.""" + from modelopt.torch.sparsity.attention_sparsity.config import SparseAttentionAttributeConfig + + cfg = SparseAttentionAttributeConfig(backend="triton", method="triton_sparse_softmax") + assert not hasattr(cfg, "quantize_p"), ( + "quantize_p should NOT be a field on SparseAttentionAttributeConfig; " + "it belongs to modelopt.torch.quantization.sage_attention" + ) + + +# --------------------------------------------------------------------------- +# Tests: sparse methods do NOT have quantize_p +# --------------------------------------------------------------------------- + + +class TestMethodNoQuantizeP: + """triton_skip_softmax and triton_sparse_softmax must NOT expose quantize_p.""" + + def test_skip_softmax_no_quantize_p(self): + from modelopt.torch.sparsity.attention_sparsity.methods.triton_skip_softmax import ( + TritonSkipSoftmaxMethod, + ) + + m = TritonSkipSoftmaxMethod(method_config={"skip_softmax_threshold": 0.05}) + assert not hasattr(m, "quantize_p"), ( + "TritonSkipSoftmaxMethod must not have a quantize_p attribute; " + "NVFP4 quantization is managed by modelopt.torch.quantization.sage_attention" + ) + assert m.skip_softmax_threshold == pytest.approx(0.05) + + def test_sparse_softmax_no_quantize_p(self): + from modelopt.torch.sparsity.attention_sparsity.methods.triton_sparse_softmax import ( + TritonSparseSoftmaxMethod, + ) + + m = TritonSparseSoftmaxMethod(method_config={"sparsity_n": 2, "sparsity_m": 4}) + assert not hasattr(m, "quantize_p"), ( + "TritonSparseSoftmaxMethod must not have a quantize_p attribute; " + "NVFP4 quantization is managed by modelopt.torch.quantization.sage_attention" + ) + assert m.sparsity_n == 2 + + +# --------------------------------------------------------------------------- +# Tests: diffusers_triton_attention thread-local config +# --------------------------------------------------------------------------- + + +class TestDiffusersTritonAttentionConfig: + """set/clear functions for sparse params and sage_attention params work correctly.""" + + def test_set_and_get_threshold(self): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _thread_local, + clear_triton_skip_softmax_config, + set_triton_skip_softmax_config, + ) + + set_triton_skip_softmax_config(threshold=0.05) + assert _thread_local.skip_threshold == pytest.approx(0.05) + clear_triton_skip_softmax_config() + + def test_set_triton_skip_softmax_config_no_quantize_p_param(self): + """set_triton_skip_softmax_config must NOT accept quantize_p.""" + import inspect + + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + set_triton_skip_softmax_config, + ) + + sig = inspect.signature(set_triton_skip_softmax_config) + assert "quantize_p" not in sig.parameters, ( + "set_triton_skip_softmax_config must not have a quantize_p parameter; " + "use set_sage_attention_config() instead" + ) + + def test_set_and_get_sparsity_params(self): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _thread_local, + clear_triton_skip_softmax_config, + set_triton_skip_softmax_config, + ) + + set_triton_skip_softmax_config( + sparsity_n=2, + sparsity_m=4, + num_sink_tokens=16, + dense_window_size=128, + ) + assert _thread_local.sparsity_n == 2 + assert _thread_local.sparsity_m == 4 + assert _thread_local.num_sink_tokens == 16 + assert _thread_local.dense_window_size == 128 + clear_triton_skip_softmax_config() + assert _thread_local.sparsity_n == 0 + assert _thread_local.sparsity_m == 4 + assert _thread_local.num_sink_tokens == 0 + assert _thread_local.dense_window_size == 64 + + def test_clear_sparse_does_not_reset_quantize_p(self): + """clear_triton_skip_softmax_config must NOT reset quantize_p. + + This is the key composability guarantee: SageAttention sets quantize_p=True + once for the whole transformer forward; each per-layer sparsity context + can clear its own sparse params without clobbering the outer quantize_p. + """ + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _thread_local, + clear_sage_attention_config, + clear_triton_skip_softmax_config, + set_sage_attention_config, + set_triton_skip_softmax_config, + ) + + # SageAttention outer wrapper sets quantize_p=True + set_sage_attention_config(quantize_p=True) + assert _thread_local.quantize_p is True + + # Sparsity per-layer context sets threshold then clears + set_triton_skip_softmax_config(threshold=0.1) + clear_triton_skip_softmax_config() + + # quantize_p must survive the sparsity clear + assert _thread_local.quantize_p is True, ( + "clear_triton_skip_softmax_config() must not reset quantize_p; " + "SageAttention controls quantize_p independently" + ) + + # SageAttention outer wrapper clears quantize_p + clear_sage_attention_config() + assert _thread_local.quantize_p is False + + def test_set_sage_attention_config(self): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _thread_local, + clear_sage_attention_config, + set_sage_attention_config, + ) + + set_sage_attention_config(quantize_p=True) + assert _thread_local.quantize_p is True + clear_sage_attention_config() + assert _thread_local.quantize_p is False + + +# --------------------------------------------------------------------------- +# Tests: apply_sage_attention API +# --------------------------------------------------------------------------- + + +class TestApplySageAttention: + """apply_sage_attention wraps the transformer forward and marks the module.""" + + def _make_dummy_transformer(self): + """A minimal nn.Module whose forward returns a tensor.""" + + class DummyTransformer(nn.Module): + def forward(self, x): + return x * 2 + + return DummyTransformer() + + def test_marks_transformer(self): + """apply_sage_attention sets _modelopt_sage_attention=True on the module.""" + pytest.importorskip("triton") + + from modelopt.torch.quantization import apply_sage_attention + + model = self._make_dummy_transformer() + assert not hasattr(model, "_modelopt_sage_attention") + apply_sage_attention(model) + assert getattr(model, "_modelopt_sage_attention", False) is True + + def test_wraps_forward(self): + """apply_sage_attention replaces forward with a wrapper function.""" + pytest.importorskip("triton") + + from modelopt.torch.quantization import apply_sage_attention + + model = self._make_dummy_transformer() + original = model.forward + apply_sage_attention(model) + assert model.forward is not original + + def test_import_from_mtq(self): + """apply_sage_attention is accessible via modelopt.torch.quantization.""" + import modelopt.torch.quantization as mtq + + assert hasattr(mtq, "apply_sage_attention") + assert callable(mtq.apply_sage_attention) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py new file mode 100644 index 0000000000..b8685b8410 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for diffusers kernel backends and thread-local context.""" + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + + +def _mock_diffusers(): + """Mock diffusers.models.attention_dispatch for testing without real diffusers.""" + m = types.ModuleType("diffusers.models.attention_dispatch") + + class FakeBackendName(str): + _member_map_: dict = {} + _value2member_map_: dict = {} + + m.AttentionBackendName = FakeBackendName + + class FakeReg: + _backends: dict = {} + _constraints: dict = {} + _supported_arg_names: dict = {} + + m._AttentionBackendRegistry = FakeReg + m.attention_backend = MagicMock() + return { + "diffusers": types.ModuleType("diffusers"), + "diffusers.models": types.ModuleType("diffusers.models"), + "diffusers.models.attention_dispatch": m, + } + + +# --------------------------------------------------------------------------- +# Tests: thread-local skip-softmax context +# --------------------------------------------------------------------------- + + +class TestSkipSoftmaxContext: + def test_default_is_false(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context + + assert get_skip_softmax_context() is False + + def test_set_and_get(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import ( + get_skip_softmax_context, + set_skip_softmax_context, + ) + + set_skip_softmax_context(True) + assert get_skip_softmax_context() is True + set_skip_softmax_context(False) + assert get_skip_softmax_context() is False + + +# --------------------------------------------------------------------------- +# Tests: diffusers eager attention +# --------------------------------------------------------------------------- + + +class TestDiffusersEagerAttention: + @pytest.fixture(autouse=True) + def _setup(self): + with patch.dict(sys.modules, _mock_diffusers()): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention import ( + _diffusers_eager_attention, + get_skip_softmax_attention_backend, + register_diffusers_eager_attention, + ) + + self._fn = _diffusers_eager_attention + self._register = register_diffusers_eager_attention + self._get_backend = get_skip_softmax_attention_backend + + import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention as mod + + mod._BACKEND_REGISTERED = False + yield + + def test_basic_shape(self): + q = torch.randn(2, 8, 4, 16) + assert self._fn(q, q, q).shape == (2, 8, 4, 16) + + def test_cross_attention(self): + q = torch.randn(1, 4, 2, 8) + k = torch.randn(1, 12, 2, 8) + assert self._fn(q, k, k).shape == (1, 4, 2, 8) + + def test_causal(self): + q = torch.randn(1, 4, 1, 8) + assert self._fn(q, q, q, is_causal=True).shape == (1, 4, 1, 8) + + def test_gqa(self): + q = torch.randn(1, 4, 8, 16) + k = torch.randn(1, 4, 2, 16) + assert self._fn(q, k, k, enable_gqa=True).shape == (1, 4, 8, 16) + + def test_register_idempotent(self): + self._register() + self._register() + + def test_get_backend_before_register_raises(self): + with pytest.raises(RuntimeError, match="not registered"): + self._get_backend() + + +# --------------------------------------------------------------------------- +# Tests: diffusers triton attention +# --------------------------------------------------------------------------- + + +class TestDiffusersTritonAttention: + @pytest.fixture(autouse=True) + def _setup(self): + mocks = _mock_diffusers() + mk = types.ModuleType("modelopt.torch.kernels") + mk.attention = lambda q, k, v, **kw: q + mk.IS_AVAILABLE = True + mk.register_triton_attention = None + mocks["modelopt.torch.kernels"] = mk + + with patch.dict(sys.modules, mocks): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _diffusers_triton_attention, + clear_triton_skip_softmax_config, + get_triton_attention_backend, + register_diffusers_triton_attention, + set_triton_skip_softmax_config, + ) + + self._fn = _diffusers_triton_attention + self._set = set_triton_skip_softmax_config + self._clear = clear_triton_skip_softmax_config + self._register = register_diffusers_triton_attention + self._get_backend = get_triton_attention_backend + + import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention as mod + + mod._BACKEND_REGISTERED = False + yield + + def test_set_clear_config(self): + self._set(threshold=0.1) + self._clear() + + def test_register_idempotent(self): + self._register() + self._register() + + def test_get_backend_before_register_raises(self): + with pytest.raises(RuntimeError, match="not registered"): + self._get_backend() + + +# --------------------------------------------------------------------------- +# Tests: conversion.py _register_diffusers_backends_if_needed +# --------------------------------------------------------------------------- + + +class TestRegisterDiffusersBackends: + def test_no_diffusers_no_error(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + _register_diffusers_backends_if_needed(nn.Linear(10, 10)) + + def test_with_diffusers_model(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + mock_mixin = type("ModelMixin", (nn.Module,), {}) + mock_utils = types.ModuleType("diffusers.models.modeling_utils") + mock_utils.ModelMixin = mock_mixin + + with ( + patch.dict(sys.modules, {"diffusers.models.modeling_utils": mock_utils}), + patch( + "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention", + MagicMock(), + ) as mock_eager, + patch( + "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_triton_attention", + MagicMock(), + ) as mock_triton, + ): + _register_diffusers_backends_if_needed(mock_mixin()) + mock_eager.assert_called_once() + mock_triton.assert_called_once()