Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix position embeddings for GPT-J and CodeGen #22069

Merged
merged 17 commits into from Mar 22, 2023

Conversation

njhill
Copy link
Contributor

@njhill njhill commented Mar 10, 2023

What does this PR do?

Identical inputs to GPT-J and CodeGen models will currently generate different outputs if they are padded differently (for example in a batch of variable sequence lengths).

This PR reverts the recent change #21869 that removes GPT-J position_ids, and then applies similar changes as were done for GPT-J XLA in #17986.

One copy of the precomputed position embeddings is shared between all of the layers.

Related issue: #21080

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

@njhill
Copy link
Contributor Author

njhill commented Mar 10, 2023

I have tested this with my own code/usecase but wanted to check that there is interest in the contribution before also updating any applicable unit tests.

I also wonder whether there should be a universal test applied to all models that just tests the same input with different amounts of padding and makes sure that the output is identical?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments on the GPT-J side. Once we're both happy with GPT-J, Codegen should be copy/paste :)

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
@gante
Copy link
Member

gante commented Mar 10, 2023

@njhill and yes, the contribution is deeply appreciated! 🙏

Be mindful that this will not result in making the outputs left-padding agnostic. As in all models, the padding is a numerical mask. In FP32, it is almost left-padding agnostic, but in FP16/BF16/INT8 the left-padding may introduce changes :)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 10, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Two additional requests before merging:
1 - Can you confirm that existing slow tests pass for GPT-J and Codegen? They are tests with batch_size=1, so they should see no changes.
2 - Can you add a slow integration test on Codegen with batch_size=2? (GPT-J only has a massive model, so our CI doesn't run its integration tests :( ). Preferably, with an example that changes with the introduction of these changes, if you can find one 💛

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
@gante gante requested a review from sgugger March 11, 2023 11:53
@ydshieh
Copy link
Collaborator

ydshieh commented Mar 11, 2023

@gante I didn't review this PR, but I see this is related to issue #21080, and therefore to PR #21853 indirectly, which was reverted in #22093 due to some unexpected tests failure (PT/TF, PT/Flax).

So before merging this PR, it's better to verify the cross tests, as well as the slow tests too (always better).

The PR CI for #21853 was green (and also green when merged to main), but some tests started to fail in subsequent PRs. It's unclear to us why we didn't catch these in the PR CI though.

@gante
Copy link
Member

gante commented Mar 11, 2023

Thanks for the heads up @ydshieh! 🙏

I'll make sure all related slow tests (and the tests that failed after merging #21853 ) are passing before merging.

@njhill
Copy link
Contributor Author

njhill commented Mar 12, 2023

Thanks @gante ... I'm kind of new to this but will figure out how to verify/update the tests per your request.

The main problem I've run into though is newly-failing torch.fx tracing tests:

FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx - AssertionError: Couldn't trace module: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors

I've tried some different variations to the logic but always end up with similar kind of errors. I think it may stem from the index_select operation. Any pointers/ideas would be appreciated!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Before merging, let's make sure the cross tests pass though this time ;-)

@gante
Copy link
Member

gante commented Mar 13, 2023

Hey @njhill 👋

I've tried to fix the issue you mentioned with no success. It seems like we are between a rock and a hard place -- the changes you made, by design, make sincos dependent on the values of position_ids. In other words, sincos becomes a tensor impossible to predict at compile time with torch.fx, i.e. dynamic tensor. Ultimately, no matter how we rewrite the code (AFAIK), we will hit this barrier, causing the test to fail.

@sgugger @fxmarty is there a way we can make torch.fx ignore a function? (or do you have suggestions?) The change in this PR makes GPT-J correct in the presence of left-padding, but breaks compatibility with torch.fx 🙈

(Pastebin containing the code with modifications, working through the exceptions until I got stuck: https://pastebin.com/T0HpD07C)

@sgugger
Copy link
Collaborator

sgugger commented Mar 13, 2023

Also cc @michaelbenayoun for torch fx.

@njhill
Copy link
Contributor Author

njhill commented Mar 13, 2023

Thanks @gante, it sounds like you followed a similar path to me w.r.t. trying different arrangements of the logic to get around this. I was guessing this couldn't be the only occurrence of this dynamic tensor issue in the library - is dynamic slicing done elsewhere and if so how does it work with torch.fx?

@michaelbenayoun
Copy link
Member

Hi @njhill,

The issue here (from what I could understand from this), seems to be that during tracing we do not have regular tensors but rather symbolic "proxies".

In the following code we are trying to call __iter__ on sincos which is symbolic, we do not know its length (again, not 100% sure but guessing).

sincos = [t.contiguous() for t in sincos]

But the previous line is :

sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

Meaning that the list has:

  • 2 elements if sincos.shape[-1] is an even number
  • 3 elements if sincos.shape[-1] is an odd number.

So could you try this:

sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
len_sincos = 2 + torch.remainder(torch.tensor(sincos.shape[-1], 2))
sincos = [sincos[idx].contiguous() for idx in torch.arange(len_sincos)]

Tell me if this works!

@njhill
Copy link
Contributor Author

njhill commented Mar 14, 2023

Thanks @michaelbenayoun. You are right that this seems to be the fact that a symbolic proxy tensor is introduced somewhere, however I think that this stems from the tensor-based indexing here:

sincos = embed_positions[position_ids]

The proxy iterator errors are easy to circumvent but just move the problem until later where (inevitably?) the size of the proxy tensor is used for flow control. I've pushed a couple of small updates to the PR to demonstrate this... you can see the latest error in the tests here. As @gante pointed out above:

Ultimately, no matter how we rewrite the code (AFAIK), we will hit this barrier, causing the test to fail.

Could we at least make this path conditional such that it isn't followed in the torch.fx case, i.e. declare that variable padding is unsupported in that case?

@gante
Copy link
Member

gante commented Mar 15, 2023

Hey @njhill -- I think the conditional path is a sensible idea, at least for now (we can always revisit it later). #22161 reports a similar problem on another demanded model, so I would like to merge the fix as soon as possible 🤗

For context, other places in the transformers do this sort of conditional paths for torch.fx. Check here for an example.

@michaelbenayoun
Copy link
Member

@njhill The HF tracer is supposed to keep track of "concrete" metadata during tracing to allow for that.
In this case, either this does not work with len, which is possible (I do not remember tbh), or it means than an op does not support the meta device, hence breaking the concrete metadata accumulation.

Since in this case you are trying to check the rank of the tensor, could you try replacing len(tensor.shape) by tensor.ndim?

@njhill
Copy link
Contributor Author

njhill commented Mar 18, 2023

Thanks @michaelbenayoun .. the len problem can be avoided by adding torch.fx.wrap('len'), which I'd done in the prior commit but removed in this latest commit since it seemed futile (just moving the error slightly later). So I was instead attempting to bypass the position_ids fix in the torch.fx case per this comment (so far unsuccessfully).

The problem encountered after working around the len problem can be seen here:

>       if len(tensor.shape) == 5:

AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow

basically this traced length value is then used in a control flow condition.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

src/transformers/utils/fx.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work sorting the torch.fx issues! 💛

I left a final question/suggestion, but happy to merge it as is. @njhill I leave the final decision up to you

Comment on lines 213 to 216
if is_torch_fx_proxy(position_ids):
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
embed_positions = self._get_embed_positions(position_ids)
Copy link
Member

@gante gante Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the first branch (embed_positions = get_embed_positions(self.embed_positions, position_ids)) work in both cases?

It would result in simpler code :) In other words, ditch the if/else to always call embed_positions = get_embed_positions(self.embed_positions, position_ids), then remove the bits that are no longer needed.

(ditto for codegen)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gante, my concern is that the cached position embeddings will initially be in CPU mem and so will need to be copied to GPU every time. The logic in _get_embed_positions copies the first time and then no longer needs to, but I couldn't get it to work for the torch.fx case, hence this if/else.

I'll add a comment to the code to make that clear. Codegen is slightly simpler because it doesn't support torch.fx anyhow.

@gante
Copy link
Member

gante commented Mar 21, 2023

For our future reference, here's a snippet that shows that left-padding is fixed with these changes:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tok = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16).to(0)
tok.pad_token = tok.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

inputs_1 = tok(["The brown fox"], return_tensors="pt", padding=True).to(0)
out_1 = model(**inputs_1)
out_2 = model(**inputs_1)

position_ids = torch.cumsum(inputs_1.attention_mask, dim=-1) - 1
out_3 = model(**inputs_1, position_ids=position_ids + 8)

inputs_2 = tok(["The brown fox"], return_tensors="pt", padding="max_length", max_length=10).to(0)
out_4 = model(**inputs_2)

position_ids = torch.cumsum(inputs_2.attention_mask, dim=-1) - 1
position_ids.masked_fill_(inputs_2.attention_mask == 0, 1)
out_5 = model(**inputs_2, position_ids=position_ids)

# calls with the same inputs get the same logits
print(torch.max(torch.abs(out_1.logits[:, -1, :] - out_2.logits[:, -1, :]))) # tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)

# changing the position_ids changes the logits
print(torch.max(torch.abs(out_1.logits[:, -1, :] - out_3.logits[:, -1, :]))) # tensor(0.0625, device='cuda:0', grad_fn=<MaxBackward1>)

# padding and not passing position ids -> incorrect position ids -> output differences
print(torch.max(torch.abs(out_1.logits[:, -1, :] - out_4.logits[:, -1, :]))) # tensor(0.0625, device='cuda:0', grad_fn=<MaxBackward1>)

# left-padding has a much smaller impact (NOTE: setting e.g. `max_length=20` will cause the next diff to be non-zero.
# Numerical masking is not perfect :) )
print(torch.max(torch.abs(out_1.logits[:, -1, :] - out_5.logits[:, -1, :]))) # tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)

@gante
Copy link
Member

gante commented Mar 22, 2023

The failing CI was fixed in this merged PR, merging.

@gante gante merged commit 4e94c6c into huggingface:main Mar 22, 2023
3 of 4 checks passed
@gante
Copy link
Member

gante commented Mar 22, 2023

@njhill fantastic work with the torch.fx, I really appreciated your effort 🤗

@njhill njhill deleted the fix_pos_embeds branch March 22, 2023 14:48
@njhill
Copy link
Contributor Author

njhill commented Mar 22, 2023

Thanks @gante, glad I was able to contribute. Thank you for your fast responses and for all the great work you and team do.

@stas00
Copy link
Contributor

stas00 commented Mar 22, 2023

This PR isn't backward compatible. It breaks with pytorch-1.8:

E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/gptj/modeling_gptj.py", line 63, in <module>
E               @torch.fx.wrap
E           AttributeError: module 'torch' has no attribute 'fx'

not sure if you want to revert this or have an idea how to overcome this quickly.

@ydshieh
Copy link
Collaborator

ydshieh commented Mar 22, 2023

This PR isn't backward compatible. It breaks with pytorch-1.8:

E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/gptj/modeling_gptj.py", line 63, in <module>
E               @torch.fx.wrap
E           AttributeError: module 'torch' has no attribute 'fx'

not sure if you want to revert this or have an idea how to overcome this quickly.

@stas00

FYI, see #22291, although that PR and this PR is not directly related from the beginning when they are opened.

@stas00
Copy link
Contributor

stas00 commented Mar 22, 2023

ok, the deepspeed CI is running pt-1.8 - how do we solve that then?

@ydshieh
Copy link
Collaborator

ydshieh commented Mar 22, 2023

ok, the deepspeed CI is running pt-1.8 - how do we solve that then?

I just saw

microsoft/DeepSpeed#3082

opened 2 hours ago. I am not sure what will go, but I will try to follow tomorrow morning.

@stas00
Copy link
Contributor

stas00 commented Mar 22, 2023

oh, ok, I guess everything is fine then. thank you for the heads up, @ydshieh

@stas00
Copy link
Contributor

stas00 commented Mar 22, 2023

it still fails with pt-1.9.1

  1. you need import torch.fx (thanks @mrwyattii)

  2. it then fails with:

E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/gptj/modeling_gptj.py", line 61, in create_sinusoidal_positions
E               return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
E           AttributeError: module 'torch' has no attribute 'concat'

@njhill
Copy link
Contributor Author

njhill commented Mar 22, 2023

Oops, I guess we should use torch.cat() instead

@stas00
Copy link
Contributor

stas00 commented Mar 22, 2023

and it fails w/o import torch.fx

E             File "/mnt/nvme0/code/huggingface/transformers-master/examples/pytorch/language-modeling/run_clm.py", line 412, in main
E               model = AutoModelForCausalLM.from_pretrained(
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 470, in from_pretrained
E               model_class = _get_model_class(config, cls._model_mapping)
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 360, in _get_model_class
E               supported_models = model_mapping[type(config)]
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 602, in __getitem__
E               return self._load_attr_from_module(model_type, model_name)
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 616, in _load_attr_from_module
E               return getattribute_from_module(self._modules[module_name], attr)
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 561, in getattribute_from_module
E               if hasattr(module, attr):
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/utils/import_utils.py", line 1109, in __getattr__
E               module = self._get_module(self._class_to_module[name])
E             File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/utils/import_utils.py", line 1121, in _get_module
E               raise RuntimeError(
E           RuntimeError: Failed to import transformers.models.gptj.modeling_gptj because of the following error (look up to see its traceback):
E           module 'torch' has no attribute 'fx'

so 2 fixes at least. thank you!

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

I confirm that it works with torch.cat

perhaps use torch.concat but add an alias:

# bc for pt<1.10
if not getattr(torch, "concat"):
    torch.concat = torch.cat

stashed somewhere in utils?

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

import torch.fx is a must - even with pt-1.10 it won't work w/o it.

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

@njhill, are you on top of fixing this?

This is a bit urgent since Deepspeed CI uses our bleed edge to test deepspeed bleed edge on live CI. and currently their CI breaks because of this breakage.

@njhill
Copy link
Contributor Author

njhill commented Mar 23, 2023

@stas00 apologies I am AFK right now but could do it in a few hours. Feel free to do in the meantime if you like!

I don’t see any downside to just using torch.cat since it’s already an alias.

Where is it that we need to add the extra torch.fx import?

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

sure, I will fire off a PR - thank you for letting me know your preferences, @njhill

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

the fix is here #22325

@stas00
Copy link
Contributor

stas00 commented Mar 23, 2023

the fix has been merged.

raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
* Revert "[GPT-J] add deprecation warning (huggingface#21869)"

This reverts commit fb76994.

* Fix position embeddings for GPT-J and CodeGen

* Address review comments from @gante

* Fix "Copied from" comment referencing wrong function

* Fix copy/paste mistake

* Fix training path

* Hopefully make torch.fx happy

* Move position_ids long cast

* Revert "Hopefully make torch.fx happy"

This reverts commit e41a6f4.

* Changes to help with torch.fx tracing

* Linter fix

* Correct position_ids tensor type hint

* Work-around torch.fx tracing issue

* Get the changes to work with torch.fx

* Address review comment from @michaelbenayoun

* Another small adjustment

* Add explanatory comment; small code tidyup
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Revert "[GPT-J] add deprecation warning (huggingface#21869)"

This reverts commit fb76994.

* Fix position embeddings for GPT-J and CodeGen

* Address review comments from @gante

* Fix "Copied from" comment referencing wrong function

* Fix copy/paste mistake

* Fix training path

* Hopefully make torch.fx happy

* Move position_ids long cast

* Revert "Hopefully make torch.fx happy"

This reverts commit e41a6f4.

* Changes to help with torch.fx tracing

* Linter fix

* Correct position_ids tensor type hint

* Work-around torch.fx tracing issue

* Get the changes to work with torch.fx

* Address review comment from @michaelbenayoun

* Another small adjustment

* Add explanatory comment; small code tidyup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants