From 53d56028faca0a145e09cdaff7a7171f6449a526 Mon Sep 17 00:00:00 2001 From: "yueyang.hyy" Date: Wed, 9 Jul 2025 16:12:24 +0800 Subject: [PATCH 1/2] support flux fbcache --- diffsynth_engine/models/flux/__init__.py | 2 + .../models/flux/flux_dit_fbcache.py | 206 ++++++++++++++ diffsynth_engine/models/sd/sd_controlnet.py | 252 ++++++++++++------ diffsynth_engine/models/sdxl/__init__.py | 2 +- .../models/sdxl/sdxl_controlnet.py | 191 ++++++++----- diffsynth_engine/models/sdxl/sdxl_unet.py | 3 +- .../pipelines/controlnet_helper.py | 6 +- diffsynth_engine/pipelines/flux_image.py | 17 +- diffsynth_engine/pipelines/sd_image.py | 35 +-- diffsynth_engine/pipelines/sdxl_image.py | 63 +++-- tests/test_pipelines/test_flux_bfl_image.py | 2 +- tests/test_pipelines/test_sd_controlnet.py | 2 +- tests/test_pipelines/test_sdxl_controlnet.py | 1 + 13 files changed, 580 insertions(+), 202 deletions(-) create mode 100644 diffsynth_engine/models/flux/flux_dit_fbcache.py diff --git a/diffsynth_engine/models/flux/__init__.py b/diffsynth_engine/models/flux/__init__.py index 69e5c24b..ff285ff2 100644 --- a/diffsynth_engine/models/flux/__init__.py +++ b/diffsynth_engine/models/flux/__init__.py @@ -4,6 +4,7 @@ from .flux_controlnet import FluxControlNet from .flux_ipadapter import FluxIPAdapter from .flux_redux import FluxRedux +from .flux_dit_fbcache import FluxDiTFBCache __all__ = [ "FluxRedux", @@ -14,6 +15,7 @@ "FluxTextEncoder2", "FluxVAEDecoder", "FluxVAEEncoder", + "FluxDiTFBCache", "flux_dit_config", "flux_text_encoder_config", "flux_vae_config", diff --git a/diffsynth_engine/models/flux/flux_dit_fbcache.py b/diffsynth_engine/models/flux/flux_dit_fbcache.py new file mode 100644 index 00000000..da6f49af --- /dev/null +++ b/diffsynth_engine/models/flux/flux_dit_fbcache.py @@ -0,0 +1,206 @@ +import torch +import numpy as np +from typing import Dict, Optional + +from diffsynth_engine.models.utils import no_init_weights +from diffsynth_engine.utils.gguf import gguf_inference +from diffsynth_engine.utils.fp8_linear import fp8_inference +from diffsynth_engine.utils.parallel import ( + cfg_parallel, + cfg_parallel_unshard, + sequence_parallel, + sequence_parallel_unshard, +) +from diffsynth_engine.utils import logging +from diffsynth_engine.models.flux.flux_dit import FluxDiT + +logger = logging.get_logger(__name__) + + +class FluxDiTFBCache(FluxDiT): + def __init__( + self, + in_channel: int = 64, + attn_impl: Optional[str] = None, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + relative_l1_threshold: float = 0.05, + ): + super().__init__() + self.relative_l1_threshold = relative_l1_threshold + self.step_count = 0 + self.num_inference_steps = 0 + + def is_relative_l1_below_threshold(self, prev_residual, residual, threshold): + if threshold <= 0.0: + return False + + if prev_residual.shape != residual.shape: + return False + + mean_diff = (prev_residual - residual).abs().mean() + mean_prev_residual = prev_residual.abs().mean() + diff = mean_diff / mean_prev_residual + return diff.item() < threshold + + def refresh_cache_status(self, num_inference_steps): + self.step_count = 0 + self.num_inference_steps = num_inference_steps + + def forward( + self, + hidden_states, + timestep, + prompt_emb, + pooled_prompt_emb, + image_emb, + guidance, + text_ids, + image_ids=None, + controlnet_double_block_output=None, + controlnet_single_block_output=None, + **kwargs, + ): + h, w = hidden_states.shape[-2:] + if image_ids is None: + image_ids = self.prepare_image_ids(hidden_states) + controlnet_double_block_output = ( + controlnet_double_block_output if controlnet_double_block_output is not None else () + ) + controlnet_single_block_output = ( + controlnet_single_block_output if controlnet_single_block_output is not None else () + ) + + fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) + use_cfg = hidden_states.shape[0] > 1 + with ( + fp8_inference(fp8_linear_enabled), + gguf_inference(), + cfg_parallel( + ( + hidden_states, + timestep, + prompt_emb, + pooled_prompt_emb, + image_emb, + guidance, + text_ids, + image_ids, + *controlnet_double_block_output, + *controlnet_single_block_output, + ), + use_cfg=use_cfg, + ), + ): + # warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding + # addition of floating point numbers does not meet commutative law + conditioning = self.time_embedder(timestep, hidden_states.dtype) + if self.guidance_embedder is not None: + guidance = guidance * 1000 + conditioning += self.guidance_embedder(guidance, hidden_states.dtype) + conditioning += self.pooled_text_embedder(pooled_prompt_emb) + rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + text_rope_emb = rope_emb[:, :, : text_ids.size(1)] + image_rope_emb = rope_emb[:, :, text_ids.size(1) :] + hidden_states = self.patchify(hidden_states) + + with sequence_parallel( + ( + hidden_states, + prompt_emb, + text_rope_emb, + image_rope_emb, + *controlnet_double_block_output, + *controlnet_single_block_output, + ), + seq_dims=( + 1, + 1, + 2, + 2, + *(1 for _ in controlnet_double_block_output), + *(1 for _ in controlnet_single_block_output), + ), + ): + hidden_states = self.x_embedder(hidden_states) + prompt_emb = self.context_embedder(prompt_emb) + rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2) + + # first block + original_hidden_states = hidden_states + hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb) + first_hidden_states_residual = hidden_states - original_hidden_states + del original_hidden_states + (first_hidden_states_residual,) = sequence_parallel_unshard( + (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,) + ) + + if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1): + should_calc = True + else: + skip = self.is_relative_l1_below_threshold( + first_hidden_states_residual, + self.prev_first_hidden_states_residual, + threshold=self.relative_l1_threshold, + ) + should_calc = not skip + self.step_count += 1 + + if not should_calc: + hidden_states += self.previous_residual + else: + self.prev_first_hidden_states_residual = first_hidden_states_residual + + ori_hidden_states = hidden_states.clone() + for i, block in enumerate(self.blocks[1:]): + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb) + if len(controlnet_double_block_output) > 0: + interval_control = len(self.blocks) / len(controlnet_double_block_output) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_double_block_output[i // interval_control] + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + for i, block in enumerate(self.single_blocks): + hidden_states = block(hidden_states, conditioning, rope_emb, image_emb) + if len(controlnet_single_block_output) > 0: + interval_control = len(self.single_blocks) / len(controlnet_double_block_output) + interval_control = int(np.ceil(interval_control)) + hidden_states = hidden_states + controlnet_single_block_output[i // interval_control] + + hidden_states = hidden_states[:, prompt_emb.shape[1] :] + + previous_residual = hidden_states - ori_hidden_states + self.previous_residual = previous_residual + + hidden_states = self.final_norm_out(hidden_states, conditioning) + hidden_states = self.final_proj_out(hidden_states) + (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,)) + + hidden_states = self.unpatchify(hidden_states, h, w) + (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg) + + return hidden_states + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + device: str, + dtype: torch.dtype, + in_channel: int = 64, + attn_impl: Optional[str] = None, + ): + with no_init_weights(): + model = torch.nn.utils.skip_init( + cls, + device=device, + dtype=dtype, + in_channel=in_channel, + attn_impl=attn_impl, + ) + model = model.requires_grad_(False) # for loading gguf + model.load_state_dict(state_dict, assign=True) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + def get_fsdp_modules(self): + return ["blocks", "single_blocks"] diff --git a/diffsynth_engine/models/sd/sd_controlnet.py b/diffsynth_engine/models/sd/sd_controlnet.py index c0e4b70b..362b31f5 100644 --- a/diffsynth_engine/models/sd/sd_controlnet.py +++ b/diffsynth_engine/models/sd/sd_controlnet.py @@ -12,18 +12,29 @@ DownSampler, ) + class ControlNetConditioningLayer(nn.Module): - def __init__(self, channels = (3, 16, 32, 96, 256, 320), device = "cuda:0", dtype=torch.float16): + def __init__(self, channels=(3, 16, 32, 96, 256, 320), device="cuda:0", dtype=torch.float16): super().__init__() self.blocks = torch.nn.ModuleList([]) - self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype)) + self.blocks.append( + torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype) + ) self.blocks.append(torch.nn.SiLU()) for i in range(1, len(channels) - 2): - self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype)) + self.blocks.append( + torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype) + ) self.blocks.append(torch.nn.SiLU()) - self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype)) + self.blocks.append( + torch.nn.Conv2d( + channels[i], channels[i + 1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype + ) + ) self.blocks.append(torch.nn.SiLU()) - self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype)) + self.blocks.append( + torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype) + ) def forward(self, conditioning): for block in self.blocks: @@ -38,15 +49,73 @@ def __init__(self): def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # architecture block_types = [ - 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', - 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', - 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock', - 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock', - 'ResnetBlock', 'AttentionBlock', 'ResnetBlock', - 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler', - 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', - 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler', - 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock' + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "DownSampler", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "DownSampler", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "DownSampler", + "PushBlock", + "ResnetBlock", + "PushBlock", + "ResnetBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "ResnetBlock", + "PopBlock", + "ResnetBlock", + "PopBlock", + "ResnetBlock", + "PopBlock", + "ResnetBlock", + "UpSampler", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "UpSampler", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "UpSampler", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", + "PopBlock", + "ResnetBlock", + "AttentionBlock", ] # controlnet_rename_dict @@ -66,7 +135,7 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight", "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias", "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight", - "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias", + "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias", } # Rename each parameter @@ -91,7 +160,12 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: if names[0] == "mid_block": names.insert(1, "0") - block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type = { + "resnets": "ResnetBlock", + "attentions": "AttentionBlock", + "downsamplers": "DownSampler", + "upsamplers": "UpSampler", + }[names[2]] block_type_with_id = ".".join(names[:4]) if block_type_with_id != last_block_type_with_id[block_type]: block_id[block_type] += 1 @@ -102,9 +176,9 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc names = ["blocks", str(block_id[block_type])] + names[4:] if "ff" in names: ff_index = names.index("ff") - component = ".".join(names[ff_index:ff_index+3]) + component = ".".join(names[ff_index : ff_index + 3]) component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] - names = names[:ff_index] + [component] + names[ff_index+3:] + names = names[:ff_index] + [component] + names[ff_index + 3 :] if "to_out" in names: names.pop(names.index("to_out") + 1) else: @@ -117,13 +191,21 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc if ".proj_in." in name or ".proj_out." in name: param = param.squeeze() if rename_dict[name] in [ - "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias", - "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias" + "controlnet_blocks.1.bias", + "controlnet_blocks.2.bias", + "controlnet_blocks.3.bias", + "controlnet_blocks.5.bias", + "controlnet_blocks.6.bias", + "controlnet_blocks.8.bias", + "controlnet_blocks.9.bias", + "controlnet_blocks.10.bias", + "controlnet_blocks.11.bias", + "controlnet_blocks.12.bias", ]: continue state_dict_[rename_dict[name]] = param return state_dict_ - + def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: rename_dict = { "control_model.time_embed.0.weight": "time_embedding.timestep_embedder.0.weight", @@ -496,69 +578,71 @@ def __init__( self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype) self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype) - self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype) + self.controlnet_conv_in = ControlNetConditioningLayer( + channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype + ) - self.blocks = torch.nn.ModuleList([ - # CrossAttnDownBlock2D - ResnetBlock(320, 320, 1280, device=device, dtype=dtype), - AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(320, 320, 1280, device=device, dtype=dtype), - AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), - PushBlock(), - DownSampler(320, device=device, dtype=dtype), - PushBlock(), - # CrossAttnDownBlock2D - ResnetBlock(320, 640, 1280, device=device, dtype=dtype), - AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(640, 640, 1280, device=device, dtype=dtype), - AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), - PushBlock(), - DownSampler(640, device=device, dtype=dtype), - PushBlock(), - # CrossAttnDownBlock2D - ResnetBlock(640, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), - PushBlock(), - DownSampler(1280, device=device, dtype=dtype), - PushBlock(), - # DownBlock2D - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - PushBlock(), - # UNetMidBlock2DCrossAttn - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - PushBlock() - ]) + self.blocks = torch.nn.ModuleList( + [ + # CrossAttnDownBlock2D + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), + AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), + AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), + PushBlock(), + DownSampler(320, device=device, dtype=dtype), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280, device=device, dtype=dtype), + AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(640, 640, 1280, device=device, dtype=dtype), + AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), + PushBlock(), + DownSampler(640, device=device, dtype=dtype), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), + PushBlock(), + DownSampler(1280, device=device, dtype=dtype), + PushBlock(), + # DownBlock2D + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + PushBlock(), + ] + ) - self.controlnet_blocks = torch.nn.ModuleList([ - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), - ]) + self.controlnet_blocks = torch.nn.ModuleList( + [ + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), + ] + ) - def forward( - self, - sample, timestep, encoder_hidden_states, conditioning, - **kwargs - ): + def forward(self, sample, timestep, encoder_hidden_states, conditioning, **kwargs): # 1. time time_emb = self.time_embedding(timestep, dtype=sample.dtype) @@ -585,9 +669,7 @@ def from_state_dict( attn_impl: Optional[str] = None, ): with no_init_weights(): - model = torch.nn.utils.skip_init( - cls, attn_impl=attn_impl, device=device, dtype=dtype - ) + model = torch.nn.utils.skip_init(cls, attn_impl=attn_impl, device=device, dtype=dtype) model.load_state_dict(state_dict) model.to(device=device, dtype=dtype, non_blocking=True) - return model \ No newline at end of file + return model diff --git a/diffsynth_engine/models/sdxl/__init__.py b/diffsynth_engine/models/sdxl/__init__.py index c383e45e..58f59bd6 100644 --- a/diffsynth_engine/models/sdxl/__init__.py +++ b/diffsynth_engine/models/sdxl/__init__.py @@ -9,7 +9,7 @@ "SDXLUNet", "SDXLVAEDecoder", "SDXLVAEEncoder", - "SDXLControlNetUnion", + "SDXLControlNetUnion", "sdxl_text_encoder_config", "sdxl_unet_config", ] diff --git a/diffsynth_engine/models/sdxl/sdxl_controlnet.py b/diffsynth_engine/models/sdxl/sdxl_controlnet.py index ef09ad2f..b597d828 100644 --- a/diffsynth_engine/models/sdxl/sdxl_controlnet.py +++ b/diffsynth_engine/models/sdxl/sdxl_controlnet.py @@ -12,23 +12,27 @@ from collections import OrderedDict -class QuickGELU(torch.nn.Module): +class QuickGELU(torch.nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) -class ResidualAttentionBlock(torch.nn.Module): +class ResidualAttentionBlock(torch.nn.Module): def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, device="cuda:0", dtype=torch.float16): super().__init__() self.attn = torch.nn.MultiheadAttention(d_model, n_head, device=device, dtype=dtype) self.ln_1 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype) - self.mlp = torch.nn.Sequential(OrderedDict([ - ("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)), - ("gelu", QuickGELU()), - ("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype)) - ])) + self.mlp = torch.nn.Sequential( + OrderedDict( + [ + ("c_fc", torch.nn.Linear(d_model, d_model * 4, device=device, dtype=dtype)), + ("gelu", QuickGELU()), + ("c_proj", torch.nn.Linear(d_model * 4, d_model, device=device, dtype=dtype)), + ] + ) + ) self.ln_2 = torch.nn.LayerNorm(d_model, device=device, dtype=dtype) self.attn_mask = attn_mask @@ -49,10 +53,30 @@ def __init__(self): def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # architecture block_types = [ - "ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock", - "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock", - "ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", - "ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock" + "ResnetBlock", + "PushBlock", + "ResnetBlock", + "PushBlock", + "DownSampler", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "DownSampler", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "PushBlock", + "ResnetBlock", + "AttentionBlock", + "ResnetBlock", + "PushBlock", ] # controlnet_rename_dict @@ -107,7 +131,12 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: if names[0] == "mid_block": names.insert(1, "0") - block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] + block_type = { + "resnets": "ResnetBlock", + "attentions": "AttentionBlock", + "downsamplers": "DownSampler", + "upsamplers": "UpSampler", + }[names[2]] block_type_with_id = ".".join(names[:4]) if block_type_with_id != last_block_type_with_id[block_type]: block_id[block_type] += 1 @@ -118,9 +147,9 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc names = ["blocks", str(block_id[block_type])] + names[4:] if "ff" in names: ff_index = names.index("ff") - component = ".".join(names[ff_index:ff_index+3]) + component = ".".join(names[ff_index : ff_index + 3]) component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] - names = names[:ff_index] + [component] + names[ff_index+3:] + names = names[:ff_index] + [component] + names[ff_index + 3 :] if "to_out" in names: names.pop(names.index("to_out") + 1) else: @@ -137,19 +166,20 @@ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torc param = param.squeeze() state_dict_[rename_dict[name]] = param return state_dict_ - + # TODO: check civitai def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._from_diffusers(state_dict) - def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self._from_diffusers(state_dict) + class SDXLControlNetUnion(PreTrainedModel): converter = SDXLControlNetUnionStateDictConverter() - def __init__(self, + def __init__( + self, attn_impl: Optional[str] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, @@ -157,68 +187,78 @@ def __init__(self, super().__init__() self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype) - self.add_time_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype) + self.add_time_proj = TemporalTimesteps( + 256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype + ) self.add_time_embedding = torch.nn.Sequential( torch.nn.Linear(2816, 1280, device=device, dtype=dtype), torch.nn.SiLU(), - torch.nn.Linear(1280, 1280, device=device, dtype=dtype) + torch.nn.Linear(1280, 1280, device=device, dtype=dtype), + ) + self.control_type_proj = TemporalTimesteps( + 256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype ) - self.control_type_proj = TemporalTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0, device=device, dtype=dtype) self.control_type_embedding = torch.nn.Sequential( torch.nn.Linear(256 * 8, 1280, device=device, dtype=dtype), torch.nn.SiLU(), - torch.nn.Linear(1280, 1280, device=device, dtype=dtype) + torch.nn.Linear(1280, 1280, device=device, dtype=dtype), ) self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype) - self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype) + self.controlnet_conv_in = ControlNetConditioningLayer( + channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype + ) self.controlnet_transformer = ResidualAttentionBlock(320, 8, device=device, dtype=dtype) self.task_embedding = torch.nn.Parameter(torch.randn(8, 320)) self.spatial_ch_projs = torch.nn.Linear(320, 320, device=device, dtype=dtype) - self.blocks = torch.nn.ModuleList([ - # DownBlock2D - ResnetBlock(320, 320, 1280, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(320, 320, 1280, device=device, dtype=dtype), - PushBlock(), - DownSampler(320, device=device, dtype=dtype), - PushBlock(), - # CrossAttnDownBlock2D - ResnetBlock(320, 640, 1280, device=device, dtype=dtype), - AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(640, 640, 1280, device=device, dtype=dtype), - AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype), - PushBlock(), - DownSampler(640, device=device, dtype=dtype), - PushBlock(), - # CrossAttnDownBlock2D - ResnetBlock(640, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), - PushBlock(), - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), - PushBlock(), - # UNetMidBlock2DCrossAttn - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), - ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), - PushBlock() - ]) - - self.controlnet_blocks = torch.nn.ModuleList([ - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), - ]) + self.blocks = torch.nn.ModuleList( + [ + # DownBlock2D + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), + PushBlock(), + DownSampler(320, device=device, dtype=dtype), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(320, 640, 1280, device=device, dtype=dtype), + AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(640, 640, 1280, device=device, dtype=dtype), + AttentionBlock(10, 64, 640, 2, 2048, device=device, dtype=dtype), + PushBlock(), + DownSampler(640, device=device, dtype=dtype), + PushBlock(), + # CrossAttnDownBlock2D + ResnetBlock(640, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), + PushBlock(), + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), + PushBlock(), + # UNetMidBlock2DCrossAttn + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + AttentionBlock(20, 64, 1280, 10, 2048, device=device, dtype=dtype), + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), + PushBlock(), + ] + ) + + self.controlnet_blocks = torch.nn.ModuleList( + [ + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), + ] + ) # 0 -- openpose # 1 -- depth @@ -236,10 +276,9 @@ def __init__(self, "lineart": 3, "lineart_anime": 3, "tile": 6, - "inpaint": 7 + "inpaint": 7, } - def fuse_condition_to_input(self, hidden_states, task_id, conditioning): controlnet_cond = self.controlnet_conv_in(conditioning) feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) @@ -247,19 +286,25 @@ def fuse_condition_to_input(self, hidden_states, task_id, conditioning): x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1) x = self.controlnet_transformer(x) - alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1) + alpha = self.spatial_ch_projs(x[:, 0]).unsqueeze(-1).unsqueeze(-1) controlnet_cond_fuser = controlnet_cond + alpha hidden_states = hidden_states + controlnet_cond_fuser return hidden_states - def forward( self, - sample, timestep, encoder_hidden_states, - conditioning, processor_name, add_time_id, add_text_embeds, - tiled=False, tile_size=64, tile_stride=32, - **kwargs + sample, + timestep, + encoder_hidden_states, + conditioning, + processor_name, + add_time_id, + add_text_embeds, + tiled=False, + tile_size=64, + tile_stride=32, + **kwargs, ): task_id = self.task_id[processor_name] diff --git a/diffsynth_engine/models/sdxl/sdxl_unet.py b/diffsynth_engine/models/sdxl/sdxl_unet.py index 61dc525e..c8cdf0bd 100644 --- a/diffsynth_engine/models/sdxl/sdxl_unet.py +++ b/diffsynth_engine/models/sdxl/sdxl_unet.py @@ -268,13 +268,12 @@ def forward(self, x, timestep, context, y, controlnet_res_stack=None, **kwargs): text_emb, res_stack, ) - + # 3.2 Controlnet if i == controlnet_insert_block_id and controlnet_res_stack is not None: hidden_states += controlnet_res_stack.pop() res_stack = [res + controlnet_res for res, controlnet_res in zip(res_stack, controlnet_res_stack)] - # 4. output hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) diff --git a/diffsynth_engine/pipelines/controlnet_helper.py b/diffsynth_engine/pipelines/controlnet_helper.py index fa8c110d..d219bcde 100644 --- a/diffsynth_engine/pipelines/controlnet_helper.py +++ b/diffsynth_engine/pipelines/controlnet_helper.py @@ -6,6 +6,7 @@ ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]] + @dataclass class ControlNetParams: image: ImageType @@ -14,11 +15,12 @@ class ControlNetParams: mask: Optional[ImageType] = None control_start: float = 0 control_end: float = 1 - processor_name: Optional[str] = None # only used for sdxl controlnet union now + processor_name: Optional[str] = None # only used for sdxl controlnet union now + def accumulate(result, new_item): if result is None: return new_item for i, item in enumerate(new_item): result[i] += item - return result \ No newline at end of file + return result diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index 93ec408b..060a8535 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -6,7 +6,7 @@ import math from einops import rearrange from enum import Enum -from typing import Callable, Dict, List, Tuple, Optional +from typing import Callable, Dict, List, Tuple, Optional, Union from tqdm import tqdm from PIL import Image from dataclasses import dataclass @@ -16,6 +16,7 @@ FluxVAEDecoder, FluxVAEEncoder, FluxDiT, + FluxDiTFBCache, flux_dit_config, flux_text_encoder_config, ) @@ -429,6 +430,7 @@ def get_in_channel(self): elif self == ControlType.bfl_fill: return 384 + @dataclass class FluxModelConfig: dit_path: str | os.PathLike @@ -460,7 +462,7 @@ def __init__( tokenizer_2: T5TokenizerFast, text_encoder_1: FluxTextEncoder1, text_encoder_2: FluxTextEncoder2, - dit: FluxDiT, + dit: Union[FluxDiT, FluxDiTFBCache], vae_decoder: FluxVAEDecoder, vae_encoder: FluxVAEEncoder, load_text_encoder: bool = True, @@ -518,6 +520,8 @@ def from_pretrained( offload_mode: str | None = None, parallelism: int = 1, use_cfg_parallel: bool = False, + use_fb_cache: bool = False, + fb_cache_relative_l1_threshold: float = 0.05, ) -> "FluxImagePipeline": model_config = ( model_path_or_config @@ -561,8 +565,13 @@ def from_pretrained( vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype) vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype) + if use_fb_cache: + dit_class = FluxDiTFBCache + else: + dit_class = FluxDiT + with LoRAContext(): - dit = FluxDiT.from_state_dict( + dit = dit_class.from_state_dict( dit_state_dict, device=init_device, dtype=model_config.dit_dtype, @@ -968,6 +977,8 @@ def __call__( controlnet_params: List[ControlNetParams] | ControlNetParams = [], progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status) ): + if isinstance(self.dit, FluxDiTFBCache): + self.dit.refresh_cache_status(num_inference_steps) if not isinstance(controlnet_params, list): controlnet_params = [controlnet_params] if self.control_type != ControlType.normal: diff --git a/diffsynth_engine/pipelines/sd_image.py b/diffsynth_engine/pipelines/sd_image.py index 4ec0f1c6..f5ed7368 100644 --- a/diffsynth_engine/pipelines/sd_image.py +++ b/diffsynth_engine/pipelines/sd_image.py @@ -291,7 +291,7 @@ def predict_multicontrolnet( current_step: int, total_step: int, ): - controlnet_res_stack = None + controlnet_res_stack = None for param in controlnet_params: current_scale = param.scale if not ( @@ -303,15 +303,10 @@ def predict_multicontrolnet( if self.offload_mode is not None: empty_cache() param.model.to(self.device) - controlnet_res = param.model( - latents, - timestep, - prompt_emb, - param.image - ) + controlnet_res = param.model(latents, timestep, prompt_emb, param.image) controlnet_res = [res * current_scale for res in controlnet_res] if self.offload_mode is not None: - param.model.to("cpu") + param.model.to("cpu") empty_cache() controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res) return controlnet_res_stack @@ -324,16 +319,22 @@ def predict_noise_with_cfg( negative_prompt_emb: torch.Tensor, controlnet_params: List[ControlNetParams], current_step: int, - total_step: int, + total_step: int, cfg_scale: float, batch_cfg: bool = True, ): if cfg_scale <= 1.0: - return self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step) + return self.predict_noise( + latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step + ) if not batch_cfg: # cfg by predict noise one by one - positive_noise_pred = self.predict_noise(latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step) - negative_noise_pred = self.predict_noise(latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step) + positive_noise_pred = self.predict_noise( + latents, timestep, positive_prompt_emb, controlnet_params, current_step, total_step + ) + negative_noise_pred = self.predict_noise( + latents, timestep, negative_prompt_emb, controlnet_params, current_step, total_step + ) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred else: @@ -341,12 +342,16 @@ def predict_noise_with_cfg( prompt_emb = torch.cat([positive_prompt_emb, negative_prompt_emb], dim=0) latents = torch.cat([latents, latents], dim=0) timestep = torch.cat([timestep, timestep], dim=0) - positive_noise_pred, negative_noise_pred = self.predict_noise(latents, timestep, prompt_emb, controlnet_params, current_step, total_step).chunk(2) + positive_noise_pred, negative_noise_pred = self.predict_noise( + latents, timestep, prompt_emb, controlnet_params, current_step, total_step + ).chunk(2) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred def predict_noise(self, latents, timestep, prompt_emb, controlnet_params, current_step, total_step): - controlnet_res_stack = self.predict_multicontrolnet(latents, timestep, prompt_emb, controlnet_params, current_step, total_step) + controlnet_res_stack = self.predict_multicontrolnet( + latents, timestep, prompt_emb, controlnet_params, current_step, total_step + ) noise_pred = self.unet( x=latents, @@ -433,7 +438,7 @@ def __call__( cfg_scale=cfg_scale, controlnet_params=controlnet_params, current_step=i, - total_step=len(timesteps), + total_step=len(timesteps), batch_cfg=self.batch_cfg, ) # Denoise diff --git a/diffsynth_engine/pipelines/sdxl_image.py b/diffsynth_engine/pipelines/sdxl_image.py index ecc23e67..c77505fc 100644 --- a/diffsynth_engine/pipelines/sdxl_image.py +++ b/diffsynth_engine/pipelines/sdxl_image.py @@ -31,6 +31,7 @@ logger = logging.get_logger(__name__) + class SDXLLoRAConverter(LoRAStateDictConverter): def _replace_kohya_te1_key(self, key): key = key.replace("lora_te1_text_model_encoder_layers_", "encoders.") @@ -91,7 +92,7 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic else: raise ValueError(f"Unsupported key: {key}") # clip skip - te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith('encoders.11')} + te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith("encoders.11")} return {"unet": unet_dict, "text_encoder": te1_dict, "text_encoder_2": te2_dict} def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: @@ -279,10 +280,7 @@ def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h condition = self.preprocess_control_image(param.image).to(device=self.device, dtype=self.dtype) results.append( ControlNetParams( - model=param.model, - scale=param.scale, - image=condition, - processor_name=param.processor_name + model=param.model, scale=param.scale, image=condition, processor_name=param.processor_name ) ) return results @@ -307,13 +305,13 @@ def predict_multicontrolnet( latents: torch.Tensor, timestep: torch.Tensor, prompt_emb: torch.Tensor, - add_text_embeds: torch.Tensor, - add_time_id: torch.Tensor, + add_text_embeds: torch.Tensor, + add_time_id: torch.Tensor, controlnet_params: List[ControlNetParams], current_step: int, total_step: int, ): - controlnet_res_stack = None + controlnet_res_stack = None for param in controlnet_params: current_scale = param.scale if not ( @@ -338,8 +336,8 @@ def predict_multicontrolnet( ) controlnet_res = [res * current_scale for res in controlnet_res] if self.offload_mode is not None: - param.model.to("cpu") - empty_cache() + param.model.to("cpu") + empty_cache() controlnet_res_stack = accumulate(controlnet_res_stack, controlnet_res) return controlnet_res_stack @@ -353,20 +351,36 @@ def predict_noise_with_cfg( negative_add_text_embeds: torch.Tensor, controlnet_params: List[ControlNetParams], current_step: int, - total_step: int, + total_step: int, add_time_id: torch.Tensor, cfg_scale: float, batch_cfg: bool = True, ): if cfg_scale <= 1.0: - return self.predict_noise(latents, timestep, positive_prompt_emb, add_time_id, controlnet_params, current_step, total_step) + return self.predict_noise( + latents, timestep, positive_prompt_emb, add_time_id, controlnet_params, current_step, total_step + ) if not batch_cfg: # cfg by predict noise one by one positive_noise_pred = self.predict_noise( - latents, timestep, positive_prompt_emb, positive_add_text_embeds, add_time_id, controlnet_params, current_step, total_step + latents, + timestep, + positive_prompt_emb, + positive_add_text_embeds, + add_time_id, + controlnet_params, + current_step, + total_step, ) negative_noise_pred = self.predict_noise( - latents, timestep, negative_prompt_emb, negative_add_text_embeds, add_time_id, controlnet_params, current_step, total_step + latents, + timestep, + negative_prompt_emb, + negative_add_text_embeds, + add_time_id, + controlnet_params, + current_step, + total_step, ) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred @@ -378,14 +392,25 @@ def predict_noise_with_cfg( latents = torch.cat([latents, latents], dim=0) timestep = torch.cat([timestep, timestep], dim=0) positive_noise_pred, negative_noise_pred = self.predict_noise( - latents, timestep, prompt_emb, add_text_embeds, add_time_ids, controlnet_params, current_step, total_step + latents, + timestep, + prompt_emb, + add_text_embeds, + add_time_ids, + controlnet_params, + current_step, + total_step, ) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred - def predict_noise(self, latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step): + def predict_noise( + self, latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step + ): y = self.prepare_add_embeds(add_text_embeds, add_time_id, self.dtype) - controlnet_res_stack = self.predict_multicontrolnet(latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step) + controlnet_res_stack = self.predict_multicontrolnet( + latents, timestep, prompt_emb, add_text_embeds, add_time_id, controlnet_params, current_step, total_step + ) noise_pred = self.unet( x=latents, @@ -433,7 +458,7 @@ def __call__( width: int = 1024, num_inference_steps: int = 20, seed: int | None = None, - controlnet_params: List[ControlNetParams] | ControlNetParams = [], + controlnet_params: List[ControlNetParams] | ControlNetParams = [], progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status) ): if not isinstance(controlnet_params, list): @@ -491,7 +516,7 @@ def __call__( cfg_scale=cfg_scale, controlnet_params=controlnet_params, current_step=i, - total_step=len(timesteps), + total_step=len(timesteps), batch_cfg=self.batch_cfg, ) # Denoise diff --git a/tests/test_pipelines/test_flux_bfl_image.py b/tests/test_pipelines/test_flux_bfl_image.py index 2b9d28dc..19facf53 100644 --- a/tests/test_pipelines/test_flux_bfl_image.py +++ b/tests/test_pipelines/test_flux_bfl_image.py @@ -115,7 +115,7 @@ def setUpClass(cls): "black-forest-labs/FLUX.1-Kontext-dev", revision="master", path="flux1-kontext-dev.safetensors" ) cls.pipe = FluxImagePipeline.from_pretrained(kontext_model_path, control_type=ControlType.bfl_kontext).eval() - + def test_kontext_image(self): image = self.pipe( prompt="Make the wall color to red", diff --git a/tests/test_pipelines/test_sd_controlnet.py b/tests/test_pipelines/test_sd_controlnet.py index 1d3d06bc..3bb2e49e 100644 --- a/tests/test_pipelines/test_sd_controlnet.py +++ b/tests/test_pipelines/test_sd_controlnet.py @@ -36,6 +36,6 @@ def test_canny(self): # TODO: replace image self.assertImageEqualAndSaveFailed(output_image, "flux/flux_union_pro_canny.png", threshold=0.7) - + if __name__ == "__main__": unittest.main() diff --git a/tests/test_pipelines/test_sdxl_controlnet.py b/tests/test_pipelines/test_sdxl_controlnet.py index 90bb8960..51aa3f03 100644 --- a/tests/test_pipelines/test_sdxl_controlnet.py +++ b/tests/test_pipelines/test_sdxl_controlnet.py @@ -34,5 +34,6 @@ def test_canny(self): ) self.assertImageEqualAndSaveFailed(output_image, "flux/flux_union_pro_canny.png", threshold=0.7) + if __name__ == "__main__": unittest.main() From 5578dc00ae0626cb7c0dc3fb4689687211658871 Mon Sep 17 00:00:00 2001 From: "yueyang.hyy" Date: Thu, 10 Jul 2025 10:19:15 +0800 Subject: [PATCH 2/2] fix fb cache dit init --- .../models/flux/flux_dit_fbcache.py | 13 ++++----- diffsynth_engine/pipelines/flux_image.py | 29 +++++++++++-------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/diffsynth_engine/models/flux/flux_dit_fbcache.py b/diffsynth_engine/models/flux/flux_dit_fbcache.py index da6f49af..60e58ae3 100644 --- a/diffsynth_engine/models/flux/flux_dit_fbcache.py +++ b/diffsynth_engine/models/flux/flux_dit_fbcache.py @@ -26,7 +26,7 @@ def __init__( dtype: torch.dtype = torch.bfloat16, relative_l1_threshold: float = 0.05, ): - super().__init__() + super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype) self.relative_l1_threshold = relative_l1_threshold self.step_count = 0 self.num_inference_steps = 0 @@ -130,7 +130,7 @@ def forward( original_hidden_states = hidden_states hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb) first_hidden_states_residual = hidden_states - original_hidden_states - del original_hidden_states + (first_hidden_states_residual,) = sequence_parallel_unshard( (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,) ) @@ -151,7 +151,7 @@ def forward( else: self.prev_first_hidden_states_residual = first_hidden_states_residual - ori_hidden_states = hidden_states.clone() + first_hidden_states = hidden_states.clone() for i, block in enumerate(self.blocks[1:]): hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb) if len(controlnet_double_block_output) > 0: @@ -168,7 +168,7 @@ def forward( hidden_states = hidden_states[:, prompt_emb.shape[1] :] - previous_residual = hidden_states - ori_hidden_states + previous_residual = hidden_states - first_hidden_states self.previous_residual = previous_residual hidden_states = self.final_norm_out(hidden_states, conditioning) @@ -188,6 +188,7 @@ def from_state_dict( dtype: torch.dtype, in_channel: int = 64, attn_impl: Optional[str] = None, + fb_cache_relative_l1_threshold: float = 0.05, ): with no_init_weights(): model = torch.nn.utils.skip_init( @@ -196,11 +197,9 @@ def from_state_dict( dtype=dtype, in_channel=in_channel, attn_impl=attn_impl, + fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold, ) model = model.requires_grad_(False) # for loading gguf model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) return model - - def get_fsdp_modules(self): - return ["blocks", "single_blocks"] diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index 060a8535..06f51eb0 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -565,19 +565,24 @@ def from_pretrained( vae_decoder = FluxVAEDecoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype) vae_encoder = FluxVAEEncoder.from_state_dict(vae_state_dict, device=init_device, dtype=model_config.vae_dtype) - if use_fb_cache: - dit_class = FluxDiTFBCache - else: - dit_class = FluxDiT - with LoRAContext(): - dit = dit_class.from_state_dict( - dit_state_dict, - device=init_device, - dtype=model_config.dit_dtype, - in_channel=control_type.get_in_channel(), - attn_impl=model_config.dit_attn_impl, - ) + if use_fb_cache: + dit = FluxDiTFBCache.from_state_dict( + dit_state_dict, + device=init_device, + dtype=model_config.dit_dtype, + in_channel=control_type.get_in_channel(), + attn_impl=model_config.dit_attn_impl, + relative_l1_threshold=fb_cache_relative_l1_threshold, + ) + else: + dit = FluxDiT.from_state_dict( + dit_state_dict, + device=init_device, + dtype=model_config.dit_dtype, + in_channel=control_type.get_in_channel(), + attn_impl=model_config.dit_attn_impl, + ) if model_config.use_fp8_linear: enable_fp8_linear(dit)