2828
2929__all__ = [
3030 "_native_attention" ,
31+ "_sdpa_cudnn_attention" ,
3132]
3233
3334# Enable custom native attention backend with context parallelism
@@ -52,25 +53,33 @@ def _is_native_attn_supported_context_parallel() -> bool:
5253 )
5354
5455
55- if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH :
56- logger .warning (
57- "Re-registering NATIVE attention backend to enable context parallelism. "
58- "This is a temporary workaround and should be removed after the native "
59- "attention backend supports context parallelism natively. Please check: "
60- "https://github.com/huggingface/diffusers/pull/12563 for more details. "
61- "Or, you can disable this behavior by setting the environment variable "
62- "`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
63- )
64- _AttentionBackendRegistry ._backends .pop (AttentionBackendName .NATIVE )
65- _AttentionBackendRegistry ._constraints .pop (AttentionBackendName .NATIVE )
66- _AttentionBackendRegistry ._supported_arg_names .pop (AttentionBackendName .NATIVE )
56+ def _registry_pop_attn_backend (attn_backend : AttentionBackendName ):
57+ _AttentionBackendRegistry ._backends .pop (attn_backend )
58+ _AttentionBackendRegistry ._constraints .pop (attn_backend )
59+ _AttentionBackendRegistry ._supported_arg_names .pop (attn_backend )
6760 if _is_native_attn_supported_context_parallel ():
6861 if isinstance (_AttentionBackendRegistry ._supports_context_parallel , dict ):
69- _AttentionBackendRegistry ._supports_context_parallel .pop (AttentionBackendName . NATIVE )
62+ _AttentionBackendRegistry ._supports_context_parallel .pop (attn_backend )
7063 else :
71- _AttentionBackendRegistry ._supports_context_parallel .remove (
72- AttentionBackendName .NATIVE .value
73- )
64+ _AttentionBackendRegistry ._supports_context_parallel .remove (attn_backend .value )
65+
66+
67+ def _set_new_attn_backend (member : str , value : str ):
68+ # e.g., _set_new_attn_backend("_SDPA_CUDNN", "_sdpa_cudnn")
69+ new_member = str .__new__ (AttentionBackendName , value )
70+ new_member ._name_ = member
71+ new_member ._value_ = value
72+ setattr (AttentionBackendName , member , new_member )
73+ AttentionBackendName ._member_map_ [member ] = new_member
74+ AttentionBackendName ._member_names_ .append (member )
75+ AttentionBackendName ._value2member_map_ [value ] = new_member
76+
77+
78+ if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH :
79+ _ATTENTION_OPS_ALLOW_ATTN_MASK = [
80+ "_native_attention_forward_op" ,
81+ "_sdpa_cudnn_attention_forward_op" ,
82+ ]
7483
7584 # Re-define templated context parallel attention to support attn mask
7685 def _templated_context_parallel_attention_v2 (
@@ -91,7 +100,7 @@ def _templated_context_parallel_attention_v2(
91100 if attn_mask is not None :
92101 # NOTE(DefTruth): Check if forward_op is native attention forward op
93102 forward_op_name = forward_op .__name__
94- if not forward_op_name == "_native_attention_forward_op" :
103+ if forward_op_name not in _ATTENTION_OPS_ALLOW_ATTN_MASK :
95104 raise ValueError (
96105 "Templated context parallel attention with attn_mask "
97106 "is only supported for native attention backend, "
@@ -239,6 +248,9 @@ def _native_attention_backward_op(
239248
240249 return grad_query , grad_key , grad_value
241250
251+ # Re-register NATIVE attention backend to allow attn mask while using context parallelism
252+ _registry_pop_attn_backend (AttentionBackendName .NATIVE )
253+
242254 @_AttentionBackendRegistry .register (
243255 AttentionBackendName .NATIVE ,
244256 constraints = [_check_device , _check_shape ],
@@ -288,9 +300,130 @@ def _native_attention(
288300 )
289301 return out
290302
303+ logger .warning (
304+ "Re-registered NATIVE attention backend to enable context parallelism "
305+ "with attn mask. You can disable this behavior by export env: "
306+ "export CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0."
307+ )
308+
309+ def _sdpa_cudnn_attention_forward_op (
310+ ctx : torch .autograd .function .FunctionCtx ,
311+ query : torch .Tensor ,
312+ key : torch .Tensor ,
313+ value : torch .Tensor ,
314+ attn_mask : Optional [torch .Tensor ] = None ,
315+ dropout_p : float = 0.0 ,
316+ is_causal : bool = False ,
317+ scale : Optional [float ] = None ,
318+ enable_gqa : bool = False ,
319+ return_lse : bool = False ,
320+ _save_ctx : bool = True ,
321+ _parallel_config : Optional ["ParallelConfig" ] = None ,
322+ ):
323+ # Native attention does not return_lse
324+ if return_lse :
325+ raise ValueError ("cudnn attention with sdpa does not support return_lse=True" )
326+
327+ # used for backward pass
328+ if _save_ctx :
329+ ctx .save_for_backward (query , key , value )
330+ ctx .attn_mask = attn_mask
331+ ctx .dropout_p = dropout_p
332+ ctx .is_causal = is_causal
333+ ctx .scale = scale
334+ ctx .enable_gqa = enable_gqa
335+
336+ query , key , value = (x .permute (0 , 2 , 1 , 3 ) for x in (query , key , value ))
337+ with torch .nn .attention .sdpa_kernel (torch .nn .attention .SDPBackend .CUDNN_ATTENTION ):
338+ out = torch .nn .functional .scaled_dot_product_attention (
339+ query = query ,
340+ key = key ,
341+ value = value ,
342+ attn_mask = attn_mask ,
343+ dropout_p = dropout_p ,
344+ is_causal = is_causal ,
345+ scale = scale ,
346+ enable_gqa = enable_gqa ,
347+ )
348+ out = out .permute (0 , 2 , 1 , 3 )
349+
350+ return out
351+
352+ def _sdpa_cudnn_attention_backward_op (
353+ ctx : torch .autograd .function .FunctionCtx ,
354+ grad_out : torch .Tensor ,
355+ * args ,
356+ ** kwargs ,
357+ ):
358+ raise NotImplementedError ("Backward for cudnn attention with sdpa is not implemented yet." )
359+
360+ # Register _sdpa_cudnn_attention backend to allow attn mask while using context parallelism
361+ _set_new_attn_backend ("_SDPA_CUDNN" , "_sdpa_cudnn" )
362+ assert hasattr (AttentionBackendName , "_SDPA_CUDNN" )
363+
364+ @_AttentionBackendRegistry .register (
365+ AttentionBackendName ._SDPA_CUDNN , # type: AttentionBackendName
366+ constraints = [_check_device , _check_shape ],
367+ supports_context_parallel = True ,
368+ )
369+ def _sdpa_cudnn_attention (
370+ query : torch .Tensor ,
371+ key : torch .Tensor ,
372+ value : torch .Tensor ,
373+ attn_mask : Optional [torch .Tensor ] = None ,
374+ dropout_p : float = 0.0 ,
375+ is_causal : bool = False ,
376+ scale : Optional [float ] = None ,
377+ enable_gqa : bool = False ,
378+ return_lse : bool = False ,
379+ _parallel_config : Optional ["ParallelConfig" ] = None ,
380+ ) -> torch .Tensor :
381+ lse = None
382+ if _parallel_config is None and not return_lse :
383+ query , key , value = (x .permute (0 , 2 , 1 , 3 ).contiguous () for x in (query , key , value ))
384+ with torch .nn .attention .sdpa_kernel (torch .nn .attention .SDPBackend .CUDNN_ATTENTION ):
385+ out = torch .nn .functional .scaled_dot_product_attention (
386+ query = query ,
387+ key = key ,
388+ value = value ,
389+ attn_mask = attn_mask ,
390+ dropout_p = dropout_p ,
391+ is_causal = is_causal ,
392+ scale = scale ,
393+ enable_gqa = enable_gqa ,
394+ )
395+ out = out .permute (0 , 2 , 1 , 3 )
396+ else :
397+ out = _templated_context_parallel_attention_v2 (
398+ query ,
399+ key ,
400+ value ,
401+ attn_mask ,
402+ dropout_p ,
403+ is_causal ,
404+ scale ,
405+ enable_gqa ,
406+ return_lse ,
407+ forward_op = _sdpa_cudnn_attention_forward_op ,
408+ backward_op = _sdpa_cudnn_attention_backward_op ,
409+ _parallel_config = _parallel_config ,
410+ )
411+ if return_lse :
412+ out , lse = out
413+
414+ return (out , lse ) if return_lse else out
415+
416+ logger .info (
417+ "Registered new attention backend: _SDPA_CUDNN, to enable "
418+ "context parallelism with attn mask. You can disable it by: "
419+ "export CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0."
420+ )
421+
291422else :
292423 from diffusers .models .attention_dispatch import (
293424 _native_attention ,
294425 ) # noqa: F401
295426
427+ _sdpa_cudnn_attention = None # type: ignore[assignment]
428+
296429 logger .info ("Native attention backend already supports context parallelism." )
0 commit comments