Skip to content

Commit

Permalink
Add add_weighted_adapter to IA3 adapters (#1701)
Browse files Browse the repository at this point in the history
* Add add_weighted_adapter to IA3 adapters

* Refactor to simplify code

* refactor test

* Add IA3 merging docs

* Update docs/source/developer_guides/model_merging.md

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update docs/source/developer_guides/model_merging.md

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* address PR feedback

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
alexrs and BenjaminBossan committed May 17, 2024
1 parent 4e32679 commit fb7f279
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 25 deletions.
17 changes: 17 additions & 0 deletions docs/source/developer_guides/model_merging.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,20 @@ print(tokenizer.decode(outputs[0]))

</hfoption>
</hfoptions>


## Merging (IA)鲁 Models
The (IA)鲁 models facilitate linear merging of adapters. To merge adapters in an (IA)鲁 model, utilize the `add_weighted_adapter` method from the `IA3Model` class. This method is analogous to the `add_weighted_adapter` method used in `LoraModel`, with the key difference being the absence of the `combination_type` parameter. For example, to merge three (IA)鲁 adapters into a PEFT model, you would proceed as follows:

```py
adapters = ["adapter1", "adapter2", "adapter3"]
weights = [0.4, 0.3, 0.3]
adapter_name = "merge"
model.add_weighted_adapter(adapters, weights, adapter_name)
```

It is recommended that the weights sum to 1.0 to preserve the scale of the model. The merged model can then be set as the active model using the `set_adapter` method:

```py
model.set_adapter("merge")
```
107 changes: 101 additions & 6 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import re
import warnings
from dataclasses import asdict
from dataclasses import asdict, replace
from enum import Enum
from typing import Optional

Expand All @@ -29,6 +29,7 @@
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_freeze_adapter,
_get_submodules,
)

Expand Down Expand Up @@ -279,17 +280,20 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _prepare_adapter_config(self, peft_config, model_config):
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
if peft_config.feedforward_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
raise ValueError("Please specify `feedforward_modules` in `peft_config`")
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
model_config["model_type"]
]
peft_config.feedforward_modules = set(
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config

def _unload_and_optionally_merge(
Expand Down Expand Up @@ -393,3 +397,94 @@ def delete_adapter(self, adapter_name: str) -> None:
new_adapter = target.active_adapters[:]

self.active_adapter = new_adapter or []

def _check_add_weighted_adapter(self, adapters: list[str]) -> tuple[str, str]:
"""
Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
model.
"""
# Validate existence of adapters
for adapter in adapters:
if adapter not in self.peft_config:
raise ValueError(f"Adapter {adapter} does not exist")

# Check for conflicting modules_to_save
modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
if any(
sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 for wrapper in modules_to_save_wrappers
):
raise ValueError("Cannot add weighted adapters targeting the same module with modules_to_save.")

# Ensure all adapters have compatible target and feedforward module types
target_module_types = {type(self.peft_config[adapter].target_modules) for adapter in adapters}
feedforward_module_types = {type(self.peft_config[adapter].feedforward_modules) for adapter in adapters}
if len(target_module_types) > 1 or len(feedforward_module_types) > 1:
raise ValueError("All adapter configs should have the same type for target and feedforward modules.")

# Combine target and feedforward modules
if str in target_module_types:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
else:
new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))

if str in feedforward_module_types:
new_feedforward_modules = "|".join(
f"({self.peft_config[adapter].feedforward_modules})" for adapter in adapters
)
else:
new_feedforward_modules = set.union(
*(self.peft_config[adapter].feedforward_modules for adapter in adapters)
)

return new_target_modules, new_feedforward_modules

def add_weighted_adapter(
self,
adapters: list[str],
weights: list[float],
adapter_name: str,
) -> None:
"""
This method adds a new adapter by merging the given adapters with the given weights.
Args:
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
adapter_name (`str`):
Name of the new adapter.
"""
if adapter_name in list(self.peft_config.keys()):
return

new_target_modules, new_feedforward_modules = self._check_add_weighted_adapter(
adapters=adapters,
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
target_modules=new_target_modules,
feedforward_modules=new_feedforward_modules,
)
self.inject_adapter(self.model, adapter_name)

# Do we really need that?
_freeze_adapter(self.model, adapter_name)

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
if adapter_name in target.ia3_l:
target_ia3_l = target.ia3_l[adapter_name]
else:
continue

target_ia3_l.data = target_ia3_l.data.zero_()
for adapter, weight in zip(adapters, weights):
if adapter in target.ia3_l:
current_adapter_ia3_l = target.ia3_l[adapter]
else:
continue
target_ia3_l.data += current_adapter_ia3_l.data * weight
3 changes: 0 additions & 3 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,6 @@ def add_weighted_adapter(

if adapter_name in list(self.peft_config.keys()):
return
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
adapters=adapters,
Expand Down
1 change: 1 addition & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "CAUSAL_LM",
},
Expand Down
1 change: 1 addition & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"task_type": "FEATURE_EXTRACTION",
},
Expand Down
71 changes: 55 additions & 16 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,22 +1139,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
assert not torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10)
assert torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)

def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return pytest.skip(f"Test not applicable for {config_cls}")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, LoraConfig):
return pytest.skip(f"Test not applicable for {config}")

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])
def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)
Expand Down Expand Up @@ -1338,6 +1323,60 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
assert model.active_adapters == [adapter_name]
model(**dummy_input)[0]

def _test_weighted_combination_of_adapters_ia3(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], config)
model = model.to(self.torch_device)

# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

# test re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_reweighting")

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_reweighting",
]
for new_adapter in new_adapters:
assert new_adapter in model.peft_config

dummy_input = self.prepare_inputs_for_testing()
model.eval()
for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
assert model.active_adapter == adapter_name
assert model.active_adapters == [adapter_name]
model(**dummy_input)[0]

def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return pytest.skip(f"Test not applicable for {config_cls}")

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
# Initialize the config
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

if not isinstance(config, (LoraConfig, IA3Config)):
# This test is only applicable for Lora and IA3 configs
return pytest.skip(f"Test not applicable for {config}")

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, adapter_list[0])

if isinstance(config, LoraConfig):
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list)
elif isinstance(config, IA3Config):
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list)
else:
pytest.skip(f"Test not applicable for {config}")

def _test_disable_adapter(self, model_id, config_cls, config_kwargs):
task_type = config_kwargs.get("task_type")
if (task_type == "SEQ_2_SEQ_LM") and (config_cls in (PromptTuningConfig, PromptEncoderConfig)):
Expand Down

0 comments on commit fb7f279

Please sign in to comment.