From c86772a07e640623a9dc2ba8ab9a9840eb16b7f7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 31 May 2022 09:15:44 +0200 Subject: [PATCH 1/8] better error --- tests/bases/test_metric.py | 12 ++++++++++++ torchmetrics/metric.py | 15 ++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 07c0ccc6bbf..03654675026 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -25,6 +25,7 @@ from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum +from torchmetrics import PearsonCorrCoef from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 seed_all(42) @@ -423,3 +424,14 @@ class UnsetProperty(metric_class): match="Torchmetrics v0.9 introduced a new argument class property called.*", ): UnsetProperty() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") +def test_specific_error_on_wrong_device(): + metric = PearsonCorrCoef() + preds = torch.tensor(range(10), device="cuda", dtype=torch.float) + target = torch.tensor(range(10), device="cuda", dtype=torch.float) + with pytest.raises( + RuntimeError, match="This could be due to the metric class not being on the same device as input" + ): + _ = metric(preds, target) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 67e53d4130e..ddbc0408b9a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import inspect +import traceback from abc import ABC, abstractmethod from contextlib import contextmanager from copy import deepcopy @@ -377,7 +378,19 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: self._computed = None self._update_count += 1 with torch.set_grad_enabled(self._enable_grad): - update(*args, **kwargs) + try: + update(*args, **kwargs) + except RuntimeError as e: + if "Expected all tensors to be on" in str(e): + traceback.print_exc() + raise RuntimeError( + f"{str(e)}. \n" + "This could be due to the metric class not being on the same device as input.\n" + "Instead of `metric=Metric()` try to do `metric=Metric().to(device)` where" + " device corresponds to the device of the input." + ) + raise e + if self.compute_on_cpu: self._move_list_states_to_cpu() From a5ed3977433ff02ebd74e59fcf3702945592fd5c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 1 Jun 2022 23:15:18 +0200 Subject: [PATCH 2/8] Apply suggestions from code review --- torchmetrics/metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ddbc0408b9a..f3d6e920de2 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -380,16 +380,16 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: with torch.set_grad_enabled(self._enable_grad): try: update(*args, **kwargs) - except RuntimeError as e: - if "Expected all tensors to be on" in str(e): + except RuntimeError as err: + if "Expected all tensors to be on" in str(err): traceback.print_exc() raise RuntimeError( - f"{str(e)}. \n" + f"{str(err)}. \n" "This could be due to the metric class not being on the same device as input.\n" "Instead of `metric=Metric()` try to do `metric=Metric().to(device)` where" " device corresponds to the device of the input." ) - raise e + raise err if self.compute_on_cpu: self._move_list_states_to_cpu() From eab570239d2540a2e97dfebe85f1728f9036694d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 3 Jun 2022 20:29:43 +0200 Subject: [PATCH 3/8] suggestions --- CHANGELOG.md | 2 +- torchmetrics/metric.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee1bc72d827..0e708a4bfd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added specific runtime error when metric object is on wrong device (#1056(https://github.com/PyTorchLightning/metrics/pull/1056)) - diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index f3d6e920de2..3d6bc2ecb75 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import inspect -import traceback from abc import ABC, abstractmethod from contextlib import contextmanager from copy import deepcopy @@ -382,13 +381,12 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: update(*args, **kwargs) except RuntimeError as err: if "Expected all tensors to be on" in str(err): - traceback.print_exc() raise RuntimeError( f"{str(err)}. \n" "This could be due to the metric class not being on the same device as input.\n" "Instead of `metric=Metric()` try to do `metric=Metric().to(device)` where" " device corresponds to the device of the input." - ) + ) from err raise err if self.compute_on_cpu: From 7c723c1bdded91569e783ddfa041eb9975a921b0 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Jun 2022 11:29:40 +0200 Subject: [PATCH 4/8] update based on suggestions --- torchmetrics/metric.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 3d6bc2ecb75..d270fc71895 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -382,9 +382,10 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( - f"{str(err)}. \n" + "Encountered different devices in metric calculation.\n" "This could be due to the metric class not being on the same device as input.\n" - "Instead of `metric=Metric()` try to do `metric=Metric().to(device)` where" + f"Instead of `metric={self.__class__.__name__}(...)` try to do" + f" `metric={self.__class__.__name__}(...).to(device)` where" " device corresponds to the device of the input." ) from err raise err From 1af0fed5918554f0741c2776ad71b10b2ebe75a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jun 2022 08:56:36 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 8d039e8eb09..e131afcf97f 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -25,8 +25,8 @@ from tests.helpers import seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum -from torchmetrics import PearsonCorrCoef from tests.helpers.utilities import no_warning_call +from torchmetrics import PearsonCorrCoef from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 seed_all(42) From 0985c28e9806440274b114327f2df56b911744ff Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 7 Jun 2022 11:12:03 +0200 Subject: [PATCH 6/8] add error --- torchmetrics/metric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index ed46beb6628..5354d4642f7 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -383,6 +383,7 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( + f"{str(err)}. \n" "Encountered different devices in metric calculation.\n" "This could be due to the metric class not being on the same device as input.\n" f"Instead of `metric={self.__class__.__name__}(...)` try to do" From f5328e86ccb3d91ac264bf2e72d695b7faac4a34 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 7 Jun 2022 12:38:44 +0200 Subject: [PATCH 7/8] remove linebreaks --- torchmetrics/metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 5354d4642f7..16770c03f46 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -383,9 +383,9 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( - f"{str(err)}. \n" - "Encountered different devices in metric calculation.\n" - "This could be due to the metric class not being on the same device as input.\n" + f"{str(err)}." + "Encountered different devices in metric calculation." + "This could be due to the metric class not being on the same device as input." f"Instead of `metric={self.__class__.__name__}(...)` try to do" f" `metric={self.__class__.__name__}(...).to(device)` where" " device corresponds to the device of the input." From 86633e0e9dbeeb47815a59cbee0b88e1efbf7f41 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 7 Jun 2022 13:47:39 +0200 Subject: [PATCH 8/8] remove err --- torchmetrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 16770c03f46..edecf0fa72e 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -383,8 +383,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( - f"{str(err)}." - "Encountered different devices in metric calculation." + "Encountered different devices in metric calculation" + " (see stacktrace for details)." "This could be due to the metric class not being on the same device as input." f"Instead of `metric={self.__class__.__name__}(...)` try to do" f" `metric={self.__class__.__name__}(...).to(device)` where"