New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
can't run (TF)BartForConditionalGeneration.generation on GPU, it's speed very very very slow #17411
Comments
Hey @TheHonestBob 👋 We are aware of the generate speed problems with TensorFlow, and will be releasing an update very soon. It is not a bug, but rather how Eager Execution works, sadly. Stay tuned 🤞 |
thanks for your reply,what can I do before update to solve it. |
My advice would be to go with the PyTorch version, if performance is a bottleneck to you and you need something working in the next ~2 weeks. If you can afford to wait ~2 weeks, then you can have a look at the guides we are writing up at the moment :) |
OK, I will continue to pay attention no it |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@TheHonestBob -- some of the functionality to speed up has been merged recently. If you try running a modified version of your script and you have a GPU, you will see it is much much faster. import tensorflow as tf
from transformers import BertTokenizer, TFBartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = TFBartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese", from_pt=True)
batch_data = ['北京是[MASK]的首都']*64
xla_generate = tf.function(model.generate, jit_compile=True)
for i in range(20):
batch_dict = tokenizer.batch_encode_plus(batch_data, return_token_type_ids=False, return_tensors='tf')
result = xla_generate(**batch_dict, max_length=20, no_repeat_ngram_size=0, num_beams=1)
result = tokenizer.batch_decode(result, skip_special_tokens=True)
print(result) To enable bigger values of |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@TheHonestBob The newest release (v4.21) fixes this issue. Check our recent blog post -- https://huggingface.co/blog/tf-xla-generate |
thanks a lot, I'll try it |
System Info
Who can help?
@patil-suraj@patrickvonplaten, @Narsil, @gante@Rocketknight1
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
from transformers import BertTokenizer, TFBartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = TFBartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese", from_pt=True)
batch_data = ['北京是[MASK]的首都']*64
for i in range(20):
batch_dict = tokenizer.batch_encode_plus(batch_data, return_token_type_ids=False, return_tensors='tf')
result = model.generate(**batch_dict, max_length=20)
result = tokenizer.batch_decode(result, skip_special_tokens=True)
print(result)
Expected behavior
The text was updated successfully, but these errors were encountered: