Skip to content

Commit 8bc0c80

Browse files
authored
Refactor/parallel (#157)
* compile repeated blocks * pin memory * refactor parallelism & model init * fix * fix cycle import * fix qwen image * fix examples * check wan2.2 image shape * fix lora
1 parent 6047ee1 commit 8bc0c80

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+716
-721
lines changed

diffsynth_engine/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,33 @@
44
FluxPipelineConfig,
55
WanPipelineConfig,
66
QwenImagePipelineConfig,
7+
HunyuanPipelineConfig,
78
SDStateDicts,
89
SDXLStateDicts,
910
FluxStateDicts,
11+
WanStateDicts,
1012
QwenImageStateDicts,
1113
ControlNetParams,
1214
ControlType,
1315
)
1416
from .pipelines import (
15-
FluxImagePipeline,
16-
SDXLImagePipeline,
1717
SDImagePipeline,
18+
SDXLImagePipeline,
19+
FluxImagePipeline,
1820
WanVideoPipeline,
1921
QwenImagePipeline,
2022
Hunyuan3DShapePipeline,
2123
)
2224
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
2325
from .models.sd import SDControlNet
2426
from .models.sdxl import SDXLControlNetUnion
27+
from .tools import (
28+
FluxInpaintingTool,
29+
FluxOutpaintingTool,
30+
FluxIPAdapterRefTool,
31+
FluxReduxRefTool,
32+
FluxReplaceByControlTool,
33+
)
2534
from .utils.download import (
2635
fetch_model,
2736
fetch_modelscope_model,
@@ -30,32 +39,29 @@
3039
reset_fetch_modelscope_model,
3140
)
3241
from .utils.video import load_video, save_video
33-
from .tools import (
34-
FluxInpaintingTool,
35-
FluxOutpaintingTool,
36-
FluxIPAdapterRefTool,
37-
FluxReduxRefTool,
38-
FluxReplaceByControlTool,
39-
)
4042

4143
__all__ = [
4244
"SDPipelineConfig",
4345
"SDXLPipelineConfig",
4446
"FluxPipelineConfig",
4547
"WanPipelineConfig",
48+
"QwenImagePipelineConfig",
49+
"HunyuanPipelineConfig",
4650
"SDStateDicts",
4751
"SDXLStateDicts",
4852
"FluxStateDicts",
53+
"WanStateDicts",
4954
"QwenImageStateDicts",
55+
"ControlNetParams",
56+
"ControlType",
57+
"SDImagePipeline",
58+
"SDControlNet",
59+
"SDXLImagePipeline",
60+
"SDXLControlNetUnion",
5061
"FluxImagePipeline",
51-
"QwenImagePipelineConfig",
5262
"FluxControlNet",
5363
"FluxIPAdapter",
5464
"FluxRedux",
55-
"SDControlNet",
56-
"SDXLControlNetUnion",
57-
"SDXLImagePipeline",
58-
"SDImagePipeline",
5965
"WanVideoPipeline",
6066
"QwenImagePipeline",
6167
"Hunyuan3DShapePipeline",
@@ -64,8 +70,6 @@
6470
"FluxIPAdapterRefTool",
6571
"FluxReplaceByControlTool",
6672
"FluxReduxRefTool",
67-
"ControlNetParams",
68-
"ControlType",
6973
"fetch_model",
7074
"fetch_modelscope_model",
7175
"register_fetch_modelscope_model",

diffsynth_engine/configs/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
FluxPipelineConfig,
99
WanPipelineConfig,
1010
QwenImagePipelineConfig,
11+
HunyuanPipelineConfig,
1112
BaseStateDicts,
1213
SDStateDicts,
1314
SDXLStateDicts,
1415
FluxStateDicts,
16+
WanStateDicts,
1517
QwenImageStateDicts,
1618
)
1719
from .controlnet import ControlType, ControlNetParams
@@ -26,11 +28,13 @@
2628
"FluxPipelineConfig",
2729
"WanPipelineConfig",
2830
"QwenImagePipelineConfig",
29-
"ControlType",
30-
"ControlNetParams",
31+
"HunyuanPipelineConfig",
3132
"BaseStateDicts",
3233
"SDStateDicts",
3334
"SDXLStateDicts",
3435
"FluxStateDicts",
36+
"WanStateDicts",
3537
"QwenImageStateDicts",
38+
"ControlType",
39+
"ControlNetParams",
3640
]

diffsynth_engine/configs/pipeline.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33
from dataclasses import dataclass, field
4-
from typing import List, Tuple, Optional, Dict
4+
from typing import List, Dict, Tuple, Optional
55

66
from diffsynth_engine.configs.controlnet import ControlType
77

@@ -127,7 +127,7 @@ def basic_config(
127127
model_path=model_path,
128128
device=device,
129129
parallelism=parallelism,
130-
use_fsdp=True,
130+
use_fsdp=True if parallelism > 1 else False,
131131
offload_mode=offload_mode,
132132
offload_to_disk=offload_to_disk,
133133
)
@@ -174,8 +174,8 @@ def basic_config(
174174
image_encoder_path=image_encoder_path,
175175
device=device,
176176
parallelism=parallelism,
177-
use_cfg_parallel=True,
178-
use_fsdp=True,
177+
use_cfg_parallel=True if parallelism > 1 else False,
178+
use_fsdp=True if parallelism > 1 else False,
179179
offload_mode=offload_mode,
180180
offload_to_disk=offload_to_disk,
181181
)
@@ -184,16 +184,6 @@ def __post_init__(self):
184184
init_parallel_config(self)
185185

186186

187-
@dataclass
188-
class HunyuanPipelineConfig(BaseConfig):
189-
model_path: str | os.PathLike | List[str | os.PathLike]
190-
model_dtype: torch.dtype = torch.float16
191-
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
192-
vae_dtype: torch.dtype = torch.float16
193-
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
194-
image_encoder_dtype: torch.dtype = torch.float16
195-
196-
197187
@dataclass
198188
class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
199189
model_path: str | os.PathLike | List[str | os.PathLike]
@@ -228,8 +218,8 @@ def basic_config(
228218
encoder_path=encoder_path,
229219
vae_path=vae_path,
230220
parallelism=parallelism,
231-
use_cfg_parallel=True,
232-
use_fsdp=True,
221+
use_cfg_parallel=True if parallelism > 1 else False,
222+
use_fsdp=True if parallelism > 1 else False,
233223
offload_mode=offload_mode,
234224
offload_to_disk=offload_to_disk,
235225
)
@@ -238,32 +228,57 @@ def __post_init__(self):
238228
init_parallel_config(self)
239229

240230

231+
@dataclass
232+
class HunyuanPipelineConfig(BaseConfig):
233+
model_path: str | os.PathLike | List[str | os.PathLike]
234+
model_dtype: torch.dtype = torch.float16
235+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
236+
vae_dtype: torch.dtype = torch.float16
237+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
238+
image_encoder_dtype: torch.dtype = torch.float16
239+
240+
241241
@dataclass
242242
class BaseStateDicts:
243-
model: Optional[Dict[str, torch.Tensor]] = None
244-
vae: Optional[Dict[str, torch.Tensor]] = None
243+
pass
244+
245+
246+
@dataclass
247+
class SDStateDicts:
248+
model: Dict[str, torch.Tensor]
249+
clip: Dict[str, torch.Tensor]
250+
vae: Dict[str, torch.Tensor]
245251

246252

247253
@dataclass
248-
class SDStateDicts(BaseStateDicts):
249-
clip: Optional[Dict[str, torch.Tensor]] = None
254+
class SDXLStateDicts:
255+
model: Dict[str, torch.Tensor]
256+
clip_l: Dict[str, torch.Tensor]
257+
clip_g: Dict[str, torch.Tensor]
258+
vae: Dict[str, torch.Tensor]
250259

251260

252261
@dataclass
253-
class SDXLStateDicts(BaseStateDicts):
254-
clip_l: Optional[Dict[str, torch.Tensor]] = None
255-
clip_g: Optional[Dict[str, torch.Tensor]] = None
262+
class FluxStateDicts:
263+
model: Dict[str, torch.Tensor]
264+
t5: Dict[str, torch.Tensor]
265+
clip: Dict[str, torch.Tensor]
266+
vae: Dict[str, torch.Tensor]
256267

257268

258269
@dataclass
259-
class FluxStateDicts(BaseStateDicts):
260-
t5: Optional[Dict[str, torch.Tensor]] = None
261-
clip: Optional[Dict[str, torch.Tensor]] = None
270+
class WanStateDicts:
271+
model: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]
272+
t5: Dict[str, torch.Tensor]
273+
vae: Dict[str, torch.Tensor]
274+
image_encoder: Optional[Dict[str, torch.Tensor]] = None
262275

263276

264277
@dataclass
265-
class QwenImageStateDicts(BaseStateDicts):
266-
encoder: Optional[Dict[str, torch.Tensor]] = None
278+
class QwenImageStateDicts:
279+
model: Dict[str, torch.Tensor]
280+
encoder: Dict[str, torch.Tensor]
281+
vae: Dict[str, torch.Tensor]
267282

268283

269284
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig):

diffsynth_engine/models/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Dict, Union, List, Any
55
from diffsynth_engine.utils.loader import load_file
66
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
7-
from diffsynth_engine.models.utils import no_init_weights
87

98

109
class StateDictConverter:
@@ -33,10 +32,9 @@ def from_pretrained(
3332

3433
@classmethod
3534
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype, **kwargs):
36-
with no_init_weights():
37-
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype, **kwargs)
38-
model.to_empty(device=device)
39-
model.load_state_dict(state_dict)
35+
model = cls(device="meta", dtype=dtype, **kwargs)
36+
model.requires_grad_(False)
37+
model.load_state_dict(state_dict, assign=True)
4038
model.to(device=device, dtype=dtype, non_blocking=True)
4139
return model
4240

diffsynth_engine/models/basic/lora.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,13 @@ def __init__(
7474

7575
@staticmethod
7676
def from_linear(linear: nn.Linear):
77-
lora_linear = torch.nn.utils.skip_init(
78-
LoRALinear,
77+
lora_linear = LoRALinear(
7978
linear.in_features,
8079
linear.out_features,
8180
linear.bias is not None,
82-
device=linear.weight.device,
81+
device="meta",
8382
dtype=linear.weight.dtype,
84-
)
83+
).to_empty(device=linear.weight.device)
8584
lora_linear.weight = linear.weight
8685
lora_linear.bias = linear.bias
8786
return lora_linear
@@ -98,12 +97,20 @@ def add_lora(
9897
dtype: torch.dtype,
9998
**kwargs,
10099
):
101-
up_linear = torch.nn.utils.skip_init(
102-
nn.Linear, up.shape[1], up.shape[0], bias=False, device=device, dtype=dtype
103-
)
104-
down_linear = torch.nn.utils.skip_init(
105-
nn.Linear, down.shape[0], down.shape[1], bias=False, device=device, dtype=dtype
106-
)
100+
up_linear = nn.Linear(
101+
up.shape[1],
102+
up.shape[0],
103+
bias=False,
104+
device="meta",
105+
dtype=dtype,
106+
).to_empty(device=device)
107+
down_linear = nn.Linear(
108+
down.shape[0],
109+
down.shape[1],
110+
bias=False,
111+
device="meta",
112+
dtype=dtype,
113+
).to_empty(device=device)
107114
up_linear.weight.data = up
108115
down_linear.weight.data = down
109116
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
@@ -182,8 +189,7 @@ def __init__(
182189

183190
@staticmethod
184191
def from_conv2d(conv2d: nn.Conv2d):
185-
lora_conv2d = torch.nn.utils.skip_init(
186-
LoRAConv2d,
192+
lora_conv2d = LoRAConv2d(
187193
conv2d.in_channels,
188194
conv2d.out_channels,
189195
conv2d.kernel_size,
@@ -193,9 +199,9 @@ def from_conv2d(conv2d: nn.Conv2d):
193199
conv2d.groups,
194200
conv2d.bias is not None,
195201
conv2d.padding_mode,
196-
device=conv2d.weight.device,
202+
device="meta",
197203
dtype=conv2d.weight.dtype,
198-
)
204+
).to_empty(device=conv2d.weight.device)
199205
lora_conv2d.weight = conv2d.weight
200206
lora_conv2d.bias = conv2d.bias
201207
return lora_conv2d
@@ -211,31 +217,29 @@ def _construct_lora(
211217
device: str,
212218
dtype: torch.dtype,
213219
):
214-
down_conv = torch.nn.utils.skip_init(
215-
nn.Conv2d,
220+
down_conv = nn.Conv2d(
216221
self.in_channels,
217222
rank,
218223
kernel_size=self.kernel_size,
219224
stride=self.stride,
220225
padding=self.padding,
221226
bias=False,
222-
device=device,
227+
device="meta",
223228
dtype=dtype,
224-
)
229+
).to_empty(device=device)
225230
down_conv.weight.data = down
226231
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
227232
# see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
228233
# refer from diffusers
229-
up_conv = torch.nn.utils.skip_init(
230-
nn.Conv2d,
234+
up_conv = nn.Conv2d(
231235
rank,
232236
self.out_channels,
233237
kernel_size=(1, 1),
234238
stride=(1, 1),
235239
bias=False,
236-
device=device,
240+
device="meta",
237241
dtype=dtype,
238-
)
242+
).to_empty(device=device)
239243
up_conv.weight.data = up
240244

241245
lora = LoRA(scale, rank, alpha, up_conv, down_conv, device, dtype)

diffsynth_engine/models/flux/flux_controlnet.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
RoPEEmbedding,
99
TimestepEmbeddings,
1010
)
11-
from diffsynth_engine.models.utils import no_init_weights
1211

1312

1413
class FluxControlNetStateDictConverter(StateDictConverter):
@@ -164,10 +163,13 @@ def from_state_dict(
164163
else:
165164
condition_channels = 64
166165

167-
with no_init_weights():
168-
model = torch.nn.utils.skip_init(
169-
cls, condition_channels=condition_channels, attn_kwargs=attn_kwargs, device=device, dtype=dtype
170-
)
171-
model.load_state_dict(state_dict)
166+
model = cls(
167+
condition_channels=condition_channels,
168+
attn_kwargs=attn_kwargs,
169+
device="meta",
170+
dtype=dtype,
171+
)
172+
model.requires_grad_(False)
173+
model.load_state_dict(state_dict, assign=True)
172174
model.to(device=device, dtype=dtype, non_blocking=True)
173175
return model

0 commit comments

Comments
 (0)