Skip to content

Commit 68bd57c

Browse files
Support Diffusers Flux LoRA (#95)
* support diffusers lora * support flux diffusers lora
1 parent 30c836e commit 68bd57c

File tree

6 files changed

+39
-10
lines changed

6 files changed

+39
-10
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,14 @@ def prepare_latents(
196196
# if you have any questions about this, please ask @dizhipeng.dzp for more details
197197
latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5)
198198
init_latents = latents.clone()
199-
sigmas, timesteps = sigmas.to(device=self.device, dtype=self.dtype), timesteps.to(device=self.device, dtype=self.dtype)
200-
init_latents, latents = init_latents.to(device=self.device, dtype=self.dtype), latents.to(device=self.device, dtype=self.dtype)
199+
sigmas, timesteps = (
200+
sigmas.to(device=self.device, dtype=self.dtype),
201+
timesteps.to(device=self.device, dtype=self.dtype),
202+
)
203+
init_latents, latents = (
204+
init_latents.to(device=self.device, dtype=self.dtype),
205+
latents.to(device=self.device, dtype=self.dtype),
206+
)
201207
return init_latents, latents, sigmas, timesteps
202208

203209
def eval(self):

diffsynth_engine/pipelines/flux_image.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str,
163163
dit_dict = {}
164164
for key, param in lora_state_dict.items():
165165
origin_key = key
166-
if ".alpha" not in key:
166+
if "lora_A.weight" not in key or "lora_up.weight" not in key:
167167
continue
168-
key = key.replace(".alpha", ".weight")
169168
key = key.replace("transformer.", "")
170169
if "single_transformer_blocks" in key: # transformer.single_transformer_blocks.0.attn.to_k.weight
171170
key = key.replace(
@@ -208,10 +207,17 @@ def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str,
208207
else:
209208
raise ValueError(f"Unsupported key: {key}")
210209
lora_args = {}
211-
lora_args["alpha"] = param
212-
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
213-
lora_args["down"] = lora_state_dict[origin_key.replace(".alpha", ".lora_down.weight")]
210+
lora_args["up"] = param
211+
lora_args["down"] = lora_state_dict[
212+
origin_key.replace("lora_A.weight", "lora_B.weight").replace("lora_up.weight", "lora_down.weight")
213+
]
214214
lora_args["rank"] = lora_args["up"].shape[1]
215+
alpha_key = origin_key.replace("lora_A.weight", "alpha").replace("lora_up.weight", "alpha")
216+
if alpha_key in lora_state_dict:
217+
alpha = lora_state_dict[alpha_key]
218+
else:
219+
alpha = lora_args["rank"] # 如果alpha不存在,则取alpha/rank = 1
220+
lora_args["alpha"] = alpha
215221
key = key.replace(".weight", "")
216222
dit_dict[key] = lora_args
217223
return {"dit": dit_dict}

diffsynth_engine/tools/flux_inpainting_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from diffsynth_engine import fetch_model, FluxControlNet, ControlNetParams, FluxImagePipeline
2-
from typing import List, Tuple, Optional
2+
from typing import List, Tuple, Optional, Callable
33
from PIL import Image
44
import torch
55

@@ -34,6 +34,7 @@ def __call__(
3434
inpainting_scale: float = 0.9,
3535
seed: int = 42,
3636
num_inference_steps: int = 20,
37+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
3738
):
3839
assert image.size == mask.size
3940
return self.pipe(
@@ -49,4 +50,5 @@ def __call__(
4950
mask=mask,
5051
scale=inpainting_scale,
5152
),
53+
progress_callback=progress_callback,
5254
)

tests/common/test_case.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22
import os
3-
import time
43
import numpy as np
54
import torch
65
from pathlib import Path
@@ -80,7 +79,7 @@ def assertImageEqualAndSaveFailed(self, input_image: Image.Image, expect_image_p
8079
self.assertImageEqual(input_image, expect_image, threshold=threshold)
8180
except Exception as e:
8281
name = expect_image_path.split("/")[-1]
83-
input_image.save(f"save_{time.time()}_{name}")
82+
input_image.save(f"{name}")
8483
raise e
8584

8685

1.53 MB
Loading

tests/test_pipelines/test_flux_image.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
34
from tests.common.test_case import ImageTestCase
45
from diffsynth_engine.pipelines import FluxImagePipeline, FluxModelConfig
56
from diffsynth_engine import fetch_model
@@ -47,6 +48,21 @@ def test_unfused_lora(self):
4748
self.pipe.unload_loras()
4849
self.assertImageEqualAndSaveFailed(image, "flux/flux_lora.png", threshold=0.98)
4950

51+
def test_diffusers_lora_patch(self):
52+
lora_model_path = fetch_model(
53+
"InstantX/FLUX.1-dev-LoRA-Ghibli", revision="master", path="ghibli_style.safetensors"
54+
)
55+
self.pipe.load_loras([(lora_model_path, 0.8)], fused=True, save_original_weight=True)
56+
image = self.pipe(
57+
prompt="ghibli style, a shepherd boy floating on a wooly cloud-whale, holding a glowing dandelion staff to guide sheep-shaped cumulus, miniature storm clouds grazing nearby, his patched jacket flapping in high-altitude winds, aurora-like ribbons in peach and lavender stretching across the sky",
58+
width=960,
59+
height=1280,
60+
num_inference_steps=24,
61+
seed=42,
62+
)
63+
self.pipe.unload_loras()
64+
self.assertImageEqualAndSaveFailed(image, "flux/flux_diffusers_lora.png", threshold=0.99)
65+
5066

5167
class TestFLUXGGUF(ImageTestCase):
5268
@classmethod

0 commit comments

Comments
 (0)