Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing GPU for token-classification in a better way. #13856

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/pipelines/base.py
Expand Up @@ -791,7 +791,7 @@ def _ensure_tensor_on_device(self, inputs, device):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs.to(self.device)
return inputs.to(device)
else:
return inputs

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/pipelines/token_classification.py
Expand Up @@ -204,9 +204,10 @@ def _forward(self, model_inputs):
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0].numpy()
outputs = self.model(model_inputs.data)[0][0]
else:
outputs = self.model(**model_inputs)[0][0].numpy()
outputs = self.model(**model_inputs)[0][0]

return {
"outputs": outputs,
"special_tokens_mask": special_tokens_mask,
Expand All @@ -216,7 +217,7 @@ def _forward(self, model_inputs):
}

def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE):
outputs = model_outputs["outputs"]
outputs = model_outputs["outputs"].numpy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as familiar as you with the pipeline but will the model_outputs["outputs"] can still be on GPU ?
Maybe all preprocessing is firstly casted to CPU but I would rather double check.

Else should we use model_outputs["outputs"].cpu().numpy() instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ierezell,

Just like this file does not ever mention .to(device) it shouldn't mention .cpu() as it's not the task of the pipeline to encapsulate that logic. Before it was done this way meaning many pipelines wouldn't actually support some features.

The logic is now in the Pipeline.forward method.

Rougly:
preprocess: (generic python objects) -> Tensors
_forward: Tensors -> Tensors
postprocess: Tensors -> Generic python objects

forward encapsulate classic logic for inference (inference mode, GPU tensor moving and other if needed, like batching)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then everything will be at the good place at the good moment so that's perfect! Thanks for the clarification.

sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
Expand Down
22 changes: 21 additions & 1 deletion tests/test_pipelines_token_classification.py
Expand Up @@ -25,7 +25,14 @@
pipeline,
)
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)

from .test_pipelines_common import ANY, PipelineTestCaseMeta

Expand Down Expand Up @@ -246,6 +253,19 @@ def test_spanish_bert(self):
],
)

@require_torch_gpu
@slow
def test_gpu(self):
sentence = "This is dummy sentence"
ner = pipeline(
"token-classification",
device=0,
aggregation_strategy=AggregationStrategy.SIMPLE,
)

output = ner(sentence)
self.assertEqual(nested_simplify(output), [])

@require_torch
@slow
def test_dbmdz_english(self):
Expand Down