From 8d741aa43b3027db44e1cbc27c45872d911f7e7c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 16 May 2022 08:32:46 +0200 Subject: [PATCH 01/20] func --- torchmetrics/utilities/__init__.py | 1 + torchmetrics/utilities/checks.py | 63 +++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/torchmetrics/utilities/__init__.py b/torchmetrics/utilities/__init__.py index 5049a9d6901..3ec9e96b675 100644 --- a/torchmetrics/utilities/__init__.py +++ b/torchmetrics/utilities/__init__.py @@ -1,3 +1,4 @@ +from torchmetrics.utilities.checks import check_forward_no_full_state # noqa: F401 from torchmetrics.utilities.data import apply_to_collection # noqa: F401 from torchmetrics.utilities.distributed import class_reduce, reduce # noqa: F401 from torchmetrics.utilities.prints import _future_warning, rank_zero_debug, rank_zero_info, rank_zero_warn # noqa: F401 diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index f688a873c80..e78208638ec 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from time import perf_counter +from typing import Any, Dict, Optional, Sequence, Tuple import torch from torch import Tensor @@ -604,3 +605,63 @@ def _check_retrieval_target_and_prediction_types( preds = preds.float() return preds.flatten(), target.flatten() + + +def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8): + """Utility function for recursively asserting that two results are within a certain tolerance.""" + # single output compare + if isinstance(res1, Tensor): + return torch.allclose(res1, res2, atol=atol) + elif isinstance(res1, Sequence): + return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2)) + elif isinstance(res1, Dict): + return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys()) + else: + raise ValueError("Unknown format for comparison") + + +def check_forward_no_full_state( + metric_class: object, + init_args: Dict[str, Any], + *input_args, + num_update_to_compare: Sequence[int] = [10, 100, 1000], + reps: int = 5, +) -> bool: + """""" + + class FullState(metric_class): + full_state_update = True + + class PartState(metric_class): + full_state_update = False + + fullstate = FullState(**init_args) + partstate = PartState(**init_args) + + equal = True + for _ in range(10): + out1 = fullstate(*input_args) + out2 = fullstate(*input_args) + equal = equal | _allclose_recursive(out1, out2) + + if not equal: # we can stop early because the states did not match + return False + + res = torch.zeros(2, len(num_update_to_compare), reps) + for i, metric in enumerate([fullstate, partstate]): + for j, t in enumerate(num_update_to_compare): + for r in range(reps): + start = perf_counter() + for _ in range(t): + _ = metric(*input_args) + end = perf_counter() + res[i, j, r] = end - start + metric.reset() + + mean = torch.mean(res, -1) + std = torch.std(res, -1) + + for t in range(len(num_update_to_compare)): + print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]}") + print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]}") + print() From 9a01d4e743bf48157079634a195c467465964fb2 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 16 May 2022 09:21:20 +0200 Subject: [PATCH 02/20] add warning --- torchmetrics/metric.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index a6987f98d4f..97b7996075a 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -73,7 +73,7 @@ class Metric(Module, ABC): __jit_unused_properties__ = ["is_differentiable"] is_differentiable: Optional[bool] = None higher_is_better: Optional[bool] = None - full_state_update: bool = True + full_state_update: Optional[bool] = None def __init__( self, @@ -126,6 +126,19 @@ def __init__( # state management self._is_synced = False self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None + + if self.full_state_update is None: + rank_zero_warn( + f"""Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has + not been set for this class ({self.__class__.__name__}). The property determines if ``update`` by + default needs access to the full metric state. If this is not the case, significant speedups can be + achived. We provide an checking function + ``from torchmetrics.utilities import check_forward_no_full_state`` + that can be used to check if the ``full_state_update=True`` (old and potential slower behaviour, + default for now) or if ``full_state_update=False`` can be used safely + """, + UserWarning, + ) @property def _update_called(self) -> bool: @@ -216,7 +229,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: "HINT: Did you forget to call ``unsync`` ?." ) - if self.full_state_update or self.dist_sync_on_step: + if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: self._forward_cache = self._forward_full_state_update(*args, **kwargs) else: self._forward_cache = self._forward_reduce_state_update(*args, **kwargs) From d7f7686f3b444f983a3ea7343c1d2957dcbef194 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 16 May 2022 09:24:38 +0200 Subject: [PATCH 03/20] add check function --- torchmetrics/utilities/checks.py | 79 +++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index e78208638ec..cd6d7ef5fa7 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -622,12 +622,54 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8): def check_forward_no_full_state( metric_class: object, - init_args: Dict[str, Any], - *input_args, + init_args: Dict[str, Any] = {}, + input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000], reps: int = 5, ) -> bool: - """""" + """Utility function for checking if the new ``full_state_update`` property can safely be set + to ``False`` which will for most metrics results in a speedup when using ``forward``. + + Args: + metric_class: metric class object that should be checked + init_args: dict containing arguments for initializing the metric class + input_args: dict containing arguments to pass to ``forward`` + num_update_to_compare: if we successfully detech that the flag is safe to set to ``False`` + we will run some speedup test. This arg should be a list of integers for how many + steps to compare over. + reps: number of repetitions of speedup test + + Example (states in ``update`` are independent, save to set ``full_state_update=False``) + >>> from torchmetrics import ConfusionMatrix + >>> check_forward_no_full_state( + ... ConfusionMatrix, + ... init_args = {'num_classes': 3}, + ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, + ... ) # doctest: +SKIP + Full state for 10 steps took: 0.0072667645290493965+-3.828236731351353e-05 + Partial state for 10 steps took: 0.004743088968098164+-0.0005820328951813281 + Full state for 100 steps took: 0.0730440765619278+-0.0003615743189584464 + Partial state for 100 steps took: 0.04705753177404404+-0.002143740188330412 + Full state for 1000 steps took: 0.8512250781059265+-0.052338723093271255 + Partial state for 1000 steps took: 0.5545409917831421+-0.04722180590033531 + True + + Example (states in ``update`` are dependend meaning that ``full_state_update=True``): + >>> from torchmetrics import ConfusionMatrix + >>> class MyMetric(ConfusionMatrix): + ... def update(self, preds, target): + ... super().update(preds, target) + ... # by construction make future states dependent on prior states + ... if self.confmat.sum() > 20: + ... self.reset() + >>> check_forward_no_full_state( + ... MyMetric, + ... init_args = {'num_classes': 3}, + ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, + ... ) + False + + """ class FullState(metric_class): full_state_update = True @@ -640,20 +682,32 @@ class PartState(metric_class): equal = True for _ in range(10): - out1 = fullstate(*input_args) - out2 = fullstate(*input_args) - equal = equal | _allclose_recursive(out1, out2) - - if not equal: # we can stop early because the states did not match + out1 = fullstate(**input_args) + try: # if it fails, the code most likely need access to the full state + out2 = partstate(**input_args) + except RuntimeError: + equal = False + break + equal = equal & _allclose_recursive(out1, out2) + + res1 = fullstate.compute() + try: # if it fails, the code most likely need access to the full state + res2 = partstate.compute() + except RuntimeError: + equal = False + equal = equal & _allclose_recursive(res1, res2) + + if not equal: # we can stop early because the results did not match return False + # Do timings res = torch.zeros(2, len(num_update_to_compare), reps) for i, metric in enumerate([fullstate, partstate]): for j, t in enumerate(num_update_to_compare): for r in range(reps): start = perf_counter() for _ in range(t): - _ = metric(*input_args) + _ = metric(**input_args) end = perf_counter() res[i, j, r] = end - start metric.reset() @@ -662,6 +716,7 @@ class PartState(metric_class): std = torch.std(res, -1) for t in range(len(num_update_to_compare)): - print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]}") - print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]}") - print() + print(f"Full state for {num_update_to_compare[t]:0.3f} steps took: {mean[0, t]}+-{std[0, t]:0.3f}") + print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}") + + return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading From 33e02f4fdd64db5b805826ab0e265c48dce490f8 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 16 May 2022 09:24:46 +0200 Subject: [PATCH 04/20] fix tests --- tests/bases/test_metric.py | 12 ++++++++++++ tests/helpers/testers.py | 2 ++ tests/test_utilities.py | 13 ++++++++++++- torchmetrics/metric.py | 4 ++-- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 55342eeb2d6..f915b33ea2b 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -409,3 +409,15 @@ def get_memory_usage(): metric.update(x.sum()) memory = get_memory_usage() assert base_memory_level >= memory, "memory increased above base level" + + +@pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum]) +def test_warning_on_not_set_full_state_update(metric_class): + class UnsetProperty(metric_class): + full_state_update = None + + with pytest.warns( + UserWarning, + match=r"Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has.*" + ): + UnsetProperty() diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index a75d5174b2a..7bc0c3cb2e5 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -569,6 +569,7 @@ def run_differentiability_test( class DummyMetric(Metric): name = "Dummy" + full_state_update: Optional[bool] = True def __init__(self, **kwargs): super().__init__(**kwargs) @@ -583,6 +584,7 @@ def compute(self): class DummyListMetric(Metric): name = "DummyList" + full_state_update: Optional[bool] = True def __init__(self): super().__init__() diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 40dd37fe1e3..ce4c76ae94d 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -15,7 +15,8 @@ import torch from torch import tensor -from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn +from torchmetrics import MeanSquaredError, PearsonCorrCoef +from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn, check_forward_no_full_state from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -126,3 +127,13 @@ def test_bincount(): # check for correctness assert torch.allclose(res1, res2) assert torch.allclose(res1, res3) + + +@pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, True), (PearsonCorrCoef, False)]) +def test_check_full_state_update_fn(metric_class, expected): + """ Test that the check function works as it should.""" + out = check_forward_no_full_state( + metric_class=metric_class, + input_args={'preds': torch.randn(100,), 'target': torch.randn(100,)} + ) + assert out == expected diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 97b7996075a..bfb8056ae36 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -126,7 +126,7 @@ def __init__( # state management self._is_synced = False self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None - + if self.full_state_update is None: rank_zero_warn( f"""Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has @@ -138,7 +138,7 @@ def __init__( default for now) or if ``full_state_update=False`` can be used safely """, UserWarning, - ) + ) @property def _update_called(self) -> bool: From 377b0d0b35fff0e26af6e0850dfbaa1541272fc5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 May 2022 08:43:23 +0000 Subject: [PATCH 05/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_metric.py | 2 +- tests/test_utilities.py | 13 ++++++++++--- torchmetrics/utilities/checks.py | 5 ++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index f915b33ea2b..87997cd04a0 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -418,6 +418,6 @@ class UnsetProperty(metric_class): with pytest.warns( UserWarning, - match=r"Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has.*" + match=r"Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has.*", ): UnsetProperty() diff --git a/tests/test_utilities.py b/tests/test_utilities.py index ce4c76ae94d..f58322cc9bd 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics import MeanSquaredError, PearsonCorrCoef -from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn, check_forward_no_full_state +from torchmetrics.utilities import check_forward_no_full_state, rank_zero_debug, rank_zero_info, rank_zero_warn from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -131,9 +131,16 @@ def test_bincount(): @pytest.mark.parametrize("metric_class, expected", [(MeanSquaredError, True), (PearsonCorrCoef, False)]) def test_check_full_state_update_fn(metric_class, expected): - """ Test that the check function works as it should.""" + """Test that the check function works as it should.""" out = check_forward_no_full_state( metric_class=metric_class, - input_args={'preds': torch.randn(100,), 'target': torch.randn(100,)} + input_args={ + "preds": torch.randn( + 100, + ), + "target": torch.randn( + 100, + ), + }, ) assert out == expected diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index cd6d7ef5fa7..6d5033291eb 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -627,8 +627,8 @@ def check_forward_no_full_state( num_update_to_compare: Sequence[int] = [10, 100, 1000], reps: int = 5, ) -> bool: - """Utility function for checking if the new ``full_state_update`` property can safely be set - to ``False`` which will for most metrics results in a speedup when using ``forward``. + """Utility function for checking if the new ``full_state_update`` property can safely be set to ``False`` which + will for most metrics results in a speedup when using ``forward``. Args: metric_class: metric class object that should be checked @@ -668,7 +668,6 @@ def check_forward_no_full_state( ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, ... ) False - """ class FullState(metric_class): From 968fae708f82739395d82236627bee3b4bfdd8ec Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 May 2022 13:38:46 +0200 Subject: [PATCH 06/20] Update tests/test_utilities.py Co-authored-by: Jirka Borovec --- tests/test_utilities.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f58322cc9bd..f36828a1db9 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -134,13 +134,9 @@ def test_check_full_state_update_fn(metric_class, expected): """Test that the check function works as it should.""" out = check_forward_no_full_state( metric_class=metric_class, - input_args={ - "preds": torch.randn( - 100, - ), - "target": torch.randn( - 100, - ), - }, + input_args=dict( + preds=torch.randn(100), + target=torch.randn(100) + ), ) assert out == expected From 3d176c9b759927b25538aa7225a84b45e7a66b3a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 May 2022 11:39:22 +0000 Subject: [PATCH 07/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_utilities.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f36828a1db9..c74f984410f 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -134,9 +134,6 @@ def test_check_full_state_update_fn(metric_class, expected): """Test that the check function works as it should.""" out = check_forward_no_full_state( metric_class=metric_class, - input_args=dict( - preds=torch.randn(100), - target=torch.randn(100) - ), + input_args=dict(preds=torch.randn(100), target=torch.randn(100)), ) assert out == expected From c84c1b516a92837945380c1047809c24e84c5f65 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 16 May 2022 13:41:15 +0200 Subject: [PATCH 08/20] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- torchmetrics/metric.py | 5 +++-- torchmetrics/utilities/checks.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index bfb8056ae36..b57daf7f2a3 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -132,10 +132,11 @@ def __init__( f"""Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has not been set for this class ({self.__class__.__name__}). The property determines if ``update`` by default needs access to the full metric state. If this is not the case, significant speedups can be - achived. We provide an checking function + achieved and we recommend setting this to ``False``. + We provide an checking function ``from torchmetrics.utilities import check_forward_no_full_state`` that can be used to check if the ``full_state_update=True`` (old and potential slower behaviour, - default for now) or if ``full_state_update=False`` can be used safely + default for now) or if ``full_state_update=False`` can be used safely. """, UserWarning, ) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 6d5033291eb..212730080fa 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -612,12 +612,14 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8): # single output compare if isinstance(res1, Tensor): return torch.allclose(res1, res2, atol=atol) + elif isinstance(res1, str): + return res1 == res2 elif isinstance(res1, Sequence): return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2)) - elif isinstance(res1, Dict): + elif isinstance(res1, Mapping): return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys()) else: - raise ValueError("Unknown format for comparison") + return res1 == res2 def check_forward_no_full_state( @@ -680,7 +682,7 @@ class PartState(metric_class): partstate = PartState(**init_args) equal = True - for _ in range(10): + for _ in range(num_update_to_compare[0]): out1 = fullstate(**input_args) try: # if it fails, the code most likely need access to the full state out2 = partstate(**input_args) From fbddd37fc4a0fc679368e6bc02da1c5e8f1cf06a Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Mon, 16 May 2022 13:21:09 +0200 Subject: [PATCH 09/20] add test for recursive check --- tests/test_utilities.py | 16 ++++++++++++++++ torchmetrics/utilities/checks.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index c74f984410f..10f85416c30 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -17,6 +17,7 @@ from torchmetrics import MeanSquaredError, PearsonCorrCoef from torchmetrics.utilities import check_forward_no_full_state, rank_zero_debug, rank_zero_info, rank_zero_warn +from torchmetrics.utilities.checks import _allclose_recursive from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -137,3 +138,18 @@ def test_check_full_state_update_fn(metric_class, expected): input_args=dict(preds=torch.randn(100), target=torch.randn(100)), ) assert out == expected + + +@pytest.mark.parametrize("input, expected", + [ + ((torch.ones(2,), torch.ones(2,)), True), + ((torch.rand(2,), torch.rand(2,)), False), + (([torch.ones(2,) for _ in range(2)], [torch.ones(2,) for _ in range(2)]), True), + (([torch.rand(2,) for _ in range(2)], [torch.rand(2,) for _ in range(2)]), False), + (({f'{i}' : torch.ones(2,) for i in range(2)}, {f'{i}' :torch.ones(2,) for i in range(2)}), True), + (({f'{i}' : torch.rand(2,) for i in range(2)}, {f'{i}' :torch.rand(2,) for i in range(2)}), False), + ] +) +def test_recursive_allclose(input, expected): + res = _allclose_recursive(*input) + assert res == expected diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 212730080fa..ed4da3fc0a1 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from time import perf_counter -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple import torch from torch import Tensor From ef4381af228cfd3a7c7de2e3ea365e0301e95d7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 May 2022 11:49:14 +0000 Subject: [PATCH 10/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_utilities.py | 101 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 8 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 10f85416c30..d354cfaf3b7 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -140,15 +140,100 @@ def test_check_full_state_update_fn(metric_class, expected): assert out == expected -@pytest.mark.parametrize("input, expected", +@pytest.mark.parametrize( + "input, expected", [ - ((torch.ones(2,), torch.ones(2,)), True), - ((torch.rand(2,), torch.rand(2,)), False), - (([torch.ones(2,) for _ in range(2)], [torch.ones(2,) for _ in range(2)]), True), - (([torch.rand(2,) for _ in range(2)], [torch.rand(2,) for _ in range(2)]), False), - (({f'{i}' : torch.ones(2,) for i in range(2)}, {f'{i}' :torch.ones(2,) for i in range(2)}), True), - (({f'{i}' : torch.rand(2,) for i in range(2)}, {f'{i}' :torch.rand(2,) for i in range(2)}), False), - ] + ( + ( + torch.ones( + 2, + ), + torch.ones( + 2, + ), + ), + True, + ), + ( + ( + torch.rand( + 2, + ), + torch.rand( + 2, + ), + ), + False, + ), + ( + ( + [ + torch.ones( + 2, + ) + for _ in range(2) + ], + [ + torch.ones( + 2, + ) + for _ in range(2) + ], + ), + True, + ), + ( + ( + [ + torch.rand( + 2, + ) + for _ in range(2) + ], + [ + torch.rand( + 2, + ) + for _ in range(2) + ], + ), + False, + ), + ( + ( + { + f"{i}": torch.ones( + 2, + ) + for i in range(2) + }, + { + f"{i}": torch.ones( + 2, + ) + for i in range(2) + }, + ), + True, + ), + ( + ( + { + f"{i}": torch.rand( + 2, + ) + for i in range(2) + }, + { + f"{i}": torch.rand( + 2, + ) + for i in range(2) + }, + ), + False, + ), + ], ) def test_recursive_allclose(input, expected): res = _allclose_recursive(*input) From facb9b34680a0c49dd21a26be3a4ebf5394d7a6c Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 17 May 2022 09:12:44 +0200 Subject: [PATCH 11/20] short --- tests/test_utilities.py | 96 +++-------------------------------------- 1 file changed, 6 insertions(+), 90 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index d354cfaf3b7..9cafb5b9d03 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -143,96 +143,12 @@ def test_check_full_state_update_fn(metric_class, expected): @pytest.mark.parametrize( "input, expected", [ - ( - ( - torch.ones( - 2, - ), - torch.ones( - 2, - ), - ), - True, - ), - ( - ( - torch.rand( - 2, - ), - torch.rand( - 2, - ), - ), - False, - ), - ( - ( - [ - torch.ones( - 2, - ) - for _ in range(2) - ], - [ - torch.ones( - 2, - ) - for _ in range(2) - ], - ), - True, - ), - ( - ( - [ - torch.rand( - 2, - ) - for _ in range(2) - ], - [ - torch.rand( - 2, - ) - for _ in range(2) - ], - ), - False, - ), - ( - ( - { - f"{i}": torch.ones( - 2, - ) - for i in range(2) - }, - { - f"{i}": torch.ones( - 2, - ) - for i in range(2) - }, - ), - True, - ), - ( - ( - { - f"{i}": torch.rand( - 2, - ) - for i in range(2) - }, - { - f"{i}": torch.rand( - 2, - ) - for i in range(2) - }, - ), - False, - ), + ((torch.ones(2), torch.ones(2)), True), + ((torch.rand(2), torch.rand(2)), False), + (([torch.ones(2) for _ in range(2)], [torch.ones(2) for _ in range(2)]), True), + (([torch.rand(2) for _ in range(2)], [torch.rand(2) for _ in range(2)]), False), + (({f"{i}": torch.ones(2) for i in range(2)}, {f"{i}": torch.ones(2) for i in range(2)}), True), + (({f"{i}": torch.rand(2) for i in range(2)}, {f"{i}": torch.rand(2) for i in range(2)}), False), ], ) def test_recursive_allclose(input, expected): From f453bf6455d0669af32432dda6b26566f434be01 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 19 May 2022 13:47:25 +0200 Subject: [PATCH 12/20] add doctest --- torchmetrics/utilities/checks.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index ed4da3fc0a1..876d0f505fb 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -647,13 +647,13 @@ def check_forward_no_full_state( ... ConfusionMatrix, ... init_args = {'num_classes': 3}, ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, - ... ) # doctest: +SKIP - Full state for 10 steps took: 0.0072667645290493965+-3.828236731351353e-05 - Partial state for 10 steps took: 0.004743088968098164+-0.0005820328951813281 - Full state for 100 steps took: 0.0730440765619278+-0.0003615743189584464 - Partial state for 100 steps took: 0.04705753177404404+-0.002143740188330412 - Full state for 1000 steps took: 0.8512250781059265+-0.052338723093271255 - Partial state for 1000 steps took: 0.5545409917831421+-0.04722180590033531 + ... ) # doctest: +ELLIPSIS + Full state for 10 steps took: ... + Partial state for 10 steps took: ... + Full state for 100 steps took: ... + Partial state for 100 steps took: ... + Full state for 1000 steps took: ... + Partial state for 1000 steps took: ... True Example (states in ``update`` are dependend meaning that ``full_state_update=True``): @@ -717,7 +717,7 @@ class PartState(metric_class): std = torch.std(res, -1) for t in range(len(num_update_to_compare)): - print(f"Full state for {num_update_to_compare[t]:0.3f} steps took: {mean[0, t]}+-{std[0, t]:0.3f}") + print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}") print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}") return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading From df89521ccc0ac76f64441701ef360d4ee65412f9 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 19 May 2022 13:56:53 +0200 Subject: [PATCH 13/20] fix mypy --- torchmetrics/utilities/checks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 876d0f505fb..c24bd04ab77 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple import torch -from torch import Tensor +from torch import nn, Tensor from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType @@ -607,7 +607,7 @@ def _check_retrieval_target_and_prediction_types( return preds.flatten(), target.flatten() -def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8): +def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: """Utility function for recursively asserting that two results are within a certain tolerance.""" # single output compare if isinstance(res1, Tensor): @@ -623,7 +623,7 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8): def check_forward_no_full_state( - metric_class: object, + metric_class: nn.Module, init_args: Dict[str, Any] = {}, input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000], From d04075ed1909e2eacf0168bc76f728a3b3581989 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 May 2022 12:26:03 +0000 Subject: [PATCH 14/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/utilities/checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index c24bd04ab77..6632be10b8f 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType @@ -607,7 +607,7 @@ def _check_retrieval_target_and_prediction_types( return preds.flatten(), target.flatten() -def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: +def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: """Utility function for recursively asserting that two results are within a certain tolerance.""" # single output compare if isinstance(res1, Tensor): From 7e00969396ce91e1bfb812114f34828e6f67060e Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 19 May 2022 14:02:18 +0200 Subject: [PATCH 15/20] fix --- torchmetrics/utilities/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 6632be10b8f..296d5782b13 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -623,7 +623,7 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: def check_forward_no_full_state( - metric_class: nn.Module, + metric_class: Any, init_args: Dict[str, Any] = {}, input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000], From d0364d752fb0a66a357b886545d2e4bc06d3c0df Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 19 May 2022 14:04:40 +0200 Subject: [PATCH 16/20] update --- torchmetrics/utilities/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 296d5782b13..75c6c59dedb 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple import torch -from torch import Tensor, nn +from torch import Tensor from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType From 89270470522ed8199c09efa44ac6b39168decb8a Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Thu, 19 May 2022 14:09:55 +0200 Subject: [PATCH 17/20] again --- torchmetrics/utilities/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 75c6c59dedb..936914e6a95 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -623,7 +623,7 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: def check_forward_no_full_state( - metric_class: Any, + metric_class, # type: ignore init_args: Dict[str, Any] = {}, input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000], From 53b6bae1ec537863810095d54f4fa560e349e59f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 23 May 2022 11:41:44 +0200 Subject: [PATCH 18/20] Apply suggestions from code review --- torchmetrics/metric.py | 12 ++++++------ torchmetrics/utilities/checks.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index b57daf7f2a3..ec857708e97 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -129,14 +129,14 @@ def __init__( if self.full_state_update is None: rank_zero_warn( - f"""Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has - not been set for this class ({self.__class__.__name__}). The property determines if ``update`` by + f"""Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has + not been set for this class ({self.__class__.__name__}). The property determines if `update` by default needs access to the full metric state. If this is not the case, significant speedups can be - achieved and we recommend setting this to ``False``. + achieved and we recommend setting this to `False`. We provide an checking function - ``from torchmetrics.utilities import check_forward_no_full_state`` - that can be used to check if the ``full_state_update=True`` (old and potential slower behaviour, - default for now) or if ``full_state_update=False`` can be used safely. + `from torchmetrics.utilities import check_forward_no_full_state` + that can be used to check if the `full_state_update=True` (old and potential slower behaviour, + default for now) or if `full_state_update=False` can be used safely. """, UserWarning, ) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 936914e6a95..3004fd8e3d2 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -618,8 +618,7 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2)) elif isinstance(res1, Mapping): return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys()) - else: - return res1 == res2 + return res1 == res2 def check_forward_no_full_state( From b27aaac8361c249a3a6b401bb2e9ef7db628db0f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 24 May 2022 11:00:48 +0200 Subject: [PATCH 19/20] fix tests --- tests/bases/test_metric.py | 2 +- torchmetrics/utilities/checks.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 87997cd04a0..5b2c866f9f2 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -418,6 +418,6 @@ class UnsetProperty(metric_class): with pytest.warns( UserWarning, - match=r"Torchmetrics v0.9 introduced a new argument class property called ``full_state_update`` that has.*", + match="Torchmetrics v0.9 introduced a new argument class property called.*", ): UnsetProperty() diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 3004fd8e3d2..85ed9a1f0f8 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from time import perf_counter -from typing import Any, Dict, Mapping, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Tuple import torch from torch import Tensor @@ -20,6 +20,9 @@ from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType +if TYPE_CHECKING: + from torchmetrics import Metric + def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool: if preds.numel() == target.numel() == 0: @@ -622,7 +625,7 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: def check_forward_no_full_state( - metric_class, # type: ignore + metric_class: Metric, init_args: Dict[str, Any] = {}, input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000], From 460aa1f0c116c3953c01aebcccec6fdd5898596f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 24 May 2022 11:10:52 +0200 Subject: [PATCH 20/20] fix --- torchmetrics/utilities/checks.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 85ed9a1f0f8..16cb4a267c0 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from time import perf_counter -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, no_type_check import torch from torch import Tensor @@ -20,9 +20,6 @@ from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType -if TYPE_CHECKING: - from torchmetrics import Metric - def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool: if preds.numel() == target.numel() == 0: @@ -624,8 +621,9 @@ def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool: return res1 == res2 +@no_type_check def check_forward_no_full_state( - metric_class: Metric, + metric_class, init_args: Dict[str, Any] = {}, input_args: Dict[str, Any] = {}, num_update_to_compare: Sequence[int] = [10, 100, 1000],