Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions mellea/formatters/granite/intrinsics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,41 @@ def make_config_dict(
return result_dict


def adapter_subpath(
intrinsic_name: str, target_model_name: str, repo_id: str, /, alora: bool = False
) -> str:
"""Return the Hugging Face Hub subpath where an intrinsic's adapter lives.

Encapsulates the layout convention used by the Granite Intrinsics Library and
related repositories so callers don't replicate the rules. Both
:func:`obtain_lora` and out-of-tree consumers (e.g. drift checks in tests)
should call this function rather than building the path themselves.

Args:
intrinsic_name: Short name of the intrinsic model, such as `"certainty"`.
target_model_name: Name of the base model for the LoRA or aLoRA adapter.
May be a raw HF repo ID; canonical normalization is applied.
repo_id: Hugging Face Hub repository containing the adapter collection.
Used to select between old and new directory layouts.
alora: If `True`, return the path for the aLoRA variant; otherwise LoRA.

Returns:
Subpath relative to the repo root, e.g. `"certainty/granite-4.1-3b/lora"`.
"""
# Normalize target model name if a normalization exists.
target_model_name = BASE_MODEL_TO_CANONICAL_NAME.get(
target_model_name, target_model_name
)

lora_str = "alora" if alora else "lora"

if repo_id in OLD_LAYOUT_REPOS:
# Old repository layout.
return f"{intrinsic_name}/{lora_str}/{target_model_name}"

return f"{intrinsic_name}/{target_model_name}/{lora_str}"


def obtain_lora(
intrinsic_name: str,
target_model_name: str,
Expand Down Expand Up @@ -136,20 +171,10 @@ def obtain_lora(
# Third Party
import huggingface_hub

# Normalize target model name if a normalization exists.
target_model_name = BASE_MODEL_TO_CANONICAL_NAME.get(
target_model_name, target_model_name
lora_subdir_name = adapter_subpath(
intrinsic_name, target_model_name, repo_id, alora=alora
)

lora_str = "alora" if alora else "lora"

if repo_id in OLD_LAYOUT_REPOS:
# Old repository layout
lora_subdir_name = f"{intrinsic_name}/{lora_str}/{target_model_name}"
else:
# Assume new layout otherwise
lora_subdir_name = f"{intrinsic_name}/{target_model_name}/{lora_str}"

# Download just the files for this LoRA if not already present
local_root_path = huggingface_hub.snapshot_download(
repo_id=repo_id,
Expand Down
Loading
Loading