-
Notifications
You must be signed in to change notification settings - Fork 2.1k
ENH: Tie weights for target_modules in Lora (#2864) #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 22 commits
4c6d15f
4b91220
46b803e
37b1e06
8388aa8
cd6c6d0
0cb44e8
628ce10
602ce10
e2d0345
7880032
f73af50
46cca1e
2267a48
5d5b8e4
c7cfe40
8294ec7
7370a21
1da895f
d86ff7d
dc03dd4
c79a64c
0715451
dbb0096
06d4b7f
2ea03e3
d4427e8
d3c0099
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,11 +26,7 @@ | |
| from transformers.modeling_layers import GradientCheckpointingLayer | ||
|
|
||
| from peft.import_utils import is_bnb_4bit_available, is_bnb_available | ||
| from peft.tuners.tuners_utils import ( | ||
| BaseTuner, | ||
| BaseTunerLayer, | ||
| replicate_layers, | ||
| ) | ||
| from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, find_parameter_name_by_tensor, replicate_layers | ||
| from peft.utils import ( | ||
| TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, | ||
| AuxiliaryTrainingWrapper, | ||
|
|
@@ -201,6 +197,17 @@ def _create_and_replace( | |
| r = lora_config.rank_pattern.get(r_key, lora_config.r) | ||
| alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) | ||
|
|
||
| # Checks if the target is marked as a tied layer | ||
| # If true, we add the reference to lora adapters of embedding layer in `tied_adapters` | ||
| is_tied = target_name in (getattr(lora_config, "target_modules_to_tie", []) or []) | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tied_adapters = {} | ||
| if is_tied: | ||
| tied_module = self.model.get_input_embeddings() | ||
| emb_A = tied_module.lora_embedding_A[adapter_name] | ||
| emb_B = tied_module.lora_embedding_B[adapter_name] | ||
|
|
||
| tied_adapters = {"lora_A": emb_B.t(), "lora_B": emb_A.t()} | ||
|
|
||
| kwargs = { | ||
| "r": r, | ||
| "lora_alpha": alpha, | ||
|
|
@@ -218,6 +225,7 @@ def _create_and_replace( | |
| "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), | ||
| "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), | ||
| "parameter_name": parameter_name, | ||
| "tied_adapters": tied_adapters, | ||
| } | ||
|
|
||
| # for torchao merging, we need the get_apply_tensor_subclass from the quantization config | ||
|
|
@@ -263,6 +271,7 @@ def _create_and_replace( | |
| if adapter_name not in self.active_adapters: | ||
| # adding an additional adapter: it is not automatically trainable | ||
| new_module.requires_grad_(False) | ||
|
|
||
| self._replace_module(parent, target_name, new_module, target) | ||
|
|
||
| def _replace_module(self, parent, child_name, new_module, child): | ||
|
|
@@ -861,8 +870,63 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap | |
|
|
||
| return tensors_lora | ||
|
|
||
| def _add_modules_to_tie(self, peft_config, tied_weight_keys): | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| missing_keys = set(tied_weight_keys) - modules_to_save | ||
| def _add_modules_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||
|
||
| """ | ||
| Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the | ||
| tied layers from `module_to_save`. Maintain a separate set for layers to be tied | ||
|
|
||
| Args: | ||
| peft_config (LoraConfig) | ||
| tied_weight_keys (list[str]) | ||
| """ | ||
romitjain marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tied_weight_keys = set(tied_weight_keys) | ||
| peft_config.modules_to_tie = tied_weight_keys | ||
|
|
||
| modules_to_save = getattr(peft_config, "modules_to_save", []) or [] | ||
|
|
||
| embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is no guarantee that this will return the name of the embedding layer. It could also return the name of a layer tied to the embedding layer. It is probably safer to compare module identity instead (even though for transformers <5 this will also be flaky for models like T5). |
||
| # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name | ||
| embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") | ||
|
||
|
|
||
| if embed_layer_name not in modules_to_save: | ||
| modules_to_save.append(embed_layer_name) | ||
|
|
||
| for m in tied_weight_keys: | ||
| if m in modules_to_save: | ||
| modules_to_save.remove(m) | ||
|
Comment on lines
+903
to
+904
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how often this will generate a match. If I understand correctly, |
||
|
|
||
| peft_config.modules_to_save = modules_to_save | ||
|
|
||
| def _add_targets_to_tie(self, peft_config: LoraConfig, tied_weight_keys: list[str]): | ||
| """ | ||
| Tied weight keys contains the layers tied to the embedding layer. Add embedding layer and remove rest of the | ||
| tied layers from `target_modules`. Maintain a separate set for layers to be tied | ||
|
|
||
| Args: | ||
| peft_config (LoraConfig) | ||
| tied_weight_keys (list[str]) | ||
| """ | ||
| tied_weight_keys = set(tied_weight_keys) | ||
| peft_config.target_modules_to_tie = tied_weight_keys | ||
|
|
||
| raw_target_modules = getattr(peft_config, "target_modules", None) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan Please review this logic. I know this is a bit hacky! I am open to suggestions
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah, this is rough. We can't really operate on the string like this, as there are too many possible ways that the regex could be formed. I wonder if we should just leave it be and deal with the tied module edge case in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be possible, it would just make the flow very convoluted.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I redid this a bit. We just need to make sure that |
||
| embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) | ||
| # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name | ||
| embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") | ||
|
|
||
| if isinstance(raw_target_modules, str): | ||
| # The way weight tying is handled for adapters, we always want to add | ||
| # lora adapters to the input embedding layer (embed_tokens) | ||
| # instead of output embedding lauyer. | ||
| raw_target_modules = rf"(?:{raw_target_modules}|.*{embed_layer_name}$)" | ||
| peft_config.target_modules = raw_target_modules | ||
| return | ||
|
|
||
| target_modules = set(raw_target_modules or []) | ||
| target_modules.add(embed_layer_name) | ||
|
|
||
| for m in tied_weight_keys: | ||
| if m in target_modules: | ||
| target_modules.remove(m) | ||
|
Comment on lines
+936
to
+938
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will also only occasionally match, right? Only if users supply the fully-qualified module names. |
||
|
|
||
| peft_config.modules_to_tie = missing_keys | ||
| peft_config.target_modules = target_modules | ||
Uh oh!
There was an error while loading. Please reload this page.