Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support peft package for loading lora-adapted models #19546

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/app/app.txt
Expand Up @@ -27,3 +27,4 @@ urllib3 <2.1.0
uvicorn <0.24.0
websocket-client <1.7.0
websockets <11.1.0
peft <0.8.2
86 changes: 86 additions & 0 deletions src/lightning/app/utilities/peft.py
@@ -0,0 +1,86 @@
import importlib
import os
from typing import Dict, Optional, Union

from packaging import version

ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"


def find_adapter_config_file(
model_id: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
_commit_hash: Optional[str] = None,
) -> Optional[str]:
r"""Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the
adapter config file if it is, None otherwise.

Args:
model_id (`str`):
The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.

<Tip>

To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".

</Tip>

local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.

"""
adapter_cached_filename = None
if model_id is None:
return None
if os.path.isdir(model_id):
list_remote_files = os.listdir(model_id)
if ADAPTER_CONFIG_NAME in list_remote_files:
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
return adapter_cached_filename


def check_peft_version(min_version: str) -> None:
r"""Checks if the version of PEFT is compatible.

Args:
version (`str`):
The version of PEFT to check against.

"""
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)

if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
19 changes: 19 additions & 0 deletions src/lightning/pytorch/core/saving.py
Expand Up @@ -81,6 +81,25 @@ def _load_from_checkpoint(

# TODO: make this a migration:
# for past checkpoint need to add the new key
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)

if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
Expand Down