Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion diffsynth_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux
from .models.sd import SDControlNet
from .models.sdxl import SDXLControlNetUnion
from .utils.download import fetch_model, fetch_modelscope_model, fetch_civitai_model
from .utils.download import (
fetch_model,
fetch_modelscope_model,
fetch_civitai_model,
register_fetch_modelscope_model,
reset_fetch_modelscope_model,
)
from .utils.video import load_video, save_video
from .tools import (
FluxInpaintingTool,
Expand Down Expand Up @@ -52,6 +58,8 @@
"ControlType",
"fetch_model",
"fetch_modelscope_model",
"register_fetch_modelscope_model",
"reset_fetch_modelscope_model",
"fetch_civitai_model",
"load_video",
"save_video",
Expand Down
30 changes: 30 additions & 0 deletions diffsynth_engine/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@

MODEL_SOURCES = ["modelscope", "civitai"]

# Global registry for custom fetch function
_CUSTOM_MODELSCOPE_FETCHER = None


def register_fetch_modelscope_model(fetch_func):
"""
Register a global custom fetch function for ModelScope models.

Args:
fetch_func (callable): Custom fetch function that should accept the same parameters
as fetch_modelscope_model and return the model path(s)
"""
global _CUSTOM_MODELSCOPE_FETCHER
_CUSTOM_MODELSCOPE_FETCHER = fetch_func
logger.info("Registered global custom ModelScope fetcher")


def reset_fetch_modelscope_model():
"""
Reset the global custom fetch function for ModelScope models.
"""
global _CUSTOM_MODELSCOPE_FETCHER
_CUSTOM_MODELSCOPE_FETCHER = None
logger.info("Reset global custom ModelScope fetcher")


def fetch_model(
model_uri: str,
Expand All @@ -43,6 +68,11 @@ def fetch_modelscope_model(
access_token: Optional[str] = None,
fetch_safetensors: bool = True,
) -> str:
# Check if there's a global custom fetcher registered
if _CUSTOM_MODELSCOPE_FETCHER is not None:
logger.info(f"Using global custom fetcher for model: {model_id}")
return _CUSTOM_MODELSCOPE_FETCHER(model_id, revision, path, access_token, fetch_safetensors)

lock_file_name = f"modelscope.{model_id.replace('/', '--')}.{revision if revision else '__version'}.lock"
lock_file_path = os.path.join(DIFFSYNTH_FILELOCK_DIR, lock_file_name)
ensure_directory_exists(lock_file_path)
Expand Down