11import torch
2- from typing import Optional , Union , List
2+ from typing import Optional
33from diffusers .models .modeling_utils import ModelMixin
44from diffusers import ZImageTransformer2DModel
55
1919 ContextParallelismPlanner ,
2020 ContextParallelismPlannerRegister ,
2121)
22+ from ..utils import maybe_patch_cp_find_submodule_by_name
2223
2324from cache_dit .logger import init_logger
2425
@@ -44,12 +45,18 @@ def apply(
4445 if transformer ._cp_plan is not None :
4546 return transformer ._cp_plan
4647
48+ # NOTE: This only a temporary workaround for ZImage to make context parallelism
49+ # work compatible with DBCache FnB0. The better way is to make DBCache fully
50+ # compatible with diffusers native context parallelism, e.g., check the split/gather
51+ # hooks in each block/layer in the initialization of DBCache.
52+ # Issue: https://github.com/vipshop/cache-dit/issues/498
53+ maybe_patch_cp_find_submodule_by_name ()
4754 # Otherwise, use the custom CP plan defined here, this maybe
4855 # a little different from the native diffusers implementation
4956 # for some models.
5057 n_noise_refiner_layers = len (transformer .noise_refiner ) # 2
5158 n_context_refiner_layers = len (transformer .context_refiner ) # 2
52- num_layers = len (transformer .layers ) # 30
59+ # num_layers = len(transformer.layers) # 30
5360 _cp_plan = {
5461 # 0. Hooks for noise_refiner layers, 2
5562 "noise_refiner.0" : {
@@ -78,50 +85,11 @@ def apply(
7885 "layers.*" : {
7986 "freqs_cis" : ContextParallelInput (split_dim = 1 , expected_dims = 3 , split_output = False ),
8087 },
88+ # NEED: call maybe_patch_cp_find_submodule_by_name to support ModuleDict like 'all_final_layer'
89+ "all_final_layer" : ContextParallelOutput (gather_dim = 1 , expected_dims = 3 ),
8190 # NOTE: The 'all_final_layer' is a ModuleDict of several final layers,
8291 # each for a specific patch size combination, so we do not add hooks for it here.
8392 # So, we have to gather the output of the last transformer layer.
84- # "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
85- f"layers.{ num_layers - 1 } " : ContextParallelOutput (gather_dim = 1 , expected_dims = 3 ),
93+ # f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
8694 }
8795 return _cp_plan
88-
89-
90- # TODO: Add this utility function to diffusers to support ModuleDict, such as 'all_final_layer' in ZImage
91- # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L283
92- def _find_submodule_by_name (
93- model : torch .nn .Module , name : str
94- ) -> Union [torch .nn .Module , List [torch .nn .Module ]]:
95- if name == "" :
96- return model
97- first_atom , remaining_name = name .split ("." , 1 ) if "." in name else (name , "" )
98- if first_atom == "*" :
99- if not isinstance (model , torch .nn .ModuleList ):
100- raise ValueError ("Wildcard '*' can only be used with ModuleList" )
101- submodules = []
102- for submodule in model :
103- subsubmodules = _find_submodule_by_name (submodule , remaining_name )
104- if not isinstance (subsubmodules , list ):
105- if isinstance (subsubmodules , torch .nn .ModuleDict ):
106- subsubmodules = list (subsubmodules .values ())
107- else :
108- subsubmodules = [subsubmodules ]
109- submodules .extend (subsubmodules )
110- return submodules
111- else :
112- if hasattr (model , first_atom ):
113- submodule = getattr (model , first_atom )
114- if isinstance (submodule , torch .nn .ModuleDict ):
115- if remaining_name == "" :
116- return list (submodule .values ())
117- else :
118- raise ValueError (
119- f"Cannot access submodule '{ remaining_name } ' of ModuleDict '{ first_atom } ' directly. "
120- f"Please specify the key of the ModuleDict first."
121- )
122- return _find_submodule_by_name (submodule , remaining_name )
123- else :
124- raise ValueError (f"'{ first_atom } ' is not a submodule of '{ model .__class__ .__name__ } '" )
125-
126-
127- # TODO: Add async Ulysses QKV proj for ZImage model
0 commit comments