-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Right now, we don't have specific tests for torch.compile. Instead, we have a "hack" that allows to run _all_ tests with torch.compile if we set the environment variable PEFT_DEBUG_WITH_TORCH_COMPILE=1. This is not very practical because it takes a lot of time to run all these tests with compilation enabled. Also, currently hundreds of tests are failing, which makes it impossible to understand more closely what does or does not work. This PR removes the aforementioned "hack" and instead replaces it with a list of explicit torch.compile tests. Currently, these tests cover training/inference with a bunch of different tuner types, as well as more advanced features with LoRA (e.g. quantization, multiple adapters, etc.). Some of these tests pass and some of them fail. This is documented now, so that users can quickly look up if their use case would be compatible with torch.compile. This is very useful to have, because sometimes torch.compile may appear to work but actually returns the wrong result. For users, it's not immediately obvious when this happens. The test suite is not exhaustive, there are many combinations of features that could be added. However, it should be a good starting point and can be extended later. The test suite does _not_ cover whether torch.compile actually accelerates the code. This may not be the case even if it works correctly (e.g. because of graph breaks). Testing this would require bigger models and more data, which is prohibitively slow to test. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
- Loading branch information
1 parent
3f7aacd
commit 4e32679
Showing
6 changed files
with
648 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# torch.compile | ||
|
||
In PEFT, [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) works for some but not all features. The reason why it won't always work is because PEFT is highly dynamic in certain places (loading and switching between multiple adapters, for instance), which can cause trouble for `torch.compile`. In other places, `torch.compile` may work, but won't be as fast as expected because of graph breaks. | ||
|
||
If you don't see an error, it doesn't necessarily mean that `torch.compile` worked correctly. It might give you an output, but the output is incorrect. This guide describes what works with `torch.compile` and what doesn't. | ||
|
||
> [!TIP] | ||
> Unless indicated otherwise, the default `torch.compile` settings were used. | ||
## Training and inference with `torch.compile` | ||
|
||
These features **work** with `torch.compile`. Everything listed below was tested with a causal LM: | ||
|
||
- Training with `Trainer` from 馃 transformers | ||
- Training with a custom PyTorch loop | ||
- Inference | ||
- Generation | ||
|
||
The following adapters were tested successfully: | ||
|
||
- AdaLoRA | ||
- BOFT | ||
- IA鲁 | ||
- Layer Norm Tuning | ||
- LoHa | ||
- LoRA | ||
- LoRA + DoRA | ||
- OFT | ||
- VeRA | ||
|
||
The following adapters **don't work** correctly for training or inference when using `torch.compile`: | ||
|
||
- LoKr | ||
- LoRA targeting embedding layers | ||
|
||
## Advanced PEFT features with `torch.compile` | ||
|
||
Below are some of the more advanced PEFT features that **work**. They were all tested with LoRA. | ||
|
||
- `modules_to_save` (i.e. `config = LoraConfig(..., modules_to_save=...)`) | ||
- Merging adapters (one or multiple) | ||
- Merging multiple adapters into one adapter (i.e. calling `model.add_weighted_adapter(...)`) | ||
|
||
Generally, we can expect that if a feature works correctly with LoRA and is also supported by other adapter types, it should also work for that adapter type. | ||
|
||
The more advanced PEFT features below **don't work** in conjunction with `torch.compile`. Tests were run with LoRA: | ||
|
||
- Using PEFT adapters with quantization (bitsandbytes) | ||
- Inference with multiple adapters | ||
- Unloading (i.e. calling `model.merge_and_unload()`) | ||
- Disabling adapters (i.e. using `with model.disable_adapter()`) | ||
- Mixed adapter batches (i.e. calling `model(batch, adapter_names=["__base__", "default", "other", ...])`) | ||
|
||
## Test cases | ||
|
||
All the use cases listed above are tested inside of [`peft/tests/test_torch_compile.py`](https://github.com/huggingface/peft/blob/main/tests/test_torch_compile.py). If you want to check in more detail how we tested a certain feature, please go to that file and check the test that corresponds to your use case. | ||
|
||
> [!TIP] | ||
> If you have another use case where you know that `torch.compile` does or does not work with PEFT, please contribute by letting us know or by opening a PR to add this use case to the covered test cases. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +0,0 @@ | ||
import os | ||
|
||
|
||
if os.environ.get("PEFT_DEBUG_WITH_TORCH_COMPILE") == "1": | ||
# This is a hack purely for debugging purposes. If the environment variable PEFT_DEBUG_WITH_TORCH_COMPILE is set to | ||
# 1, get_peft_model() will return a compiled model. This way, all unit tests that use peft.get_peft_model() will | ||
# use a compiled model. See .github/workflows/torch_compile_tests.yml. | ||
import torch | ||
|
||
import peft | ||
from peft.mapping import get_peft_model as get_peft_model_original | ||
|
||
# TODO: Experimental dynamo feature that should allow correct compilation of more PEFT modules. This should be | ||
# removed once PyTorch has found a better solution, as this incurs a performance penalty. | ||
# https://github.com/pytorch/pytorch/issues/124717#issuecomment-2083235776 | ||
torch._dynamo.config.guard_nn_modules = True | ||
|
||
def get_peft_model_new(*args, **kwargs): | ||
"""Make get_peft_model() return a compiled model.""" | ||
peft_model = get_peft_model_original(*args, **kwargs) | ||
peft_model = torch.compile(peft_model) | ||
return peft_model | ||
|
||
peft.get_peft_model = get_peft_model_new | ||
Oops, something went wrong.