Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions diffsynth_engine/models/flux/flux_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
with no_init_weights():
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model


Expand All @@ -74,5 +75,6 @@ def __init__(self, device: str, dtype: torch.dtype = torch.float32):
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
with no_init_weights():
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model
13 changes: 10 additions & 3 deletions diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
self.dtype = dtype
self.offload_mode = None
self.model_names = []
self._models_offload_params = {}

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -288,6 +289,10 @@ def _enable_model_cpu_offload(self):
model = getattr(self, model_name)
if model is not None:
model.to("cpu")
self._models_offload_params[model_name] = {}
for name, param in model.named_parameters(recurse=True):
param.data = param.data.pin_memory()
self._models_offload_params[model_name][name] = param.data
self.offload_mode = "cpu_offload"

def _enable_sequential_cpu_offload(self):
Expand Down Expand Up @@ -321,12 +326,14 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
for model_name in self.model_names:
if model_name not in load_model_names:
model = getattr(self, model_name)
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != "cpu":
model.to("cpu")
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device("cpu"):
param_cache = self._models_offload_params[model_name]
for name, param in model.named_parameters(recurse=True):
param.data = param_cache[name]
# load the needed models to device
for model_name in load_model_names:
model = getattr(self, model_name)
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device(self.device):
model.to(self.device)
# fresh the cuda cache
empty_cache()
3 changes: 2 additions & 1 deletion diffsynth_engine/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ def from_pretrained(
device=device,
dtype=model_config.dit_dtype,
)
pipe.enable_cpu_offload(offload_mode)
if offload_mode is not None:
pipe.enable_cpu_offload(offload_mode)
if model_config.dit_dtype == torch.float8_e4m3fn:
pipe.dtype = torch.bfloat16 # running dtype
pipe.enable_fp8_autocast(
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def add_cpu_offload_hook(module: nn.Module, device: str = "cuda", recurse: bool
def _forward_pre_hook(module: nn.Module, input):
offload_params = {}
for name, param in module.named_parameters(recurse=recurse):
offload_params[name] = param.data
offload_params[name] = param.data.pin_memory()
param.data = param.data.to(device=device)
setattr(module, "_offload_params", offload_params)
return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input)
Expand Down