From 9e26cae4f03585ba795c722ebd7d7890601598a0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 7 Jun 2022 16:06:22 +0200 Subject: [PATCH] Better error message on wrong device (#1056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Apply suggestions from code review Co-authored-by: Jirka Borovec Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 4 ++++ tests/bases/test_metric.py | 12 ++++++++++++ torchmetrics/metric.py | 15 ++++++++++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fcd68d4b8d..418944d93d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added specific `RuntimeError` when metric object is on wrong device ([#1056](https://github.com/PyTorchLightning/metrics/pull/1056)) + + - Added an option to specify own n-gram weights for `BLEUScore` and `SacreBLEUScore` instead of using uniform weights only. ([#1075](https://github.com/PyTorchLightning/metrics/pull/1075)) + - diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index f27c1b2a2ea..e131afcf97f 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -26,6 +26,7 @@ from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum from tests.helpers.utilities import no_warning_call +from torchmetrics import PearsonCorrCoef from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 seed_all(42) @@ -426,6 +427,17 @@ class UnsetProperty(metric_class): UnsetProperty() +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") +def test_specific_error_on_wrong_device(): + metric = PearsonCorrCoef() + preds = torch.tensor(range(10), device="cuda", dtype=torch.float) + target = torch.tensor(range(10), device="cuda", dtype=torch.float) + with pytest.raises( + RuntimeError, match="This could be due to the metric class not being on the same device as input" + ): + _ = metric(preds, target) + + @pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum]) def test_no_warning_on_custom_forward(metric_class): """If metric is using custom forward, full_state_update is irrelevant.""" diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index bfcc05d877a..edecf0fa72e 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -378,7 +378,20 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: self._computed = None self._update_count += 1 with torch.set_grad_enabled(self._enable_grad): - update(*args, **kwargs) + try: + update(*args, **kwargs) + except RuntimeError as err: + if "Expected all tensors to be on" in str(err): + raise RuntimeError( + "Encountered different devices in metric calculation" + " (see stacktrace for details)." + "This could be due to the metric class not being on the same device as input." + f"Instead of `metric={self.__class__.__name__}(...)` try to do" + f" `metric={self.__class__.__name__}(...).to(device)` where" + " device corresponds to the device of the input." + ) from err + raise err + if self.compute_on_cpu: self._move_list_states_to_cpu()