Skip to content

Commit

Permalink
Fixing backward compatiblity for non prefixed tokens (B-, I-). (#13493)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and patrickvonplaten committed Sep 10, 2021
1 parent 60eb416 commit 4afbd7e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/pipelines/token_classification.py
Expand Up @@ -411,7 +411,8 @@ def get_tag(self, entity_name: str) -> Tuple[str, str]:
tag = entity_name[2:]
else:
# It's not in B-, I- format
bi = "B"
# Default to I- for continuation.
bi = "I"
tag = entity_name
return bi, tag

Expand Down
53 changes: 53 additions & 0 deletions tests/test_pipelines_token_classification.py
Expand Up @@ -318,6 +318,59 @@ def test_aggregation_strategy_byte_level_tokenizer(self):
],
)

@require_torch
def test_aggregation_strategy_no_b_i_prefix(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"}
example = [
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9968166351318359]),
"index": 1,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9957635998725891]),
"index": 2,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
# fmt: off
"scores": np.array([0, 0, 0, 0.9986497163772583, 0]),
# fmt: on
"index": 7,
"word": "UN",
"is_subword": False,
"start": 11,
"end": 13,
},
]
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
[
{"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1},
{"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
{"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
[
{"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)

@require_torch
def test_aggregation_strategy(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down

0 comments on commit 4afbd7e

Please sign in to comment.