Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deberta_v2 tf #13120

Merged
merged 7 commits into from Aug 31, 2021
Merged

Deberta_v2 tf #13120

merged 7 commits into from Aug 31, 2021

Conversation

kamalkraj
Copy link
Contributor

@kamalkraj kamalkraj commented Aug 13, 2021

What does this PR do?

Deberta-v2 TF

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @patrickvonplaten

@kamalkraj
Copy link
Contributor Author

@Rocketknight1
#12972 (comment)
gather function fails while running run_glue.py from examples
Screenshot 2021-08-13 at 9 55 59 PM

If i replace the gather function with experimental NumPy take_along_axis works - https://gist.github.com/kamalkraj/73ad5fa2b84de7e201e05464e11a4fec

@Rocketknight1
Copy link
Member

Hi @kamalkraj, do you know what shape the inputs are to the gather/take_along_axis? I'm going to try to construct a small test case that fails for my gather function but not for take_along_axis. If you can find a simple test case that fails, feel free to send that too so I can fix the function!

@kamalkraj
Copy link
Contributor Author

Hi @Rocketknight1
I have tried few tests for torch.gather when you initially shared the function. notebook link- https://colab.research.google.com/drive/1ujI6zKTuuryAO2Nfw9U1ZftyZyC4VUVS?usp=sharing

@Rocketknight1
Copy link
Member

In all of those cases, it looks like the TF torch_gather function gets the same results as the actual torch.gather, right? Is there a difference?

@kamalkraj
Copy link
Contributor Author

kamalkraj commented Aug 16, 2021

No. TF torch_gather function gets the same output as torch.gather.

Actually, in runtime, this branch never gets called

if query_layer.size(-2) != key_layer.size(-2):
p2c_att = torch.gather(
p2c_att,
dim=-2,
index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))),
)

because both query_layer and key_layer are of the same size

self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)

@kamalkraj
Copy link
Contributor Author

kamalkraj commented Aug 16, 2021

Hi @BigBird01,

I was going through deberta-v2 implementation inside huggingface and as per my understanding, for deberta-v2 the below branch will be never executed.

if query_layer.size(-2) != key_layer.size(-2):

Because query_layer and key_layer shapes are ->
[batch_size * num_attention_heads, sequence_length, attention_head_size]

the above condition may be needed for deberta. But Huggingface has separate implementation for deberta and deberta-v2
if my assumption is correct we can remove those never executed control flow branches from the deberta-v2 code.

@BigBird01
Copy link
Contributor

BigBird01 commented Aug 16, 2021 via email

@Rocketknight1
Copy link
Member

@Rocketknight1
#12972 (comment)
gather function fails while running run_glue.py from examples
Screenshot 2021-08-13 at 9 55 59 PM

If i replace the gather function with experimental NumPy take_along_axis works - https://gist.github.com/kamalkraj/73ad5fa2b84de7e201e05464e11a4fec

Hi @kamalkraj, can you share the exact glue task / command you used? I still can't reproduce the bug - I tried this:

python run_glue.py --model_name_or_path kamalkraj/deberta-v2-xlarge --task_name mnli --do_train --do_eval --do_predict --output_dir output

This seemed to work fine with torch_gather.

@kamalkraj
Copy link
Contributor Author

kamalkraj commented Aug 17, 2021

@Rocketknight1
the issue is solved with this commit 90c122d .

torch_gather function under those if condition was creating the issue. I removed those conditions as it was unnecessary .
You can see the discussion #13120 (comment)

I also opened another pull request to remove from PyTorch model also. #13145

@kamalkraj
Copy link
Contributor Author

Hi @Rocketknight1 ,
#13145 is merged to master. Now the TF implementation is the same as the torch Implementation. and runs without any issues

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work @kamalkraj! I only left some nits - looks good to me to merge!

The only thing that confused me a bit was the if-else logic depending on whether hidden_states is of type sequences, e.g. here: https://github.com/huggingface/transformers/pull/13120/files#r694681300
-> when would that be the case?

@kamalkraj
Copy link
Contributor Author

Hi @patrickvonplaten ,
thanks for the review.
committed changes.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tremendous work putting this all together @kamalkraj!

Most of my comments are regarding the # Copied from statements that could be added to most classes here, it seems even the actual model classes like TFDebertaV2ForSequenceClassification could benefit from them.

docs/source/index.rst Show resolved Hide resolved
_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge"

TF_DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"kamalkraj/deberta-v2-xlarge",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to migrate the TF checkpoint to the official one in microsoft/deberta-v2-xlarge

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

@kamalkraj
Copy link
Contributor Author

Hi @LysandreJik,
committed changes.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks like a very solid PR now, and if tests are passing with good performance then I think it should be just about ready to go, assuming everyone else is in agreement!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work @kamalkraj!

@LysandreJik LysandreJik merged commit 3efcfea into huggingface:master Aug 31, 2021
@kamalkraj kamalkraj deleted the deberta_v2-tf branch September 11, 2021 10:30
@sh0416
Copy link

sh0416 commented May 12, 2022

Is this code compatible with model.fit?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants