Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing #13381 #13400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/transformers/pipelines/zero_shot_classification.py
Expand Up @@ -88,7 +88,7 @@ def _parse_and_tokenize(
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
truncation=TruncationStrategy.ONLY_FIRST,
**kwargs
):
"""
Expand All @@ -113,13 +113,31 @@ def _parse_and_tokenize(
)
inputs.append(model_input)
else:
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
try:
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
except Exception as e:
if "too short" in str(e):
# tokenizers might yell that we want to truncate
# to a value that is not even reached by the input.
# In that case we don't want to truncate.
# It seems there's not a really better way to catch that
# exception.
Comment on lines +125 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide a reproducer for this issue? I can't seem to find where that would come from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basic test for the pipelines RUN_PIPELINE_TESTS=1 pytest -sv tests/test_pipelines_zero_shot_classification.py will do (on the LED test)


inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
)
else:
raise e

return inputs

Expand Down
14 changes: 14 additions & 0 deletions tests/test_pipelines_zero_shot.py
Expand Up @@ -105,6 +105,20 @@ def run_entailment_id(self, zero_shot_classifier: Pipeline):
zero_shot_classifier.model.config.label2id = original_label2id
self.assertEqual(original_entailment, zero_shot_classifier.entailment_id)

@require_torch
def test_truncation(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
)
# There was a regression in 4.10 for this
# Adding a test so we don't make the mistake again.
# https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499
zero_shot_classifier(
"Who are you voting for in 2020?" * 100, candidate_labels=["politics", "public health", "science"]
)

@require_torch
def test_small_model_pt(self):
zero_shot_classifier = pipeline(
Expand Down