Skip to content

Commit

Permalink
Fixing GPU for token-classification in a better way. (huggingface#13856)
Browse files Browse the repository at this point in the history
Co-authored-by:  Pierre Snell <pierre.snell@botpress.com>

Co-authored-by: Pierre Snell <pierre.snell@botpress.com>
  • Loading branch information
2 people authored and Alberto B茅gu茅 committed Jan 27, 2022
1 parent 0fcbee7 commit be3f60c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/transformers/pipelines/base.py
Expand Up @@ -791,7 +791,7 @@ def _ensure_tensor_on_device(self, inputs, device):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs.to(self.device)
return inputs.to(device)
else:
return inputs

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/pipelines/token_classification.py
Expand Up @@ -204,9 +204,10 @@ def _forward(self, model_inputs):
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0].numpy()
outputs = self.model(model_inputs.data)[0][0]
else:
outputs = self.model(**model_inputs)[0][0].numpy()
outputs = self.model(**model_inputs)[0][0]

return {
"outputs": outputs,
"special_tokens_mask": special_tokens_mask,
Expand All @@ -216,7 +217,7 @@ def _forward(self, model_inputs):
}

def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
outputs = model_outputs["outputs"]
outputs = model_outputs["outputs"].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
Expand Down
22 changes: 21 additions & 1 deletion tests/test_pipelines_token_classification.py
Expand Up @@ -25,7 +25,14 @@
pipeline,
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)

from .test_pipelines_common import ANY, PipelineTestCaseMeta

Expand Down Expand Up @@ -246,6 +253,19 @@ def test_spanish_bert(self):
],
)

@require_torch_gpu
@slow
def test_gpu(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
aggregation_strategy=AggregationStrategy.SIMPLE,
)

output = ner(sentence)
self.assertEqual(nested_simplify(output), [])

@require_torch
@slow
def test_dbmdz_english(self):
Expand Down

0 comments on commit be3f60c

Please sign in to comment.