Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b547fcf
Fix QwenImage txt_seq_lens handling
kashif Nov 23, 2025
72a80c6
formatting
kashif Nov 23, 2025
88cee8b
formatting
kashif Nov 23, 2025
ac5ac24
remove txt_seq_lens and use bool mask
kashif Nov 29, 2025
0477526
Merge branch 'main' into txt_seq_lens
kashif Nov 29, 2025
18efdde
use compute_text_seq_len_from_mask
kashif Nov 30, 2025
6a549d4
add seq_lens to dispatch_attention_fn
kashif Nov 30, 2025
2d424e0
use joint_seq_lens
kashif Nov 30, 2025
30b5f98
remove unused index_block
kashif Nov 30, 2025
588dc04
Merge branch 'main' into txt_seq_lens
kashif Dec 6, 2025
f1c2d99
WIP: Remove seq_lens parameter and use mask-based approach
kashif Dec 6, 2025
ec52417
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 6, 2025
beeb020
fix formatting
kashif Dec 7, 2025
5c6f8e3
undo sage changes
kashif Dec 7, 2025
5d434f6
xformers support
kashif Dec 7, 2025
71ba603
hub fix
kashif Dec 8, 2025
babf490
Merge branch 'main' into txt_seq_lens
kashif Dec 8, 2025
afad335
fix torch compile issues
kashif Dec 8, 2025
2d5ab16
Merge branch 'main' into txt_seq_lens
sayakpaul Dec 9, 2025
c78a1e9
fix tests
kashif Dec 9, 2025
d6d4b1d
use _prepare_attn_mask_native
kashif Dec 9, 2025
e999b76
proper deprecation notice
kashif Dec 9, 2025
8115f0b
add deprecate to txt_seq_lens
kashif Dec 9, 2025
3b1510c
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
3676d8e
Update src/diffusers/models/transformers/transformer_qwenimage.py
kashif Dec 10, 2025
9ed0ffd
Only create the mask if there's actual padding
kashif Dec 10, 2025
abec461
Merge branch 'main' into txt_seq_lens
kashif Dec 10, 2025
e26e7b3
fix order of docstrings
kashif Dec 10, 2025
59e3882
Adds performance benchmarks and optimization details for QwenImage
cdutr Dec 11, 2025
0cb2138
Merge branch 'main' into txt_seq_lens
kashif Dec 12, 2025
60bd454
rope_text_seq_len = text_seq_len
kashif Dec 12, 2025
a5abbb8
rename to max_txt_seq_len
kashif Dec 12, 2025
8415c57
Merge branch 'main' into txt_seq_lens
kashif Dec 15, 2025
afff5b7
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
8dc6c3f
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
22cb03d
removed deprecated args
kashif Dec 17, 2025
125a3a4
undo unrelated change
kashif Dec 17, 2025
b5b6342
Updates QwenImage performance documentation
cdutr Dec 17, 2025
61f5265
Updates deprecation warnings for txt_seq_lens parameter
cdutr Dec 17, 2025
2ef38e2
fix compile
kashif Dec 17, 2025
270c63f
Merge branch 'txt_seq_lens' of https://github.com/kashif/diffusers in…
kashif Dec 17, 2025
35efa06
formatting
kashif Dec 17, 2025
50c4815
fix compile tests
kashif Dec 17, 2025
c88bc06
Merge branch 'main' into txt_seq_lens
kashif Dec 17, 2025
1433783
rename helper
kashif Dec 17, 2025
8de799c
remove duplicate
kashif Dec 17, 2025
fc93747
smaller values
kashif Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions docs/source/en/api/pipelines/qwenimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```

## Performance

### torch.compile

Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):

```python
import torch
from diffusers import QwenImagePipeline

pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer = torch.compile(pipe.transformer)

# First call triggers compilation (~7s overhead)
# Subsequent calls run at ~2.4x faster
image = pipe("a cat", num_inference_steps=50).images[0]
```

### Batched Inference with Variable-Length Prompts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of specifying performance numbers on torch.compile and other attention backends, maybe we could highlight this point and include with and without torch.compile numbers? @cdutr WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I've simplified the Performance section to focus on torch.compile with the before/after numbers,
removed the attention backend tables since the differences between backends are minimal compared to the torch.compile gains


When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.

```python
# CFG with different prompt lengths works correctly
image = pipe(
prompt="A cat",
negative_prompt="blurry, low quality, distorted",
true_cfg_scale=3.5,
num_inference_steps=50,
).images[0]
```

For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).

## QwenImagePipeline

[[autodoc]] QwenImagePipeline
Expand Down
2 changes: 0 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(
Expand Down
77 changes: 75 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,42 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
return out


def _prepare_additive_attn_mask(
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
) -> torch.Tensor:
"""
Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA.

This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they
require.

Args:
attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out
target_dtype: The dtype to convert the mask to (usually query.dtype)
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting

Returns:
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
"""
# Ensure it's boolean
if attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()

# Convert boolean to additive: True -> 0.0, False -> -inf
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))

# Convert to target dtype
attn_mask = attn_mask.to(dtype=target_dtype)

# Optionally reshape to 4D for broadcasting in attention mechanisms
if reshape_4d:
batch_size, seq_len_k = attn_mask.shape
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)

return attn_mask


@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
Expand All @@ -1900,6 +1936,14 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")

# Convert 2D boolean mask to 4D additive mask for SDPA
if attn_mask is not None and attn_mask.ndim == 2:
# attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out
# SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out
# Use helper to convert boolean to additive mask and reshape to 4D
attn_mask = _prepare_additive_attn_mask(attn_mask, target_dtype=query.dtype)

if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
Expand Down Expand Up @@ -2423,10 +2467,36 @@ def _xformers_attention(
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
# Convert 2D boolean mask to 4D for xformers
# attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask
# xformers requires 4D masks [batch, heads, seq_q, seq_k]
# xformers expects additive bias: 0.0 for attend, -inf for mask
# Need memory alignment - create larger tensor and slice for alignment
original_seq_len = attn_mask.size(1)
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8

# Create aligned 4D tensor and slice to ensure proper memory layout
aligned_mask = torch.zeros(
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
dtype=query.dtype,
device=query.device,
)
# Fill in the actual mask values (converting boolean to additive)
# Use helper to convert 2D boolean -> 4D additive mask
mask_additive = _prepare_additive_attn_mask(
attn_mask, target_dtype=query.dtype
) # [batch, 1, 1, seq_len_k]
# Broadcast to [batch, heads, seq_q, seq_len_k]
aligned_mask[:, :, :, :original_seq_len] = mask_additive
# Mask out the padding (already -inf from zeros -> where with default)
aligned_mask[:, :, :, original_seq_len:] = float("-inf")

# Slice to actual size with proper alignment
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
elif attn_mask.ndim == 4:
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)

if enable_gqa:
if num_heads_q % num_heads_kv != 0:
Expand All @@ -2442,3 +2512,6 @@ def _xformers_attention(
out = out.flatten(2, 3)

return out


_maybe_download_kernel_for_backend(_AttentionBackendRegistry._active_backend)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @kashif

55 changes: 42 additions & 13 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
Expand All @@ -31,6 +31,7 @@
QwenImageTransformerBlock,
QwenTimestepProjEmbeddings,
RMSNorm,
compute_text_seq_len_from_mask,
)


Expand Down Expand Up @@ -136,7 +137,7 @@ def forward(
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
The [`QwenImageControlNetModel`] forward method.

Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Expand All @@ -147,24 +148,39 @@ def forward(
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
Image shapes for RoPE computation.
txt_seq_lens (`List[int]`, *optional*):
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
length.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.

Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
the first element is the controlnet block samples.
"""
# Handle deprecated txt_seq_lens parameter
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
"and `encoder_hidden_states_mask`.",
standard_warn=False,
)

if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
Expand All @@ -186,14 +202,19 @@ def forward(

temb = self.time_text_embed(timestep, hidden_states)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)

image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)

timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
Expand Down Expand Up @@ -267,6 +288,15 @@ def forward(
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[QwenImageControlNetOutput, Tuple]:
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
standard_warn=False,
)
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1:
Expand All @@ -281,7 +311,6 @@ def forward(
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
Expand Down
Loading
Loading