Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4c6d15f
Tests and inital implementation for embed_tokens
romitjain Oct 29, 2025
4b91220
Minor fixes
romitjain Oct 30, 2025
46b803e
Fixed all tests and made updates to logic
romitjain Oct 31, 2025
37b1e06
Nit
romitjain Oct 31, 2025
8388aa8
Added contigious check for export
romitjain Nov 4, 2025
cd6c6d0
Apply suggestion from @BenjaminBossan
romitjain Nov 4, 2025
0cb44e8
Addressed PR comments
romitjain Nov 5, 2025
628ce10
Update src/peft/tuners/lora/model.py
romitjain Nov 7, 2025
602ce10
Update src/peft/tuners/lora/model.py
romitjain Nov 7, 2025
e2d0345
Apply suggestions from code review
romitjain Nov 7, 2025
7880032
Removed redundant change
romitjain Nov 7, 2025
f73af50
Merge branch 'enh/tie-target-modules' of github.com:romitjain/peft in…
romitjain Nov 7, 2025
46cca1e
Handling target_modules as str
romitjain Nov 7, 2025
2267a48
Update src/peft/tuners/tuners_utils.py
romitjain Nov 10, 2025
5d5b8e4
Updated regex matching
romitjain Nov 12, 2025
c7cfe40
Apply suggestion from @BenjaminBossan
romitjain Nov 13, 2025
8294ec7
Added find layer by tensor
romitjain Nov 13, 2025
7370a21
Merge branch 'main' of github.com:romitjain/peft into enh/tie-target-…
romitjain Nov 13, 2025
1da895f
Fixed tests
romitjain Nov 14, 2025
d86ff7d
Nit
romitjain Nov 18, 2025
dc03dd4
Small fix to ensure correct layer name gets saved for target modules
romitjain Nov 19, 2025
c79a64c
Merge branch 'main' of github.com:huggingface/peft into enh/tie-targe…
romitjain Nov 20, 2025
0715451
Merge branch 'main' of github.com:huggingface/peft into enh/tie-targe…
romitjain Dec 15, 2025
dbb0096
Apply suggestions from code review
romitjain Dec 15, 2025
06d4b7f
Merge branch 'enh/tie-target-modules' of github.com:romitjain/peft in…
romitjain Dec 15, 2025
2ea03e3
Small fixed on comments
romitjain Dec 15, 2025
d4427e8
Update src/peft/peft_model.py
romitjain Dec 15, 2025
d3c0099
Small fixes
romitjain Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
output_state_dict = save_mutated_as_lora(
peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs
)

# Before exporting the parameters we need to make sure
# all the tensors are contigious. Tensors can become non contigiuous
# if they are a transpose view of another tensor. This can happen
# during adapter tying or parameter sharing.
for k, v in output_state_dict.items():
if not v.is_contiguous():
output_state_dict[k] = v.contiguous()

safe_save_file(
output_state_dict,
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ class LoraConfig(PeftConfig):
`target_parameters`. As an example, for Llama4, you can pass:
`target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`. Passing a
string for regex matching is not implemented yet.
ensure_weight_tying (`bool`, *optional*)
Whether to tie weights or not after peft initialization. This will ensure that the adapters added to the
tied layers are also tied. This is only applicable for layers passed via `modules_to_save` and
`target_modules`.

"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -670,7 +675,7 @@ class LoraConfig(PeftConfig):
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
"`modules_to_save` and and `target_modules`."
)
},
)
Expand All @@ -695,6 +700,7 @@ def __post_init__(self):

if self.ensure_weight_tying:
self.modules_to_tie = None
self.target_modules_to_tie = None

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")
Expand Down
13 changes: 13 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def update_layer(
arrow_config: ArrowConfig = None,
qalora_group_size: int = 32,
inference_mode: bool = False,
tied_adapters: Optional[dict[str, nn.Parameter]] = None,
**kwargs,
):
# collect the kwargs
Expand Down Expand Up @@ -195,6 +196,17 @@ def update_layer(
# Actual trainable parameters
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)

# Tying adapters is only implemented for Linear layers
# where the source is the embedding layer.
# Currently, this is the most prevelant way of tying layers (weight tying)
if tied_adapters:
lora_A_params = tied_adapters["lora_A"]
lora_B_params = tied_adapters["lora_B"]

self.lora_A[adapter_name].weight = torch.nn.Parameter(lora_A_params)
self.lora_B[adapter_name].weight = torch.nn.Parameter(lora_B_params)

self.lora_bias[adapter_name] = lora_bias

if use_rslora:
Expand Down Expand Up @@ -631,6 +643,7 @@ def __init__(
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
tied_adapters=kwargs.get("tied_adapters"),
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

Expand Down
82 changes: 73 additions & 9 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
from transformers.modeling_layers import GradientCheckpointingLayer

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
replicate_layers,
)
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, find_parameter_name_by_tensor, replicate_layers
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
AuxiliaryTrainingWrapper,
Expand Down Expand Up @@ -201,6 +197,17 @@ def _create_and_replace(
r = lora_config.rank_pattern.get(r_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha)

# Checks if the target is marked as a tied layer
# If true, we add the reference to lora adapters of embedding layer in `tied_adapters`
is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or [])
tied_adapters = {}
if is_tied:
tied_module = self.model.get_input_embeddings()
emb_A = tied_module.lora_embedding_A[adapter_name]
emb_B = tied_module.lora_embedding_B[adapter_name]

tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()}

kwargs = {
"r": r,
"lora_alpha": alpha,
Expand All @@ -218,6 +225,7 @@ def _create_and_replace(
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
"parameter_name": parameter_name,
"tied_adapters": tied_adapters,
}

# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
Expand Down Expand Up @@ -263,6 +271,7 @@ def _create_and_replace(
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)

self._replace_module(parent, target_name, new_module, target)

def _replace_module(self, parent, child_name, new_module, child):
Expand Down Expand Up @@ -861,8 +870,63 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap

return tensors_lora

def _add_modules_to_tie(self, peft_config, tied_weight_keys):
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
missing_keys = set(tied_weight_keys) - modules_to_save
def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have _add_target_modules as well I'm wondering if we should refactor this to _add_modules_to_save_to_tie for clarity (it is verbose, yes).

Same goes for the config key modules_to_tie.

"""
Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the
tied layers from `module_to_save`. Maintain a separate set for layers to be tied

Args:
peft_config (LoraConfig)
tied_weight_keys (list[str])
"""
tied_weight_keys = set(tied_weight_keys)
peft_config.modules_to_tie = tied_weight_keys

modules_to_save = getattr(peft_config, "modules_to_save", []) or []

embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no guarantee that this will return the name of the embedding layer. It could also return the name of a layer tied to the embedding layer. It is probably safer to compare module identity instead (even though for transformers <5 this will also be flaky for models like T5).

# find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name
embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if replacing these strings is a good idea. encoder_model.embed_tokens would be turned into encoder_embed_tokens. Maybe using a more restricted approach (only one replacement, only if the key is found) would be better? .weight for example could be dropped by using .removesuffix.


if embed_layer_name not in modules_to_save:
modules_to_save.append(embed_layer_name)

for m in tied_weight_keys:
if m in modules_to_save:
modules_to_save.remove(m)
Comment on lines +903 to +904
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how often this will generate a match. If I understand correctly, tied_weight_keys are fully-qualified keys. So this check will only match if the keys in modules_to_save are also fully-qualified. I don't think this happens often. cc @BenjaminBossan


peft_config.modules_to_save = modules_to_save

def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]):
"""
Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the
tied layers from `target_modules`. Maintain a separate set for layers to be tied

Args:
peft_config (LoraConfig)
tied_weight_keys (list[str])
"""
tied_weight_keys = set(tied_weight_keys)
peft_config.target_modules_to_tie = tied_weight_keys

raw_target_modules = getattr(peft_config, "target_modules", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Please review this logic. I know this is a bit hacky! I am open to suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, this is rough. We can't really operate on the string like this, as there are too many possible ways that the regex could be formed. I wonder if we should just leave it be and deal with the tied module edge case in inject_adapter directly. I haven't fully thought this through, perhaps you already tried that and there is a caveat that I'm missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2879 (comment)

It should be possible, it would just make the flow very convoluted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I redid this a bit. We just need to make sure that embed_tokens is present in the target_modules

embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight)
# find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name
embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "")

if isinstance(raw_target_modules, str):
# The way weight tying is handled for adapters, we always want to add
# lora adapters to the input embedding layer (embed_tokens)
# instead of output embedding lauyer.
raw_target_modules = rf"(?:{raw_target_modules}|.*{embed_layer_name}$)"
peft_config.target_modules = raw_target_modules
return

target_modules = set(raw_target_modules or [])
target_modules.add(embed_layer_name)

for m in tied_weight_keys:
if m in target_modules:
target_modules.remove(m)
Comment on lines +936 to +938
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also only occasionally match, right? Only if users supply the fully-qualified module names.


peft_config.modules_to_tie = missing_keys
peft_config.target_modules = target_modules
Loading
Loading