diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index b3dd3dd93618..7ea6a5ddfb60 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### torch.compile + +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 53b01bf0cfc8..ea9b137b0acd 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c27..a94afb1f2fbc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1881,6 +1881,42 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out +def _prepare_additive_attn_mask( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D boolean attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + This helper is used by both native SDPA and xformers backends to convert boolean masks to the additive format they + require. + + Args: + attn_mask: 2D boolean tensor [batch_size, seq_len_k] where True means attend, False means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Ensure it's boolean + if attn_mask.dtype != torch.bool: + attn_mask = attn_mask.bool() + + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -1900,6 +1936,14 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Convert 2D boolean mask to 4D additive mask for SDPA + if attn_mask is not None and attn_mask.ndim == 2: + # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask out + # SDPA expects [batch_size, 1, 1, seq_len_k] additive mask: 0.0 for attend, -inf for mask out + # Use helper to convert boolean to additive mask and reshape to 4D + attn_mask = _prepare_additive_attn_mask(attn_mask, target_dtype=query.dtype) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2423,10 +2467,36 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D boolean mask to 4D for xformers + # attn_mask is [batch_size, seq_len_k] boolean: True means attend, False means mask + # xformers requires 4D masks [batch, heads, seq_q, seq_k] + # xformers expects additive bias: 0.0 for attend, -inf for mask + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Fill in the actual mask values (converting boolean to additive) + # Use helper to convert 2D boolean -> 4D additive mask + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 86971271788f..4c85b6ec3852 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,6 +31,7 @@ QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -136,7 +137,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -147,24 +148,39 @@ def forward( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -186,14 +202,19 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, @@ -267,6 +288,15 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -281,7 +311,6 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 1229bab169b2..217c64bf4b51 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -142,6 +142,32 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -207,21 +233,50 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], - device: torch.device, + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. - device: (`torch.device`): + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", + standard_warn=False, + ) + if max_txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -233,8 +288,7 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -242,17 +296,23 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -304,14 +364,35 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -324,11 +405,10 @@ def forward(self, video_fhw, txt_seq_lens, device): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width) - video_freq = video_freq.to(device) + video_freq = self._compute_condition_freqs(frame, height, width, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -337,17 +417,21 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -363,10 +447,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width): + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -454,6 +541,35 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # If an encoder_hidden_states_mask is provided, create a joint attention mask. + # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding. + # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend). + # Only create the mask if there's actual padding, otherwise keep attention_mask=None for better SDPA performance. + if encoder_hidden_states_mask is not None and attention_mask is None: + batch_size, image_seq_len = hidden_states.shape[:2] + text_seq_len = encoder_hidden_states.shape[1] + + if encoder_hidden_states_mask.shape[0] != batch_size: + raise ValueError( + f"`encoder_hidden_states_mask` batch size ({encoder_hidden_states_mask.shape[0]}) " + f"must match hidden_states batch size ({batch_size})." + ) + if encoder_hidden_states_mask.shape[1] != text_seq_len: + raise ValueError( + f"`encoder_hidden_states_mask` sequence length ({encoder_hidden_states_mask.shape[1]}) " + f"must match encoder_hidden_states sequence length ({text_seq_len})." + ) + + # Create joint attention mask + # torch.compile compatible: always create mask when encoder_hidden_states_mask is provided + text_attention_mask = encoder_hidden_states_mask.bool() + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + # Create 2D joint mask [batch_size, text_seq_len + image_seq_len] + # The attention dispatch will normalize this and extract sequence lengths + attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1) + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -762,14 +878,25 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -778,6 +905,15 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -810,6 +946,11 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -819,7 +960,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index bd92d403539e..f0d36ce00b6c 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -525,18 +525,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -551,14 +539,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ) ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) self.set_block_state(state, block_state) @@ -592,18 +572,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -625,15 +593,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 49acd2dc0295..2faa34ada329 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -149,7 +149,7 @@ def inputs(self) -> List[InputParam]: kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + "It should contain prompt_embeds/negative_prompt_embeds." ), ), ] @@ -176,7 +176,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -247,10 +246,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) @@ -345,10 +340,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..bc3ce84e1019 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,11 +672,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +690,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +703,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..ce6fc974a56e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,7 +909,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +919,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +933,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..77d78a5ca7a1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,7 +852,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +862,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +876,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..dd723460a59e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,11 +793,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +816,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +830,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..cf467203a9d2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,11 +1008,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1030,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1044,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..942ee348508c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -777,11 +777,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +800,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +814,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..e0b41b8b8799 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,11 +775,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +792,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +805,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..83f02539b1ba 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,11 +944,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +961,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +974,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 7bb12c26baa4..6ad972f8774d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,10 +781,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -809,7 +805,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -825,7 +820,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..384954dfbad7 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,7 +68,6 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,180 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) + + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) + + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask + with torch.no_grad(): + output = model(**inputs) + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(int(per_sample_len.max().item()), 5) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 3, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) + "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + additional_t_cond=addition_t_cond, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -101,6 +274,5 @@ def prepare_init_args_and_inputs_for_common(self): def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break()