Skip to content

Commit 8bd89ba

Browse files
authored
feat: support _sdpa_cudnn backend for cp (#504)
* feat: fast rope for z-image * chore: update notes * chore: update z-image cp example * chore: update z-image cp example * chore: update z-image cp example * chore: update z-image cp example * chore: update z-image cp example * feat: allow cudnn attn w/ attn mask for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp * feat: support _sdpa_cudnn backend for cp
1 parent 9a5cdd9 commit 8bd89ba

File tree

5 files changed

+205
-26
lines changed

5 files changed

+205
-26
lines changed

examples/parallelism/run_zimage_cp.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import cache_dit
2020

21-
# NOTE: Only support context parallelism with 'native' attention backend
22-
# for ZImage due to the attention mask in ZImage is not None. Please use:
23-
# --parallel ulysses --attn native
21+
# NOTE: Only support context parallelism with 'native/_sdpa_cudnn' attn backend
22+
# for Z-Image due to the attention mask in Z-Image is not None. Please use:
23+
# `--parallel ulysses --attn native` or `--attn _sdpa_cudnn`.
2424

2525
args = get_args()
2626
print(args)
@@ -59,12 +59,34 @@
5959
# Only warmup 4 steps (total 9 steps) for distilled models
6060
args.max_warmup_steps = min(4, args.max_warmup_steps)
6161

62-
cachify(args, pipe)
62+
cachify(
63+
args,
64+
pipe,
65+
# Total 9 steps for distilled Z-Image-Turbo
66+
# e.g, 111110101, 1: compute, 0: dynamic cache
67+
steps_computation_mask=(
68+
cache_dit.steps_mask(
69+
compute_bins=[5, 1, 1], # 7 steps compute
70+
cache_bins=[1, 1], # max 2 steps cache
71+
)
72+
if args.steps_mask
73+
else None
74+
),
75+
)
6376

6477
pipe.to(device)
6578

6679
assert isinstance(pipe.transformer, ZImageTransformer2DModel)
6780

81+
# Allow customize attention backend for Single GPU inference
82+
if args.parallel_type is None:
83+
# native, flash, _native_cudnn, sage, etc.
84+
# _native_cudnn is faster than native(sdpa) on NVIDIA L20 with CUDA 12.9+.
85+
# '_sdpa_cudnn' is only in cache-dit to support context parallelism
86+
# with attn masks, e.g., Z-Image. It is not in diffusers yet.
87+
if args.attn is not None:
88+
pipe.transformer.set_attention_backend(args.attn)
89+
6890
pipe.set_progress_bar_config(disable=rank != 0)
6991

7092
# Set default prompt
@@ -94,7 +116,14 @@ def run_pipe(warmup: bool = False):
94116

95117
if args.compile:
96118
cache_dit.set_compile_configs()
97-
pipe.transformer = torch.compile(pipe.transformer)
119+
if args.compile_repeated_blocks:
120+
pipe.transformer.compile_repeated_blocks(
121+
mode="max-autotune-no-cudagraphs" if args.max_autotune else "default"
122+
)
123+
else:
124+
pipe.transformer = torch.compile(
125+
pipe.transformer, mode="max-autotune-no-cudagraphs" if args.max_autotune else "default"
126+
)
98127

99128
# warmup
100129
_ = run_pipe(warmup=True)

examples/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def get_args(
6161
parser = argparse.ArgumentParser()
6262
parser.add_argument("--cache", action="store_true", default=False)
6363
parser.add_argument("--compile", action="store_true", default=False)
64+
parser.add_argument("--compile-repeated-blocks", action="store_true", default=False)
65+
parser.add_argument("--max-autotune", action="store_true", default=False)
6466
parser.add_argument("--fuse-lora", action="store_true", default=False)
6567
parser.add_argument("--steps", type=int, default=None)
6668
parser.add_argument("--Fn", type=int, default=8)
@@ -72,6 +74,7 @@ def get_args(
7274
parser.add_argument("--max-continuous-cached-steps", "--mcc", type=int, default=-1)
7375
parser.add_argument("--taylorseer", action="store_true", default=False)
7476
parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
77+
parser.add_argument("--steps-mask", "--scm", action="store_true", default=False)
7578
parser.add_argument("--height", type=int, default=None)
7679
parser.add_argument("--width", type=int, default=None)
7780
parser.add_argument("--quantize", "-q", action="store_true", default=False)
@@ -113,6 +116,9 @@ def get_args(
113116
# Based on this fix: https://github.com/huggingface/diffusers/pull/12563
114117
"native", # native pytorch attention: sdpa
115118
"_native_cudnn",
119+
# '_sdpa_cudnn' is only in cache-dit to support context parallelism
120+
# with attn masks, e.g., ZImage. It is not in diffusers yet.
121+
"_sdpa_cudnn",
116122
"sage", # Need install sageattention: https://github.com/thu-ml/SageAttention
117123
],
118124
)
@@ -220,6 +226,7 @@ def cachify(
220226
max_continuous_cached_steps=args.max_continuous_cached_steps,
221227
residual_diff_threshold=args.rdt,
222228
enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
229+
steps_computation_mask=kwargs.get("steps_computation_mask", None),
223230
)
224231
if cache_config is None and args.cache
225232
else cache_config
@@ -262,6 +269,8 @@ def strify(args, pipe_or_stats):
262269
base_str += "_ulysses_anything"
263270
if args.ulysses_async_qkv_proj:
264271
base_str += "_ulysses_async_qkv_proj"
272+
if args.attn is not None:
273+
base_str += f"_{args.attn.strip('_')}"
265274
return base_str
266275

267276

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
def maybe_resigter_native_attention_backend():
22
"""Maybe re-register native attention backend to enable context parallelism."""
33
# Import custom attention backend ensuring registration
4-
from ._attention_dispatch import _native_attention
4+
from ._attention_dispatch import _native_attention, _sdpa_cudnn_attention
55

66

77
from ._templated_ulysses_anything import (

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py

Lines changed: 150 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
__all__ = [
3030
"_native_attention",
31+
"_sdpa_cudnn_attention",
3132
]
3233

3334
# Enable custom native attention backend with context parallelism
@@ -52,25 +53,33 @@ def _is_native_attn_supported_context_parallel() -> bool:
5253
)
5354

5455

55-
if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
56-
logger.warning(
57-
"Re-registering NATIVE attention backend to enable context parallelism. "
58-
"This is a temporary workaround and should be removed after the native "
59-
"attention backend supports context parallelism natively. Please check: "
60-
"https://github.com/huggingface/diffusers/pull/12563 for more details. "
61-
"Or, you can disable this behavior by setting the environment variable "
62-
"`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
63-
)
64-
_AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
65-
_AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
66-
_AttentionBackendRegistry._supported_arg_names.pop(AttentionBackendName.NATIVE)
56+
def _registry_pop_attn_backend(attn_backend: AttentionBackendName):
57+
_AttentionBackendRegistry._backends.pop(attn_backend)
58+
_AttentionBackendRegistry._constraints.pop(attn_backend)
59+
_AttentionBackendRegistry._supported_arg_names.pop(attn_backend)
6760
if _is_native_attn_supported_context_parallel():
6861
if isinstance(_AttentionBackendRegistry._supports_context_parallel, dict):
69-
_AttentionBackendRegistry._supports_context_parallel.pop(AttentionBackendName.NATIVE)
62+
_AttentionBackendRegistry._supports_context_parallel.pop(attn_backend)
7063
else:
71-
_AttentionBackendRegistry._supports_context_parallel.remove(
72-
AttentionBackendName.NATIVE.value
73-
)
64+
_AttentionBackendRegistry._supports_context_parallel.remove(attn_backend.value)
65+
66+
67+
def _set_new_attn_backend(member: str, value: str):
68+
# e.g., _set_new_attn_backend("_SDPA_CUDNN", "_sdpa_cudnn")
69+
new_member = str.__new__(AttentionBackendName, value)
70+
new_member._name_ = member
71+
new_member._value_ = value
72+
setattr(AttentionBackendName, member, new_member)
73+
AttentionBackendName._member_map_[member] = new_member
74+
AttentionBackendName._member_names_.append(member)
75+
AttentionBackendName._value2member_map_[value] = new_member
76+
77+
78+
if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
79+
_ATTENTION_OPS_ALLOW_ATTN_MASK = [
80+
"_native_attention_forward_op",
81+
"_sdpa_cudnn_attention_forward_op",
82+
]
7483

7584
# Re-define templated context parallel attention to support attn mask
7685
def _templated_context_parallel_attention_v2(
@@ -91,7 +100,7 @@ def _templated_context_parallel_attention_v2(
91100
if attn_mask is not None:
92101
# NOTE(DefTruth): Check if forward_op is native attention forward op
93102
forward_op_name = forward_op.__name__
94-
if not forward_op_name == "_native_attention_forward_op":
103+
if forward_op_name not in _ATTENTION_OPS_ALLOW_ATTN_MASK:
95104
raise ValueError(
96105
"Templated context parallel attention with attn_mask "
97106
"is only supported for native attention backend, "
@@ -239,6 +248,9 @@ def _native_attention_backward_op(
239248

240249
return grad_query, grad_key, grad_value
241250

251+
# Re-register NATIVE attention backend to allow attn mask while using context parallelism
252+
_registry_pop_attn_backend(AttentionBackendName.NATIVE)
253+
242254
@_AttentionBackendRegistry.register(
243255
AttentionBackendName.NATIVE,
244256
constraints=[_check_device, _check_shape],
@@ -288,9 +300,130 @@ def _native_attention(
288300
)
289301
return out
290302

303+
logger.warning(
304+
"Re-registered NATIVE attention backend to enable context parallelism "
305+
"with attn mask. You can disable this behavior by export env: "
306+
"export CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0."
307+
)
308+
309+
def _sdpa_cudnn_attention_forward_op(
310+
ctx: torch.autograd.function.FunctionCtx,
311+
query: torch.Tensor,
312+
key: torch.Tensor,
313+
value: torch.Tensor,
314+
attn_mask: Optional[torch.Tensor] = None,
315+
dropout_p: float = 0.0,
316+
is_causal: bool = False,
317+
scale: Optional[float] = None,
318+
enable_gqa: bool = False,
319+
return_lse: bool = False,
320+
_save_ctx: bool = True,
321+
_parallel_config: Optional["ParallelConfig"] = None,
322+
):
323+
# Native attention does not return_lse
324+
if return_lse:
325+
raise ValueError("cudnn attention with sdpa does not support return_lse=True")
326+
327+
# used for backward pass
328+
if _save_ctx:
329+
ctx.save_for_backward(query, key, value)
330+
ctx.attn_mask = attn_mask
331+
ctx.dropout_p = dropout_p
332+
ctx.is_causal = is_causal
333+
ctx.scale = scale
334+
ctx.enable_gqa = enable_gqa
335+
336+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
337+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
338+
out = torch.nn.functional.scaled_dot_product_attention(
339+
query=query,
340+
key=key,
341+
value=value,
342+
attn_mask=attn_mask,
343+
dropout_p=dropout_p,
344+
is_causal=is_causal,
345+
scale=scale,
346+
enable_gqa=enable_gqa,
347+
)
348+
out = out.permute(0, 2, 1, 3)
349+
350+
return out
351+
352+
def _sdpa_cudnn_attention_backward_op(
353+
ctx: torch.autograd.function.FunctionCtx,
354+
grad_out: torch.Tensor,
355+
*args,
356+
**kwargs,
357+
):
358+
raise NotImplementedError("Backward for cudnn attention with sdpa is not implemented yet.")
359+
360+
# Register _sdpa_cudnn_attention backend to allow attn mask while using context parallelism
361+
_set_new_attn_backend("_SDPA_CUDNN", "_sdpa_cudnn")
362+
assert hasattr(AttentionBackendName, "_SDPA_CUDNN")
363+
364+
@_AttentionBackendRegistry.register(
365+
AttentionBackendName._SDPA_CUDNN, # type: AttentionBackendName
366+
constraints=[_check_device, _check_shape],
367+
supports_context_parallel=True,
368+
)
369+
def _sdpa_cudnn_attention(
370+
query: torch.Tensor,
371+
key: torch.Tensor,
372+
value: torch.Tensor,
373+
attn_mask: Optional[torch.Tensor] = None,
374+
dropout_p: float = 0.0,
375+
is_causal: bool = False,
376+
scale: Optional[float] = None,
377+
enable_gqa: bool = False,
378+
return_lse: bool = False,
379+
_parallel_config: Optional["ParallelConfig"] = None,
380+
) -> torch.Tensor:
381+
lse = None
382+
if _parallel_config is None and not return_lse:
383+
query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
384+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
385+
out = torch.nn.functional.scaled_dot_product_attention(
386+
query=query,
387+
key=key,
388+
value=value,
389+
attn_mask=attn_mask,
390+
dropout_p=dropout_p,
391+
is_causal=is_causal,
392+
scale=scale,
393+
enable_gqa=enable_gqa,
394+
)
395+
out = out.permute(0, 2, 1, 3)
396+
else:
397+
out = _templated_context_parallel_attention_v2(
398+
query,
399+
key,
400+
value,
401+
attn_mask,
402+
dropout_p,
403+
is_causal,
404+
scale,
405+
enable_gqa,
406+
return_lse,
407+
forward_op=_sdpa_cudnn_attention_forward_op,
408+
backward_op=_sdpa_cudnn_attention_backward_op,
409+
_parallel_config=_parallel_config,
410+
)
411+
if return_lse:
412+
out, lse = out
413+
414+
return (out, lse) if return_lse else out
415+
416+
logger.info(
417+
"Registered new attention backend: _SDPA_CUDNN, to enable "
418+
"context parallelism with attn mask. You can disable it by: "
419+
"export CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0."
420+
)
421+
291422
else:
292423
from diffusers.models.attention_dispatch import (
293424
_native_attention,
294425
) # noqa: F401
295426

427+
_sdpa_cudnn_attention = None # type: ignore[assignment]
428+
296429
logger.info("Native attention backend already supports context parallelism.")

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_zimage.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from diffusers.models.modeling_utils import ModelMixin
44
from diffusers import ZImageTransformer2DModel
55

6+
67
try:
78
from diffusers.models._modeling_parallel import (
89
ContextParallelInput,
@@ -51,9 +52,7 @@ def apply(
5152
# hooks in each block/layer in the initialization of DBCache.
5253
# Issue: https://github.com/vipshop/cache-dit/issues/498
5354
maybe_patch_cp_find_submodule_by_name()
54-
# Otherwise, use the custom CP plan defined here, this maybe
55-
# a little different from the native diffusers implementation
56-
# for some models.
55+
# TODO: Patch rotary embedding function to avoid complex number ops
5756
n_noise_refiner_layers = len(transformer.noise_refiner) # 2
5857
n_context_refiner_layers = len(transformer.context_refiner) # 2
5958
# num_layers = len(transformer.layers) # 30
@@ -93,3 +92,12 @@ def apply(
9392
# f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
9493
}
9594
return _cp_plan
95+
96+
97+
# TODO: Original implementation using complex numbers, which is not be supported in torch.compile yet.
98+
# May be Reference:
99+
# - https://github.com/triple-Mu/Z-Image-TensorRT/blob/4efc5749e9a0d22344e6c4b8a09d2223dd0a7e17/step_by_step/2-remove-complex-op.py#L26C1-L36C25
100+
# - https://github.com/huggingface/diffusers/pull/12725
101+
102+
103+
# TODO: Support Async Ulysses QKV projection for Z-Image

0 commit comments

Comments
 (0)