Skip to content

Commit 9a5cdd9

Browse files
authored
feat: support FnB0 for z-image w/ cp (#503)
* feat: support FnB0 for z-image w/ cp * feat: support FnB0 for z-image w/ cp * feat: support FnB0 for z-image w/ cp
1 parent 879e85c commit 9a5cdd9

File tree

3 files changed

+95
-47
lines changed

3 files changed

+95
-47
lines changed

examples/parallelism/run_zimage_cp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@
5858
if args.cache:
5959
# Only warmup 4 steps (total 9 steps) for distilled models
6060
args.max_warmup_steps = min(4, args.max_warmup_steps)
61-
# Temp workaroud for issue: https://github.com/vipshop/cache-dit/issues/498
62-
args.Bn = max(1, args.Bn)
6361

6462
cachify(args, pipe)
6563

src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_zimage.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Optional, Union, List
2+
from typing import Optional
33
from diffusers.models.modeling_utils import ModelMixin
44
from diffusers import ZImageTransformer2DModel
55

@@ -19,6 +19,7 @@
1919
ContextParallelismPlanner,
2020
ContextParallelismPlannerRegister,
2121
)
22+
from ..utils import maybe_patch_cp_find_submodule_by_name
2223

2324
from cache_dit.logger import init_logger
2425

@@ -44,12 +45,18 @@ def apply(
4445
if transformer._cp_plan is not None:
4546
return transformer._cp_plan
4647

48+
# NOTE: This only a temporary workaround for ZImage to make context parallelism
49+
# work compatible with DBCache FnB0. The better way is to make DBCache fully
50+
# compatible with diffusers native context parallelism, e.g., check the split/gather
51+
# hooks in each block/layer in the initialization of DBCache.
52+
# Issue: https://github.com/vipshop/cache-dit/issues/498
53+
maybe_patch_cp_find_submodule_by_name()
4754
# Otherwise, use the custom CP plan defined here, this maybe
4855
# a little different from the native diffusers implementation
4956
# for some models.
5057
n_noise_refiner_layers = len(transformer.noise_refiner) # 2
5158
n_context_refiner_layers = len(transformer.context_refiner) # 2
52-
num_layers = len(transformer.layers) # 30
59+
# num_layers = len(transformer.layers) # 30
5360
_cp_plan = {
5461
# 0. Hooks for noise_refiner layers, 2
5562
"noise_refiner.0": {
@@ -78,50 +85,11 @@ def apply(
7885
"layers.*": {
7986
"freqs_cis": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
8087
},
88+
# NEED: call maybe_patch_cp_find_submodule_by_name to support ModuleDict like 'all_final_layer'
89+
"all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
8190
# NOTE: The 'all_final_layer' is a ModuleDict of several final layers,
8291
# each for a specific patch size combination, so we do not add hooks for it here.
8392
# So, we have to gather the output of the last transformer layer.
84-
# "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
85-
f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
93+
# f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
8694
}
8795
return _cp_plan
88-
89-
90-
# TODO: Add this utility function to diffusers to support ModuleDict, such as 'all_final_layer' in ZImage
91-
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L283
92-
def _find_submodule_by_name(
93-
model: torch.nn.Module, name: str
94-
) -> Union[torch.nn.Module, List[torch.nn.Module]]:
95-
if name == "":
96-
return model
97-
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
98-
if first_atom == "*":
99-
if not isinstance(model, torch.nn.ModuleList):
100-
raise ValueError("Wildcard '*' can only be used with ModuleList")
101-
submodules = []
102-
for submodule in model:
103-
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
104-
if not isinstance(subsubmodules, list):
105-
if isinstance(subsubmodules, torch.nn.ModuleDict):
106-
subsubmodules = list(subsubmodules.values())
107-
else:
108-
subsubmodules = [subsubmodules]
109-
submodules.extend(subsubmodules)
110-
return submodules
111-
else:
112-
if hasattr(model, first_atom):
113-
submodule = getattr(model, first_atom)
114-
if isinstance(submodule, torch.nn.ModuleDict):
115-
if remaining_name == "":
116-
return list(submodule.values())
117-
else:
118-
raise ValueError(
119-
f"Cannot access submodule '{remaining_name}' of ModuleDict '{first_atom}' directly. "
120-
f"Please specify the key of the ModuleDict first."
121-
)
122-
return _find_submodule_by_name(submodule, remaining_name)
123-
else:
124-
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
125-
126-
127-
# TODO: Add async Ulysses QKV proj for ZImage model
Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,93 @@
1+
import torch
2+
import functools
3+
import diffusers
4+
from typing import List, Union
5+
16
try:
2-
from diffusers import ContextParallelConfig
7+
from diffusers import ContextParallelConfig # noqa: F401
8+
from diffusers.hooks.context_parallel import (
9+
_find_submodule_by_name as _find_submodule_by_name_for_context_parallel,
10+
)
311

412
def native_diffusers_parallelism_available() -> bool:
513
return True
614

715
except ImportError:
816
ContextParallelConfig = None
17+
_find_submodule_by_name_for_context_parallel = None
918

1019
def native_diffusers_parallelism_available() -> bool:
1120
return False
21+
22+
23+
from cache_dit.logger import init_logger
24+
25+
logger = init_logger(__name__)
26+
27+
# NOTE: Add this utility function to diffusers to support ModuleDict, such as 'all_final_layer', like ZImage
28+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L283
29+
# This function is only used when diffusers native context parallelism is enabled and can compatible with the
30+
# original one.
31+
if (
32+
native_diffusers_parallelism_available()
33+
and _find_submodule_by_name_for_context_parallel is not None
34+
):
35+
36+
@functools.wraps(_find_submodule_by_name_for_context_parallel)
37+
def _patch_find_submodule_by_name(
38+
model: torch.nn.Module, name: str
39+
) -> Union[torch.nn.Module, List[torch.nn.Module]]:
40+
if name == "":
41+
return model
42+
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
43+
if first_atom == "*":
44+
if not isinstance(model, torch.nn.ModuleList):
45+
raise ValueError("Wildcard '*' can only be used with ModuleList")
46+
submodules = []
47+
for submodule in model:
48+
subsubmodules = _patch_find_submodule_by_name(submodule, remaining_name)
49+
if not isinstance(subsubmodules, list):
50+
if isinstance(subsubmodules, torch.nn.ModuleDict):
51+
subsubmodules = list(subsubmodules.values())
52+
else:
53+
subsubmodules = [subsubmodules]
54+
submodules.extend(subsubmodules)
55+
return submodules
56+
else:
57+
if hasattr(model, first_atom):
58+
submodule = getattr(model, first_atom)
59+
if isinstance(submodule, torch.nn.ModuleDict): # e.g, 'all_final_layer' in ZImage
60+
if remaining_name == "":
61+
submodule = list(submodule.values())
62+
# Make sure all values are Modules, not support other complex cases.
63+
for v in submodule:
64+
if not isinstance(v, torch.nn.Module):
65+
raise ValueError(
66+
f"Value '{v}' in ModuleDict '{first_atom}' is not a Module"
67+
)
68+
return submodule
69+
else:
70+
raise ValueError(
71+
f"Cannot access submodule '{remaining_name}' of ModuleDict '{first_atom}' directly. "
72+
f"Please specify the key of the ModuleDict first."
73+
)
74+
return _patch_find_submodule_by_name(submodule, remaining_name)
75+
else:
76+
raise ValueError(
77+
f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'"
78+
)
79+
80+
def maybe_patch_cp_find_submodule_by_name():
81+
if (
82+
diffusers.hooks.context_parallel._find_submodule_by_name
83+
!= _patch_find_submodule_by_name
84+
):
85+
diffusers.hooks.context_parallel._find_submodule_by_name = _patch_find_submodule_by_name
86+
logger.info(
87+
"Patched diffusers.hooks.context_parallel._find_submodule_by_name for ModuleDict support."
88+
)
89+
90+
else:
91+
92+
def maybe_patch_cp_find_submodule_by_name():
93+
pass

0 commit comments

Comments
 (0)