-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix warning messages about config.json
when the base model_id
is local.
#1668
Conversation
…ile is available on the filesystem. When tuning a model with peft, sometimes the user might wish to use a local base model. In such cases, `model_id` points to a local directory instead of a remote repository. This commit adds check on the local directory to address this issue.
Thanks for the PR, what you describe sounds reasonable. Do you have a small example where this change would apply? Ideally, we can use that to create a unit test. Also, could you please remove the |
Sure, a simple example is to create a LoRA adapter for a local base model and saving it. warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
) And the configuration is not checked therefore. |
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel
from peft import prepare_model_for_kbit_training, get_peft_model
local_dir = 'path/to/model'
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained("test") |
Is this sufficient? @BenjaminBossan |
Nice, thanks for providing the example. I used it to create a test: class TestLocalModel:
def test_local_model_saving_no_warning(self, recwarn, tmp_path):
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
local_dir = tmp_path / model_id
model.save_pretrained(local_dir)
del model
base_model = AutoModelForCausalLM.from_pretrained(local_dir)
peft_config = LoraConfig()
peft_model = get_peft_model(base_model, peft_config)
peft_model.save_pretrained(local_dir)
for warning in recwarn.list:
assert "Could not find a config file" not in warning.message.args[0] We could for instance put it into |
…e `PeftModel.save_pretrained` method. When the base model is loaded from a local directory, we should be able to find the `config.json` there.
Certainly, I followed the syntax in |
The way you added the test, it's not executed. You would have to add corresponding methods that call this method in |
…using the `PeftModel.save_pretrained` method." This reverts commit da94f3b.
and no warning is issued when saving a model and checking for vocab changes.
The test case is fixed as advised and comments are added to explain the issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the test will now be run as part of the test suite. However, you forgot some imports for the test.
peft_config = LoraConfig() | ||
peft_model = get_peft_model(base_model, peft_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoraConfig
and get_peft_model
need to be imported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved.
@elementary-particle Thanks for the update. Could you please run |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@elementary-particle This PR is almost good to go, just a small merge conflict, could you please check it out? |
Thanks for keeping up with this PR. The merge conflict is resolved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing this warning, LGTM.
It should be possible for the user to specify a local directory as the base model in the library.
However, currently the library only checks for remote presence of
config.json
, and fails to check the actualconfig.json
when using a local repo.This PR adds check for a local
model_id
and fixes the behavior.