Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA generation not working properly in some models #17935

Open
8 of 12 tasks
gante opened this issue Jun 29, 2022 · 12 comments
Open
8 of 12 tasks

TF: XLA generation not working properly in some models #17935

gante opened this issue Jun 29, 2022 · 12 comments
Assignees
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! TensorFlow Anything TensorFlow

Comments

@gante
Copy link
Member

gante commented Jun 29, 2022

This issue is used to track TensorFlow XLA generation issues, arising from #17857. There are three categories of issues, sorted in descending order by severity:

Key model issues

These are heavily-used models, whose quality should be prioritized.

  • T5 -- The quality of the results decreases with max_length. See here.
  • GPT-J -- fails simple generate tests with numerical issues

Models failing basic tests

These models are failing test_xla_generate_fast -- a short greedy generation.

  • LED
  • Speech2Text
  • XLNet
  • XGLM

Models failing complex tests

These are models failing test_xla_generate_slow -- a long beam search generation.

  • Bart
  • Blenderbot
  • Marian
  • mbart
  • OPT
  • Pegasus
@gante gante added Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! TensorFlow Anything TensorFlow labels Jun 29, 2022
@gante gante self-assigned this Jun 29, 2022
@anmolsjoshi
Copy link
Contributor

@gante do you require any help with this issue? Happy to contribute

@gante
Copy link
Member Author

gante commented Jul 12, 2022

Hi @anmolsjoshi 👋

If you are comfortable with debugging XLA, absolutely :) My recommendation would be to pick a model from "Models failing complex tests" (the others might require significant architecture changes), and to start debugging. The number 1 suspect is always the position embeddings, which may not be handling the case where past is padded. Let me know if you are up to it, and which model would you like to take!

@dsuess
Copy link
Contributor

dsuess commented Jul 25, 2022

Hi @gante, I did have a bit of a poke around. I think the complex tests all fail for the same reason: those models have a setting max_position_embeddings that is set to 20 by default during testing and which is too short for the “slow” tests. Here’s a simple fix for those: dsuess@4a3e271. I’ll give the other ones a shot now

@JuheonChu
Copy link
Contributor

Hello @gante, may I ask if there is anything that I can contribute?

@gante
Copy link
Member Author

gante commented Feb 11, 2023

Hi JuheonChu 👋 Actually yes! I have a few unchecked models at the top, but I wouldn't recommend spending time there unless you plan to use those architectures -- they are infrequently used.

However, two popular models are currently failing their XLA tests with beam search:

  • Marian
  • OPT

You can see the failing test if you install from main (pip install --upgrade git+https://github.com/huggingface/transformers.git) and run it e.g. for OPT NVIDIA_TF32_OVERRIDE=0 RUN_SLOW=1 py.test -vv tests/models/opt/test_modeling_tf_opt.py::TFOPTModelTest::test_xla_generate_slow

I haven't dived in yet, so I don't know what's the cause for the failure. You'll have to hop into debug mode and see what is breaking :)

@JuheonChu
Copy link
Contributor

Can @katiele47 and I try working on them?

@gante
Copy link
Member Author

gante commented Feb 15, 2023

@JuheonChu of course!

@JuheonChu
Copy link
Contributor

JuheonChu commented Feb 17, 2023

@JuheonChu of course!
@gante Are we figuring out the cause of the testing failures based on the clues as follows?

Error 1
Error 2
Error 3

@gante
Copy link
Member Author

gante commented Feb 17, 2023

@JuheonChu yes. My suggestion would be to attempt to find where the numerical differences start from (between the XLA and the non-XLA version), using a debugger. Please note that you can't print variables with jit_compile=True, so you should set it to False. From there, the root cause is typically apparent.

Be warned, these sort of tasks sometimes are very time-consuming to complete :)

@JuheonChu
Copy link
Contributor

@JuheonChu yes. My suggestion would be to attempt to find where the numerical differences start from (between the XLA and the non-XLA version), using a debugger. Please note that you can't print variables with jit_compile=True, so you should set it to False. From there, the root cause is typically apparent.

Be warned, these sort of tasks sometimes are very time-consuming to complete :)

Thank you very much for your valuable guidance! We will try and keep you updated!

@katiele47
Copy link
Contributor

Hi @gante, I've attempted to reproduce the failed XLA test on the OPT model using your suggested commands. The cause of error I had was somehow different from @JuheonChu's. Would you be able to verify if the following is the expected failing test output? If not, I assume it could be due to my local repo. Thanks!
Screen Shot 2023-02-21 at 11 20 44 PM
Screen Shot 2023-02-21 at 11 21 24 PM
Screen Shot 2023-02-21 at 11 21 43 PM

@soma2000-lang
Copy link
Contributor

@gante working on XLNet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! TensorFlow Anything TensorFlow
Projects
None yet
Development

No branches or pull requests

6 participants