-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix QwenImage txt_seq_lens handling #12702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kashif
wants to merge
47
commits into
huggingface:main
Choose a base branch
from
kashif:txt_seq_lens
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+506
−163
Open
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif 72a80c6
formatting
kashif 88cee8b
formatting
kashif ac5ac24
remove txt_seq_lens and use bool mask
kashif 0477526
Merge branch 'main' into txt_seq_lens
kashif 18efdde
use compute_text_seq_len_from_mask
kashif 6a549d4
add seq_lens to dispatch_attention_fn
kashif 2d424e0
use joint_seq_lens
kashif 30b5f98
remove unused index_block
kashif 588dc04
Merge branch 'main' into txt_seq_lens
kashif f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif beeb020
fix formatting
kashif 5c6f8e3
undo sage changes
kashif 5d434f6
xformers support
kashif 71ba603
hub fix
kashif babf490
Merge branch 'main' into txt_seq_lens
kashif afad335
fix torch compile issues
kashif 2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul c78a1e9
fix tests
kashif d6d4b1d
use _prepare_attn_mask_native
kashif e999b76
proper deprecation notice
kashif 8115f0b
add deprecate to txt_seq_lens
kashif 3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif 3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif 9ed0ffd
Only create the mask if there's actual padding
kashif abec461
Merge branch 'main' into txt_seq_lens
kashif e26e7b3
fix order of docstrings
kashif 59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr 0cb2138
Merge branch 'main' into txt_seq_lens
kashif 60bd454
rope_text_seq_len = text_seq_len
kashif a5abbb8
rename to max_txt_seq_len
kashif 8415c57
Merge branch 'main' into txt_seq_lens
kashif afff5b7
Merge branch 'main' into txt_seq_lens
kashif 8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif 22cb03d
removed deprecated args
kashif 125a3a4
undo unrelated change
kashif b5b6342
Updates QwenImage performance documentation
cdutr 61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr 2ef38e2
fix compile
kashif 270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif 35efa06
formatting
kashif 50c4815
fix compile tests
kashif c88bc06
Merge branch 'main' into txt_seq_lens
kashif 1433783
rename helper
kashif 8de799c
remove duplicate
kashif fc93747
smaller values
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1881,6 +1881,42 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| return out | ||
|
|
||
|
|
||
| def _prepare_additive_attn_mask( | ||
| attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA. | ||
|
|
||
| This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they | ||
| require. | ||
|
|
||
| Args: | ||
| attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out | ||
| target_dtype: The dtype to convert the mask to (usually query.dtype) | ||
| reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting | ||
|
|
||
| Returns: | ||
| Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if | ||
| reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. | ||
| """ | ||
| # Ensure it's boolean | ||
| if attn_mask.dtype != torch.bool: | ||
| attn_mask = attn_mask.bool() | ||
|
|
||
| # Convert boolean to additive: True -> 0.0, False -> -inf | ||
| attn_mask = torch.where(attn_mask, 0.0, float("-inf")) | ||
|
|
||
| # Convert to target dtype | ||
| attn_mask = attn_mask.to(dtype=target_dtype) | ||
|
|
||
| # Optionally reshape to 4D for broadcasting in attention mechanisms | ||
| if reshape_4d: | ||
| batch_size, seq_len_k = attn_mask.shape | ||
| attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) | ||
|
|
||
| return attn_mask | ||
|
|
||
|
|
||
| @_AttentionBackendRegistry.register( | ||
| AttentionBackendName.NATIVE, | ||
| constraints=[_check_device, _check_shape], | ||
|
|
@@ -1900,6 +1936,14 @@ def _native_attention( | |
| ) -> torch.Tensor: | ||
| if return_lse: | ||
| raise ValueError("Native attention backend does not support setting `return_lse=True`.") | ||
|
|
||
| # Convert 2D boolean mask to 4D additive mask for SDPA | ||
| if attn_mask is not None and attn_mask.ndim == 2: | ||
| # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out | ||
| # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out | ||
| # Use helper to convert boolean to additive mask and reshape to 4D | ||
| attn_mask = _prepare_additive_attn_mask(attn_mask, target_dtype=query.dtype) | ||
|
|
||
| if _parallel_config is None: | ||
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | ||
| out = torch.nn.functional.scaled_dot_product_attention( | ||
|
|
@@ -2423,10 +2467,36 @@ def _xformers_attention( | |
| attn_mask = xops.LowerTriangularMask() | ||
| elif attn_mask is not None: | ||
| if attn_mask.ndim == 2: | ||
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) | ||
| # Convert 2D boolean mask to 4D for xformers | ||
kashif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask | ||
| # xformers requires 4D masks [batch, heads, seq_q, seq_k] | ||
| # xformers expects additive bias: 0.0 for attend, -inf for mask | ||
| # Need memory alignment - create larger tensor and slice for alignment | ||
| original_seq_len = attn_mask.size(1) | ||
| aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 | ||
|
|
||
| # Create aligned 4D tensor and slice to ensure proper memory layout | ||
| aligned_mask = torch.zeros( | ||
| (batch_size, num_heads_q, seq_len_q, aligned_seq_len), | ||
| dtype=query.dtype, | ||
| device=query.device, | ||
| ) | ||
| # Fill in the actual mask values (converting boolean to additive) | ||
| # Use helper to convert 2D boolean -> 4D additive mask | ||
| mask_additive = _prepare_additive_attn_mask( | ||
| attn_mask, target_dtype=query.dtype | ||
| ) # [batch, 1, 1, seq_len_k] | ||
| # Broadcast to [batch, heads, seq_q, seq_len_k] | ||
| aligned_mask[:, :, :, :original_seq_len] = mask_additive | ||
| # Mask out the padding (already -inf from zeros -> where with default) | ||
| aligned_mask[:, :, :, original_seq_len:] = float("-inf") | ||
|
|
||
| # Slice to actual size with proper alignment | ||
| attn_mask = aligned_mask[:, :, :, :seq_len_kv] | ||
| elif attn_mask.ndim != 4: | ||
| raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") | ||
| attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) | ||
| elif attn_mask.ndim == 4: | ||
| attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) | ||
|
|
||
| if enable_gqa: | ||
| if num_heads_q % num_heads_kv != 0: | ||
|
|
@@ -2442,3 +2512,6 @@ def _xformers_attention( | |
| out = out.flatten(2, 3) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| _maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cc: @kashif |
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of specifying performance numbers on
torch.compileand other attention backends, maybe we could highlight this point and include with and withouttorch.compilenumbers? @cdutr WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I've simplified the Performance section to focus on torch.compile with the before/after numbers,
removed the attention backend tables since the differences between backends are minimal compared to the torch.compile gains